diff --git a/models/ssh_cert.py b/models/ssh_cert.py index 3e021e0..0ade70c 100644 --- a/models/ssh_cert.py +++ b/models/ssh_cert.py @@ -23,10 +23,15 @@ class list: def __new__(cls, sort_key=None): cur = conn.cursor() - cur.execute("SELECT nkey FROM ssh_certs") + cur.execute( + """SELECT ssh_certs.nvalue AS cert, + revoked_ssh_certs.nvalue AS revoked + FROM ssh_certs + LEFT JOIN revoked_ssh_certs USING(nkey)""" + ) - for (cert_serial,) in cur: - cert_object = cert(cert_serial) + for result in cur: + cert_object = cert(result) cls.certs.append(cert_object) cur.close() @@ -38,15 +43,25 @@ class list: class cert: - def __init__(self, serial): - cert_raw = self.get_cert(serial) + def __init__(self, cert): + (cert_raw, cert_revoked_raw) = cert + size = unpack(">I", cert_raw[:4])[0] + 4 alg = cert_raw[4:size] cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)]) - cert_revoked = self.get_cert_revoked(serial) + + if cert_revoked_raw is not None: + cert_revoked = json.loads(cert_revoked_raw) + else: + cert_revoked = None + self.load(cert_pub_id, cert_revoked, alg) + @classmethod + def from_serial(cls, serial): + return cls(cert=cls.get_cert(cls, serial)) + def load(self, cert_pub_id, cert_revoked, cert_alg): cert = serialization.load_ssh_public_identity(cert_pub_id) self.serial = cert.serial @@ -95,27 +110,22 @@ class cert: def get_cert(self, cert_serial): cur = conn.cursor() - cur.execute("SELECT nvalue FROM ssh_certs WHERE nkey=?", (cert_serial,)) + cur.execute( + """SELECT ssh_certs.nvalue AS cert, + revoked_ssh_certs.nvalue AS revoked + FROM ssh_certs + LEFT JOIN revoked_ssh_certs USING(nkey) + WHERE nkey=?""", + (cert_serial,), + ) if cur.rowcount > 0: - (cert,) = cur.fetchone() + cert = cur.fetchone() else: cert = None cur.close() return cert - def get_cert_revoked(self, cert_serial): - cur = conn.cursor() - cur.execute("SELECT nvalue FROM revoked_ssh_certs WHERE nkey=?", (cert_serial,)) - if cur.rowcount > 0: - (cert_revoked_raw,) = cur.fetchone() - cert_revoked = json.loads(cert_revoked_raw) - else: - cert_revoked = None - - cur.close() - return cert_revoked - def get_public_key_params(self, public_key): if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey): key_type = "ECDSA" diff --git a/models/x509_cert.py b/models/x509_cert.py index 19ce241..d5c1d7a 100644 --- a/models/x509_cert.py +++ b/models/x509_cert.py @@ -10,10 +10,10 @@ from models.config import config config() conn = mariadb.connect( - host = config.database_host, - user = config.database_user, - password = config.database_password, - database = config.database_name + host=config.database_host, + user=config.database_user, + password=config.database_password, + database=config.database_name, ) @@ -22,10 +22,17 @@ class list: def __new__(cls, sort_key=None): cur = conn.cursor() - cur.execute("SELECT nkey FROM x509_certs") + cur.execute( + """SELECT x509_certs.nvalue AS cert, + x509_certs_data.nvalue AS data, + revoked_x509_certs.nvalue AS revoked + FROM x509_certs + INNER JOIN x509_certs_data USING(nkey) + LEFT JOIN revoked_x509_certs USING(nkey)""" + ) - for (cert_serial,) in cur: - cert_object = cert(cert_serial) + for result in cur: + cert_object = cert(result) cls.certs.append(cert_object) cur.close() @@ -36,14 +43,21 @@ class list: return cls.certs - class cert: - def __init__(self, serial): - cert_der = self.get_cert(serial) - cert_data = self.get_cert_data(serial) - cert_revoked = self.get_cert_revoked(serial) + def __init__(self, cert): + (cert_der, cert_data_raw, cert_revoked_raw) = cert + + cert_data = json.loads(cert_data_raw) + if cert_revoked_raw is not None: + cert_revoked = json.loads(cert_revoked_raw) + else: + cert_revoked = None + self.load(cert_der, cert_data, cert_revoked) + @classmethod + def from_serial(cls, serial): + return cls(cert=cls.get_cert(cls, serial)) def load(self, cert_der, cert_data, cert_revoked): cert = x509.load_der_x509_certificate(cert_der) @@ -60,7 +74,9 @@ class cert: self.not_before = cert.not_valid_before_utc.replace(microsecond=0) self.not_after = cert.not_valid_after_utc.replace(microsecond=0) try: - san_data = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName) + san_data = cert.extensions.get_extension_for_class( + x509.SubjectAlternativeName + ) self.san_names = san_data.value.get_values_for_type(x509.GeneralName) except x509.extensions.ExtensionNotFound: self.san_names = [] @@ -85,12 +101,21 @@ class cert: else: self.status = status(status.VALID) - def get_cert(self, cert_serial): cur = conn.cursor() - cur.execute("SELECT nvalue FROM x509_certs WHERE nkey=?", (cert_serial,)) + cur.execute( + """SELECT x509_certs.nvalue AS cert, + x509_certs_data.nvalue AS data, + revoked_x509_certs.nvalue AS revoked + FROM x509_certs + INNER JOIN x509_certs_data USING(nkey) + LEFT JOIN revoked_x509_certs USING(nkey) + WHERE nkey=?""", + (cert_serial,), + ) + if cur.rowcount > 0: - (cert,) = cur.fetchone() + cert = cur.fetchone() else: cert = None @@ -98,30 +123,6 @@ class cert: return cert - def get_cert_data(self, cert_serial): - cur = conn.cursor() - cur.execute("SELECT nvalue FROM x509_certs_data WHERE nkey=?", (cert_serial,)) - (cert_data_raw,) = cur.fetchone() - cur.close() - cert_data = json.loads(cert_data_raw) - return cert_data - - - def get_cert_revoked(self, cert_serial): - cur = conn.cursor() - cur.execute( - "SELECT nvalue FROM revoked_x509_certs WHERE nkey=?", (cert_serial,) - ) - if cur.rowcount > 0: - (cert_revoked_raw,) = cur.fetchone() - cert_revoked = json.loads(cert_revoked_raw) - else: - cert_revoked = None - - cur.close() - return cert_revoked - - class status: REVOKED = 1 EXPIRED = 2 diff --git a/step-ca-inspector.py b/step-ca-inspector.py index e942a1f..17c6c75 100755 --- a/step-ca-inspector.py +++ b/step-ca-inspector.py @@ -42,7 +42,7 @@ def list_ssh_certs(sort_key, revoked=False, expired=False): def get_ssh_cert(serial): - cert = ssh_cert.cert(serial) + cert = ssh_cert.cert.from_serial(serial) cert_tbl = [] cert_tbl.append(["Serial", cert.serial]) @@ -68,7 +68,7 @@ def get_ssh_cert(serial): def dump_ssh_cert(serial): - cert = ssh_cert.cert(serial) + cert = ssh_cert.cert.from_serial(serial) print(cert.public_identity.decode()) @@ -105,7 +105,7 @@ def list_x509_certs(sort_key, revoked=False, expired=False): def get_x509_cert(serial, show_pem=False): - cert = x509_cert.cert(serial) + cert = x509_cert.cert.from_serial(serial) cert_tbl = [] cert_tbl.append(["Serial", cert.serial]) @@ -135,7 +135,7 @@ def get_x509_cert(serial, show_pem=False): def dump_x509_cert(serial, cert_format="pem"): - cert = x509_cert.cert(serial) + cert = x509_cert.cert.from_serial(serial) print(cert.pem.decode("utf-8").rstrip())