import base64
import dateutil
import json
import mariadb
from cryptography import x509
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
from datetime import datetime, timedelta, timezone
from models.config import config
from struct import unpack


config()
conn = mariadb.connect(
    host=config.database_host,
    user=config.database_user,
    password=config.database_password,
    database=config.database_name,
)


class list:
    certs = []

    def __new__(cls, sort_key=None):
        cur = conn.cursor()
        cur.execute(
            """SELECT ssh_certs.nvalue AS cert,
                      revoked_ssh_certs.nvalue AS revoked
               FROM ssh_certs
               LEFT JOIN revoked_ssh_certs USING(nkey)"""
        )

        for result in cur:
            cert_object = cert(result)
            cls.certs.append(cert_object)

        cur.close()

        if sort_key is not None:
            cls.certs.sort(key=lambda item: getattr(item, sort_key))

        return cls.certs


class cert:
    def __init__(self, cert):
        (cert_raw, cert_revoked_raw) = cert

        size = unpack(">I", cert_raw[:4])[0] + 4
        alg = cert_raw[4:size]

        cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)])

        if cert_revoked_raw is not None:
            cert_revoked = json.loads(cert_revoked_raw)
        else:
            cert_revoked = None

        self.load(cert_pub_id, cert_revoked, alg)

    @classmethod
    def from_serial(cls, serial):
        return cls(cert=cls.get_cert(cls, serial))

    def load(self, cert_pub_id, cert_revoked, cert_alg):
        cert = serialization.load_ssh_public_identity(cert_pub_id)
        self.serial = cert.serial
        self.alg = cert_alg
        if cert.type == serialization.SSHCertificateType.USER:
            self.type = "User"
        self.key_id = cert.key_id
        self.principals = cert.valid_principals
        self.not_after = datetime.fromtimestamp(cert.valid_before).replace(
            tzinfo=timezone(offset=timedelta()), microsecond=0
        )
        self.not_before = datetime.fromtimestamp(cert.valid_after).replace(
            tzinfo=timezone(offset=timedelta()), microsecond=0
        )
        # TODO: Implement critical options parsing
        # cert.critical_options
        self.extensions = cert.extensions

        (self.signing_key, self.signing_key_type, self.signing_key_hash) = (
            self.get_public_key_params(cert.signature_key())
        )

        (self.public_key, self.public_key_type, self.public_key_hash) = (
            self.get_public_key_params(cert.public_key())
        )

        self.public_identity = cert.public_bytes()

        if cert_revoked is not None:
            self.revoked_at = dateutil.parser.isoparse(
                cert_revoked.get("RevokedAt")
            ).replace(microsecond=0)
        else:
            self.revoked_at = None

        now_with_tz = datetime.utcnow().replace(
            tzinfo=timezone(offset=timedelta()), microsecond=0
        )

        if self.revoked_at is not None and self.revoked_at < now_with_tz:
            self.status = status(status.REVOKED)
        elif self.not_after < now_with_tz:
            self.status = status(status.EXPIRED)
        else:
            self.status = status(status.VALID)

    def get_cert(self, cert_serial):
        cur = conn.cursor()
        cur.execute(
            """SELECT ssh_certs.nvalue AS cert,
                      revoked_ssh_certs.nvalue AS revoked
               FROM ssh_certs
               LEFT JOIN revoked_ssh_certs USING(nkey)
               WHERE nkey=?""",
            (cert_serial,),
        )
        if cur.rowcount > 0:
            cert = cur.fetchone()
        else:
            cert = None

        cur.close()
        return cert

    def get_public_key_params(self, public_key):
        if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey):
            key_type = "ECDSA"
        elif isinstance(public_key, asymmetric.ed25519.Ed25519PublicKey):
            key_type = "ED25519"
        elif isinstance(public_key, asymmetric.rsa.RSAPublicKey):
            key_type = "RSA"

        key_str = public_key.public_bytes(
            serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
        )

        key_data = key_str.strip().split()[1]
        digest = hashes.Hash(hashes.SHA256())
        digest.update(base64.b64decode(key_data))
        hash_sha256 = digest.finalize()
        key_hash = base64.b64encode(hash_sha256)

        return key_str, key_type, key_hash


class status:
    REVOKED = 1
    EXPIRED = 2
    VALID = 3

    def __init__(self, status):
        self.value = status

    def __str__(self):
        if self.value == self.EXPIRED:
            return "Expired"
        elif self.value == self.REVOKED:
            return "Revoked"
        elif self.value == self.VALID:
            return "Valid"
        else:
            return "Undefined"