commit 6c02075a5791db0e8db657087833c736c3e0b0c1
Author: Benjamin Collet <benjamin@collet.eu>
Date:   Tue Jan 7 09:50:19 2025 +0100

    Initial commit

diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..69ad294
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+.python-version
+__pycache__/
diff --git a/models/config.py b/models/config.py
new file mode 100644
index 0000000..d9945b5
--- /dev/null
+++ b/models/config.py
@@ -0,0 +1,25 @@
+import os
+import sys
+import yaml
+
+class config:
+    @classmethod
+    def __init__(self):
+        for config_path in (
+            os.path.expanduser("~/.config/step-ca-inspector"),
+            os.environ.get("STEP_CA_INSPECTOR_CONF"),
+        ):
+            if config_path is None:
+                continue
+            try:
+                with open(os.path.join(config_path, "config.yaml")) as ymlfile:
+                    cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)
+                    break
+            except IOError:
+                pass
+        else:
+            print("No configuration file found")
+            sys.exit(1)
+
+        for k, v in cfg.items():
+            setattr(self, k, v)
diff --git a/models/ssh_cert.py b/models/ssh_cert.py
new file mode 100644
index 0000000..3e021e0
--- /dev/null
+++ b/models/ssh_cert.py
@@ -0,0 +1,156 @@
+import base64
+import dateutil
+import json
+import mariadb
+from cryptography import x509
+from cryptography.hazmat.primitives import asymmetric, hashes, serialization
+from datetime import datetime, timedelta, timezone
+from models.config import config
+from struct import unpack
+
+
+config()
+conn = mariadb.connect(
+    host=config.database_host,
+    user=config.database_user,
+    password=config.database_password,
+    database=config.database_name,
+)
+
+
+class list:
+    certs = []
+
+    def __new__(cls, sort_key=None):
+        cur = conn.cursor()
+        cur.execute("SELECT nkey FROM ssh_certs")
+
+        for (cert_serial,) in cur:
+            cert_object = cert(cert_serial)
+            cls.certs.append(cert_object)
+
+        cur.close()
+
+        if sort_key is not None:
+            cls.certs.sort(key=lambda item: getattr(item, sort_key))
+
+        return cls.certs
+
+
+class cert:
+    def __init__(self, serial):
+        cert_raw = self.get_cert(serial)
+        size = unpack(">I", cert_raw[:4])[0] + 4
+        alg = cert_raw[4:size]
+
+        cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)])
+        cert_revoked = self.get_cert_revoked(serial)
+        self.load(cert_pub_id, cert_revoked, alg)
+
+    def load(self, cert_pub_id, cert_revoked, cert_alg):
+        cert = serialization.load_ssh_public_identity(cert_pub_id)
+        self.serial = cert.serial
+        self.alg = cert_alg
+        if cert.type == serialization.SSHCertificateType.USER:
+            self.type = "User"
+        self.key_id = cert.key_id
+        self.principals = cert.valid_principals
+        self.not_after = datetime.fromtimestamp(cert.valid_before).replace(
+            tzinfo=timezone(offset=timedelta()), microsecond=0
+        )
+        self.not_before = datetime.fromtimestamp(cert.valid_after).replace(
+            tzinfo=timezone(offset=timedelta()), microsecond=0
+        )
+        # TODO: Implement critical options parsing
+        # cert.critical_options
+        self.extensions = cert.extensions
+
+        (self.signing_key, self.signing_key_type, self.signing_key_hash) = (
+            self.get_public_key_params(cert.signature_key())
+        )
+
+        (self.public_key, self.public_key_type, self.public_key_hash) = (
+            self.get_public_key_params(cert.public_key())
+        )
+
+        self.public_identity = cert.public_bytes()
+
+        if cert_revoked is not None:
+            self.revoked_at = dateutil.parser.isoparse(
+                cert_revoked.get("RevokedAt")
+            ).replace(microsecond=0)
+        else:
+            self.revoked_at = None
+
+        now_with_tz = datetime.utcnow().replace(
+            tzinfo=timezone(offset=timedelta()), microsecond=0
+        )
+
+        if self.revoked_at is not None and self.revoked_at < now_with_tz:
+            self.status = status(status.REVOKED)
+        elif self.not_after < now_with_tz:
+            self.status = status(status.EXPIRED)
+        else:
+            self.status = status(status.VALID)
+
+    def get_cert(self, cert_serial):
+        cur = conn.cursor()
+        cur.execute("SELECT nvalue FROM ssh_certs WHERE nkey=?", (cert_serial,))
+        if cur.rowcount > 0:
+            (cert,) = cur.fetchone()
+        else:
+            cert = None
+
+        cur.close()
+        return cert
+
+    def get_cert_revoked(self, cert_serial):
+        cur = conn.cursor()
+        cur.execute("SELECT nvalue FROM revoked_ssh_certs WHERE nkey=?", (cert_serial,))
+        if cur.rowcount > 0:
+            (cert_revoked_raw,) = cur.fetchone()
+            cert_revoked = json.loads(cert_revoked_raw)
+        else:
+            cert_revoked = None
+
+        cur.close()
+        return cert_revoked
+
+    def get_public_key_params(self, public_key):
+        if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey):
+            key_type = "ECDSA"
+        elif isinstance(public_key, asymmetric.ed25519.Ed25519PublicKey):
+            key_type = "ED25519"
+        elif isinstance(public_key, asymmetric.rsa.RSAPublicKey):
+            key_type = "RSA"
+
+        key_str = public_key.public_bytes(
+            serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
+        )
+
+        key_data = key_str.strip().split()[1]
+        digest = hashes.Hash(hashes.SHA256())
+        digest.update(base64.b64decode(key_data))
+        hash_sha256 = digest.finalize()
+        key_hash = base64.b64encode(hash_sha256)
+
+        return key_str, key_type, key_hash
+
+
+class status:
+    REVOKED = 1
+    EXPIRED = 2
+    VALID = 3
+
+    def __init__(self, status):
+        self.value = status
+
+    def __str__(self):
+        if self.value == self.EXPIRED:
+            return "Expired"
+        elif self.value == self.REVOKED:
+            return "Revoked"
+        elif self.value == self.VALID:
+            return "Valid"
+        else:
+            return "Undefined"
diff --git a/models/x509_cert.py b/models/x509_cert.py
new file mode 100644
index 0000000..f9872af
--- /dev/null
+++ b/models/x509_cert.py
@@ -0,0 +1,137 @@
+import binascii
+import dateutil
+import json
+import mariadb
+from cryptography import x509
+from cryptography.hazmat.primitives import hashes, serialization
+from datetime import datetime, timedelta, timezone
+from models.config import config
+
+
+config()
+conn = mariadb.connect(
+    host = config.database_host,
+    user = config.database_user,
+    password = config.database_password,
+    database = config.database_name
+)
+
+
+class list:
+    certs = []
+
+    def __new__(cls, sort_key=None):
+        cur = conn.cursor()
+        cur.execute("SELECT nkey FROM x509_certs")
+
+        for (cert_serial,) in cur:
+            cert_object = cert(cert_serial)
+            cls.certs.append(cert_object)
+
+        cur.close()
+
+        if sort_key is not None:
+            cls.certs.sort(key=lambda item: getattr(item, sort_key))
+
+        return cls.certs
+
+
+
+class cert:
+    def __init__(self, serial):
+        cert_der = self.get_cert(serial)
+        cert_data = self.get_cert_data(serial)
+        cert_revoked = self.get_cert_revoked(serial)
+        self.load(cert_der, cert_data, cert_revoked)
+
+
+    def load(self, cert_der, cert_data, cert_revoked):
+        cert = x509.load_der_x509_certificate(cert_der)
+
+        self.pem = cert.public_bytes(serialization.Encoding.PEM)
+        self.serial = str(cert.serial_number)
+        self.sha256 = binascii.b2a_hex(cert.fingerprint(hashes.SHA256()))
+        self.sha1 = binascii.b2a_hex(cert.fingerprint(hashes.SHA1()))
+        self.md5 = binascii.b2a_hex(cert.fingerprint(hashes.MD5()))
+        self.pub_alg = cert.public_key_algorithm_oid._name
+        self.sig_alg = cert.signature_algorithm_oid._name
+        self.issuer = cert.issuer.rfc4514_string()
+        self.subject = cert.subject.rfc4514_string({x509.NameOID.EMAIL_ADDRESS: "E"})
+        self.not_before = cert.not_valid_before_utc.replace(microsecond=0)
+        self.not_after = cert.not_valid_after_utc.replace(microsecond=0)
+        san_data = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName)
+        self.san_names = san_data.value.get_values_for_type(x509.GeneralName)
+        self.provisioner = cert_data.get("provisioner", None)
+
+        if cert_revoked is not None:
+            self.revoked_at = dateutil.parser.isoparse(
+                cert_revoked.get("RevokedAt")
+            ).replace(microsecond=0)
+        else:
+            self.revoked_at = None
+
+        now_with_tz = datetime.utcnow().replace(
+            tzinfo=timezone(offset=timedelta()), microsecond=0
+        )
+
+        if self.revoked_at is not None and self.revoked_at < now_with_tz:
+            self.status = status(status.REVOKED)
+        elif self.not_after < now_with_tz:
+            self.status = status(status.EXPIRED)
+        else:
+            self.status = status(status.VALID)
+
+
+    def get_cert(self, cert_serial):
+        cur = conn.cursor()
+        cur.execute("SELECT nvalue FROM x509_certs WHERE nkey=?", (cert_serial,))
+        if cur.rowcount > 0:
+            (cert,) = cur.fetchone()
+        else:
+            cert = None
+
+        cur.close()
+        return cert
+
+
+    def get_cert_data(self, cert_serial):
+        cur = conn.cursor()
+        cur.execute("SELECT nvalue FROM x509_certs_data WHERE nkey=?", (cert_serial,))
+        (cert_data_raw,) = cur.fetchone()
+        cur.close()
+        cert_data = json.loads(cert_data_raw)
+        return cert_data
+
+
+    def get_cert_revoked(self, cert_serial):
+        cur = conn.cursor()
+        cur.execute(
+            "SELECT nvalue FROM revoked_x509_certs WHERE nkey=?", (cert_serial,)
+        )
+        if cur.rowcount > 0:
+            (cert_revoked_raw,) = cur.fetchone()
+            cert_revoked = json.loads(cert_revoked_raw)
+        else:
+            cert_revoked = None
+
+        cur.close()
+        return cert_revoked
+
+
+class status:
+    REVOKED = 1
+    EXPIRED = 2
+    VALID = 3
+
+    def __init__(self, status):
+        self.value = status
+
+    def __str__(self):
+        if self.value == self.EXPIRED:
+            return "Expired"
+        elif self.value == self.REVOKED:
+            return "Revoked"
+        elif self.value == self.VALID:
+            return "Valid"
+        else:
+            return "Undefined"
diff --git a/step-ca-inspector.py b/step-ca-inspector.py
new file mode 100755
index 0000000..e942a1f
--- /dev/null
+++ b/step-ca-inspector.py
@@ -0,0 +1,254 @@
+#!/usr/bin/env python3
+
+import argparse
+import os
+import sys
+import yaml
+from tabulate import tabulate
+from models import ssh_cert, x509_cert
+
+
+def list_ssh_certs(sort_key, revoked=False, expired=False):
+    cert_list = ssh_cert.list(sort_key=sort_key)
+    cert_tbl = []
+    for cert in cert_list:
+        if cert.status.value == ssh_cert.status.EXPIRED and not expired:
+            continue
+        if cert.status.value == ssh_cert.status.REVOKED and not revoked:
+            continue
+
+        cert_row = {}
+        cert_row["Serial"] = cert.serial
+        cert_row["Type"] = cert.type
+        cert_row["Key ID"] = cert.key_id
+        principals_count = len(cert.principals)
+        principals_list = [x.decode() for x in cert.principals]
+        if principals_count > 2:
+            principals = principals_list[:2] + [f"+{principals_count - 2} more"]
+        cert_row["Principals"] = "\n".join(principals)
+
+        validity = []
+        validity.append(f"Not before: {cert.not_before}")
+        validity.append(f"Not after:  {cert.not_after}")
+        if cert.revoked_at is not None:
+            validity.append(f"Revoked at: {cert.revoked_at}")
+
+        cert_row["Validity"] = "\n".join(validity)
+        cert_row["Status"] = cert.status
+
+        cert_tbl.append(cert_row)
+
+    print(tabulate(cert_tbl, headers="keys", tablefmt="fancy_grid"))
+
+
+def get_ssh_cert(serial):
+    cert = ssh_cert.cert(serial)
+    cert_tbl = []
+
+    cert_tbl.append(["Serial", cert.serial])
+    cert_tbl.append(["Certificate type", cert.type])
+    cert_tbl.append(["Certificate key type", cert.alg.decode()])
+    public_key = f"{cert.public_key_type} SHA256:{cert.public_key_hash.decode()}"
+    cert_tbl.append(["Public key", public_key])
+    signing_key = f"{cert.signing_key_type} SHA256:{cert.signing_key_hash.decode()}"
+    cert_tbl.append(["Signing key", signing_key])
+    cert_tbl.append(["Key ID", cert.key_id.decode()])
+    principals = [x.decode() for x in cert.principals]
+    cert_tbl.append(["Principals", "\n".join(principals)])
+    cert_tbl.append(["Not valid before", cert.not_before])
+    cert_tbl.append(["Not valid after", cert.not_after])
+    if cert.revoked_at is not None:
+        cert_tbl.append(["Revoked at", cert.revoked_at])
+    extensions = [x.decode() for x in cert.extensions]
+    cert_tbl.append(["Extensions", "\n".join(extensions)])
+    # cert_tbl.append(["Signing key", cert.signing_key.decode()])
+    cert_tbl.append(["Status", cert.status])
+
+    print(tabulate(cert_tbl, tablefmt="fancy_grid"))
+
+
+def dump_ssh_cert(serial):
+    cert = ssh_cert.cert(serial)
+    print(cert.public_identity.decode())
+
+
+def list_x509_certs(sort_key, revoked=False, expired=False):
+    cert_list = x509_cert.list(sort_key=sort_key)
+    cert_tbl = []
+    for cert in cert_list:
+        if cert.status.value == x509_cert.status.EXPIRED and not expired:
+            continue
+        if cert.status.value == x509_cert.status.REVOKED and not revoked:
+            continue
+
+        cert_row = {}
+        cert_row["Serial"] = cert.serial
+        cert_row["Subject"] = "%.30s" % cert.subject
+        cert_row["Subject Alt Names (SAN)"] = "\n".join(
+            ["%.30s" % x for x in cert.san_names]
+        )
+        cert_row["Provisioner"] = (
+            f"{cert.provisioner['name']} ({cert.provisioner['type']})"
+        )
+        validity = []
+        validity.append(f"Not before: {cert.not_before}")
+        validity.append(f"Not after:  {cert.not_after}")
+        if cert.revoked_at is not None:
+            validity.append(f"Revoked at: {cert.revoked_at}")
+
+        cert_row["Validity"] = "\n".join(validity)
+        cert_row["Status"] = cert.status
+
+        cert_tbl.append(cert_row)
+
+    print(tabulate(cert_tbl, headers="keys", tablefmt="fancy_grid"))
+
+
+def get_x509_cert(serial, show_pem=False):
+    cert = x509_cert.cert(serial)
+    cert_tbl = []
+
+    cert_tbl.append(["Serial", cert.serial])
+    cert_tbl.append(["Subject", cert.subject])
+    cert_tbl.append(["Subject Alt Names (SAN)", "\n".join(cert.san_names)])
+    cert_tbl.append(["Issuer", cert.issuer])
+    cert_tbl.append(["Not valid before", cert.not_before])
+    cert_tbl.append(["Not valid after", cert.not_after])
+    if cert.revoked_at is not None:
+        cert_tbl.append(["Revoked at", cert.revoked_at])
+    cert_tbl.append(
+        ["Provisioner", f"{cert.provisioner['name']} ({cert.provisioner['type']})"]
+    )
+    fingerprints = []
+    fingerprints.append(f"MD5:     {cert.md5.decode()}")
+    fingerprints.append(f"SHA-1:   {cert.sha1.decode()}")
+    fingerprints.append(f"SHA-256: {cert.sha256.decode()}")
+    cert_tbl.append(["Fingerprints", "\n".join(fingerprints)])
+    cert_tbl.append(["Public key algorithm", cert.pub_alg])
+    cert_tbl.append(["Signature algorithm", cert.sig_alg])
+    cert_tbl.append(["Status", cert.status])
+    # cert_tbl.append(["Extensions", cert.extensions])
+    if show_pem:
+        cert_tbl.append(["PEM", cert.pem.decode("utf-8")])
+
+    print(tabulate(cert_tbl, tablefmt="fancy_grid"))
+
+
+def dump_x509_cert(serial, cert_format="pem"):
+    cert = x509_cert.cert(serial)
+    print(cert.pem.decode("utf-8").rstrip())
+
+
+parser = argparse.ArgumentParser(description="Step CA Inspector")
+subparsers = parser.add_subparsers(
+    help="Object to inspect", dest="object", required=True
+)
+x509_parser = subparsers.add_parser("x509", help="x509 certificates")
+x509_subparsers = x509_parser.add_subparsers(
+    help="Action for perform", dest="action", required=True
+)
+x509_list_parser = x509_subparsers.add_parser("list", help="List x509 certificates")
+x509_list_parser.add_argument(
+    "--show-expired",
+    "-e",
+    action="store_true",
+    default=False,
+    help="Show expired certificates",
+)
+x509_list_parser.add_argument(
+    "--show-revoked",
+    "-r",
+    action="store_true",
+    default=False,
+    help="Show revoked certificates",
+)
+x509_list_parser.add_argument(
+    "--sort-by",
+    "-s",
+    type=str,
+    choices=["not_after", "not_before"],
+    default="not_after",
+    help="Sort certificates",
+)
+x509_details_parser = x509_subparsers.add_parser(
+    "details", help="Show an x509 certificate details"
+)
+x509_details_parser.add_argument(
+    "--serial", "-s", type=str, required=True, help="Certificate serial"
+)
+x509_details_parser.add_argument(
+    "--show-pem",
+    "-p",
+    action="store_true",
+    default=False,
+    help="Show PEM",
+)
+x509_dump_parser = x509_subparsers.add_parser("dump", help="Dump an x509 certificate")
+x509_dump_parser.add_argument(
+    "--serial", "-s", type=str, required=True, help="Certificate serial"
+)
+x509_dump_parser.add_argument(
+    "--format",
+    "-f",
+    type=str,
+    choices=["pem"],
+    required=False,
+    help="Certificate format",
+)
+ssh_parser = subparsers.add_parser("ssh", help="ssh certificates")
+ssh_subparsers = ssh_parser.add_subparsers(
+    help="Action for perform", dest="action", required=True
+)
+ssh_list_parser = ssh_subparsers.add_parser("list", help="List ssh certificates")
+ssh_list_parser.add_argument(
+    "--show-expired",
+    "-e",
+    action="store_true",
+    default=False,
+    help="Show expired certificates",
+)
+ssh_list_parser.add_argument(
+    "--show-revoked",
+    "-r",
+    action="store_true",
+    default=False,
+    help="Show revoked certificates",
+)
+ssh_list_parser.add_argument(
+    "--sort-by",
+    "-s",
+    type=str,
+    choices=["not_after", "not_before"],
+    default="not_after",
+    help="Sort certificates (default: not_after)",
+)
+ssh_details_parser = ssh_subparsers.add_parser(
+    "details", help="Show an ssh certificate details"
+)
+ssh_details_parser.add_argument(
+    "--serial", "-s", type=str, required=True, help="Certificate serial"
+)
+ssh_dump_parser = ssh_subparsers.add_parser("dump", help="Dump an ssh certificate")
+ssh_dump_parser.add_argument(
+    "--serial", "-s", type=str, required=True, help="Certificate serial"
+)
+args = parser.parse_args()
+
+if args.object == "x509":
+    if args.action == "list":
+        list_x509_certs(
+            revoked=args.show_revoked, expired=args.show_expired, sort_key=args.sort_by
+        )
+    elif args.action == "details":
+        get_x509_cert(serial=args.serial, show_pem=args.show_pem)
+    elif args.action == "dump":
+        dump_x509_cert(serial=args.serial)
+elif args.object == "ssh":
+    if args.action == "list":
+        list_ssh_certs(
+            revoked=args.show_revoked, expired=args.show_expired, sort_key=args.sort_by
+        )
+    elif args.action == "details":
+        get_ssh_cert(serial=args.serial)
+    elif args.action == "dump":
+        dump_ssh_cert(serial=args.serial)