From 6c02075a5791db0e8db657087833c736c3e0b0c1 Mon Sep 17 00:00:00 2001 From: Benjamin Collet Date: Tue, 7 Jan 2025 09:50:19 +0100 Subject: [PATCH] Initial commit --- .gitignore | 2 + models/config.py | 25 +++++ models/ssh_cert.py | 156 ++++++++++++++++++++++++++ models/x509_cert.py | 137 +++++++++++++++++++++++ step-ca-inspector.py | 254 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 574 insertions(+) create mode 100644 .gitignore create mode 100644 models/config.py create mode 100644 models/ssh_cert.py create mode 100644 models/x509_cert.py create mode 100755 step-ca-inspector.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..69ad294 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.python-version +__pycache__/ diff --git a/models/config.py b/models/config.py new file mode 100644 index 0000000..d9945b5 --- /dev/null +++ b/models/config.py @@ -0,0 +1,25 @@ +import os +import sys +import yaml + +class config: + @classmethod + def __init__(self): + for config_path in ( + os.path.expanduser("~/.config/step-ca-inspector"), + os.environ.get("STEP_CA_INSPECTOR_CONF"), + ): + if config_path is None: + continue + try: + with open(os.path.join(config_path, "config.yaml")) as ymlfile: + cfg = yaml.load(ymlfile, Loader=yaml.FullLoader) + break + except IOError: + pass + else: + print("No configuration file found") + sys.exit(1) + + for k, v in cfg.items(): + setattr(self, k, v) diff --git a/models/ssh_cert.py b/models/ssh_cert.py new file mode 100644 index 0000000..3e021e0 --- /dev/null +++ b/models/ssh_cert.py @@ -0,0 +1,156 @@ +import base64 +import dateutil +import json +import mariadb +from cryptography import x509 +from cryptography.hazmat.primitives import asymmetric, hashes, serialization +from datetime import datetime, timedelta, timezone +from models.config import config +from struct import unpack + + +config() +conn = mariadb.connect( + host=config.database_host, + user=config.database_user, + password=config.database_password, + database=config.database_name, +) + + +class list: + certs = [] + + def __new__(cls, sort_key=None): + cur = conn.cursor() + cur.execute("SELECT nkey FROM ssh_certs") + + for (cert_serial,) in cur: + cert_object = cert(cert_serial) + cls.certs.append(cert_object) + + cur.close() + + if sort_key is not None: + cls.certs.sort(key=lambda item: getattr(item, sort_key)) + + return cls.certs + + +class cert: + def __init__(self, serial): + cert_raw = self.get_cert(serial) + size = unpack(">I", cert_raw[:4])[0] + 4 + alg = cert_raw[4:size] + + cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)]) + cert_revoked = self.get_cert_revoked(serial) + self.load(cert_pub_id, cert_revoked, alg) + + def load(self, cert_pub_id, cert_revoked, cert_alg): + cert = serialization.load_ssh_public_identity(cert_pub_id) + self.serial = cert.serial + self.alg = cert_alg + if cert.type == serialization.SSHCertificateType.USER: + self.type = "User" + self.key_id = cert.key_id + self.principals = cert.valid_principals + self.not_after = datetime.fromtimestamp(cert.valid_before).replace( + tzinfo=timezone(offset=timedelta()), microsecond=0 + ) + self.not_before = datetime.fromtimestamp(cert.valid_after).replace( + tzinfo=timezone(offset=timedelta()), microsecond=0 + ) + # TODO: Implement critical options parsing + # cert.critical_options + self.extensions = cert.extensions + + (self.signing_key, self.signing_key_type, self.signing_key_hash) = ( + self.get_public_key_params(cert.signature_key()) + ) + + (self.public_key, self.public_key_type, self.public_key_hash) = ( + self.get_public_key_params(cert.public_key()) + ) + + self.public_identity = cert.public_bytes() + + if cert_revoked is not None: + self.revoked_at = dateutil.parser.isoparse( + cert_revoked.get("RevokedAt") + ).replace(microsecond=0) + else: + self.revoked_at = None + + now_with_tz = datetime.utcnow().replace( + tzinfo=timezone(offset=timedelta()), microsecond=0 + ) + + if self.revoked_at is not None and self.revoked_at < now_with_tz: + self.status = status(status.REVOKED) + elif self.not_after < now_with_tz: + self.status = status(status.EXPIRED) + else: + self.status = status(status.VALID) + + def get_cert(self, cert_serial): + cur = conn.cursor() + cur.execute("SELECT nvalue FROM ssh_certs WHERE nkey=?", (cert_serial,)) + if cur.rowcount > 0: + (cert,) = cur.fetchone() + else: + cert = None + + cur.close() + return cert + + def get_cert_revoked(self, cert_serial): + cur = conn.cursor() + cur.execute("SELECT nvalue FROM revoked_ssh_certs WHERE nkey=?", (cert_serial,)) + if cur.rowcount > 0: + (cert_revoked_raw,) = cur.fetchone() + cert_revoked = json.loads(cert_revoked_raw) + else: + cert_revoked = None + + cur.close() + return cert_revoked + + def get_public_key_params(self, public_key): + if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey): + key_type = "ECDSA" + elif isinstance(public_key, asymmetric.ed25519.Ed25519PublicKey): + key_type = "ED25519" + elif isinstance(public_key, asymmetric.rsa.RSAPublicKey): + key_type = "RSA" + + key_str = public_key.public_bytes( + serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH + ) + + key_data = key_str.strip().split()[1] + digest = hashes.Hash(hashes.SHA256()) + digest.update(base64.b64decode(key_data)) + hash_sha256 = digest.finalize() + key_hash = base64.b64encode(hash_sha256) + + return key_str, key_type, key_hash + + +class status: + REVOKED = 1 + EXPIRED = 2 + VALID = 3 + + def __init__(self, status): + self.value = status + + def __str__(self): + if self.value == self.EXPIRED: + return "Expired" + elif self.value == self.REVOKED: + return "Revoked" + elif self.value == self.VALID: + return "Valid" + else: + return "Undefined" diff --git a/models/x509_cert.py b/models/x509_cert.py new file mode 100644 index 0000000..f9872af --- /dev/null +++ b/models/x509_cert.py @@ -0,0 +1,137 @@ +import binascii +import dateutil +import json +import mariadb +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from datetime import datetime, timedelta, timezone +from models.config import config + + +config() +conn = mariadb.connect( + host = config.database_host, + user = config.database_user, + password = config.database_password, + database = config.database_name +) + + +class list: + certs = [] + + def __new__(cls, sort_key=None): + cur = conn.cursor() + cur.execute("SELECT nkey FROM x509_certs") + + for (cert_serial,) in cur: + cert_object = cert(cert_serial) + cls.certs.append(cert_object) + + cur.close() + + if sort_key is not None: + cls.certs.sort(key=lambda item: getattr(item, sort_key)) + + return cls.certs + + + +class cert: + def __init__(self, serial): + cert_der = self.get_cert(serial) + cert_data = self.get_cert_data(serial) + cert_revoked = self.get_cert_revoked(serial) + self.load(cert_der, cert_data, cert_revoked) + + + def load(self, cert_der, cert_data, cert_revoked): + cert = x509.load_der_x509_certificate(cert_der) + + self.pem = cert.public_bytes(serialization.Encoding.PEM) + self.serial = str(cert.serial_number) + self.sha256 = binascii.b2a_hex(cert.fingerprint(hashes.SHA256())) + self.sha1 = binascii.b2a_hex(cert.fingerprint(hashes.SHA1())) + self.md5 = binascii.b2a_hex(cert.fingerprint(hashes.MD5())) + self.pub_alg = cert.public_key_algorithm_oid._name + self.sig_alg = cert.signature_algorithm_oid._name + self.issuer = cert.issuer.rfc4514_string() + self.subject = cert.subject.rfc4514_string({x509.NameOID.EMAIL_ADDRESS: "E"}) + self.not_before = cert.not_valid_before_utc.replace(microsecond=0) + self.not_after = cert.not_valid_after_utc.replace(microsecond=0) + san_data = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName) + self.san_names = san_data.value.get_values_for_type(x509.GeneralName) + self.provisioner = cert_data.get("provisioner", None) + + if cert_revoked is not None: + self.revoked_at = dateutil.parser.isoparse( + cert_revoked.get("RevokedAt") + ).replace(microsecond=0) + else: + self.revoked_at = None + + now_with_tz = datetime.utcnow().replace( + tzinfo=timezone(offset=timedelta()), microsecond=0 + ) + + if self.revoked_at is not None and self.revoked_at < now_with_tz: + self.status = status(status.REVOKED) + elif self.not_after < now_with_tz: + self.status = status(status.EXPIRED) + else: + self.status = status(status.VALID) + + + def get_cert(self, cert_serial): + cur = conn.cursor() + cur.execute("SELECT nvalue FROM x509_certs WHERE nkey=?", (cert_serial,)) + if cur.rowcount > 0: + (cert,) = cur.fetchone() + else: + cert = None + + cur.close() + return cert + + + def get_cert_data(self, cert_serial): + cur = conn.cursor() + cur.execute("SELECT nvalue FROM x509_certs_data WHERE nkey=?", (cert_serial,)) + (cert_data_raw,) = cur.fetchone() + cur.close() + cert_data = json.loads(cert_data_raw) + return cert_data + + + def get_cert_revoked(self, cert_serial): + cur = conn.cursor() + cur.execute( + "SELECT nvalue FROM revoked_x509_certs WHERE nkey=?", (cert_serial,) + ) + if cur.rowcount > 0: + (cert_revoked_raw,) = cur.fetchone() + cert_revoked = json.loads(cert_revoked_raw) + else: + cert_revoked = None + + cur.close() + return cert_revoked + + +class status: + REVOKED = 1 + EXPIRED = 2 + VALID = 3 + + def __init__(self, status): + self.value = status + + def __str__(self): + if self.value == self.EXPIRED: + return "Expired" + elif self.value == self.REVOKED: + return "Revoked" + elif self.value == self.VALID: + return "Valid" + else: + return "Undefined" diff --git a/step-ca-inspector.py b/step-ca-inspector.py new file mode 100755 index 0000000..e942a1f --- /dev/null +++ b/step-ca-inspector.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 + +import argparse +import os +import sys +import yaml +from tabulate import tabulate +from models import ssh_cert, x509_cert + + +def list_ssh_certs(sort_key, revoked=False, expired=False): + cert_list = ssh_cert.list(sort_key=sort_key) + cert_tbl = [] + for cert in cert_list: + if cert.status.value == ssh_cert.status.EXPIRED and not expired: + continue + if cert.status.value == ssh_cert.status.REVOKED and not revoked: + continue + + cert_row = {} + cert_row["Serial"] = cert.serial + cert_row["Type"] = cert.type + cert_row["Key ID"] = cert.key_id + principals_count = len(cert.principals) + principals_list = [x.decode() for x in cert.principals] + if principals_count > 2: + principals = principals_list[:2] + [f"+{principals_count - 2} more"] + cert_row["Principals"] = "\n".join(principals) + + validity = [] + validity.append(f"Not before: {cert.not_before}") + validity.append(f"Not after: {cert.not_after}") + if cert.revoked_at is not None: + validity.append(f"Revoked at: {cert.revoked_at}") + + cert_row["Validity"] = "\n".join(validity) + cert_row["Status"] = cert.status + + cert_tbl.append(cert_row) + + print(tabulate(cert_tbl, headers="keys", tablefmt="fancy_grid")) + + +def get_ssh_cert(serial): + cert = ssh_cert.cert(serial) + cert_tbl = [] + + cert_tbl.append(["Serial", cert.serial]) + cert_tbl.append(["Certificate type", cert.type]) + cert_tbl.append(["Certificate key type", cert.alg.decode()]) + public_key = f"{cert.public_key_type} SHA256:{cert.public_key_hash.decode()}" + cert_tbl.append(["Public key", public_key]) + signing_key = f"{cert.signing_key_type} SHA256:{cert.signing_key_hash.decode()}" + cert_tbl.append(["Signing key", signing_key]) + cert_tbl.append(["Key ID", cert.key_id.decode()]) + principals = [x.decode() for x in cert.principals] + cert_tbl.append(["Principals", "\n".join(principals)]) + cert_tbl.append(["Not valid before", cert.not_before]) + cert_tbl.append(["Not valid after", cert.not_after]) + if cert.revoked_at is not None: + cert_tbl.append(["Revoked at", cert.revoked_at]) + extensions = [x.decode() for x in cert.extensions] + cert_tbl.append(["Extensions", "\n".join(extensions)]) + # cert_tbl.append(["Signing key", cert.signing_key.decode()]) + cert_tbl.append(["Status", cert.status]) + + print(tabulate(cert_tbl, tablefmt="fancy_grid")) + + +def dump_ssh_cert(serial): + cert = ssh_cert.cert(serial) + print(cert.public_identity.decode()) + + +def list_x509_certs(sort_key, revoked=False, expired=False): + cert_list = x509_cert.list(sort_key=sort_key) + cert_tbl = [] + for cert in cert_list: + if cert.status.value == x509_cert.status.EXPIRED and not expired: + continue + if cert.status.value == x509_cert.status.REVOKED and not revoked: + continue + + cert_row = {} + cert_row["Serial"] = cert.serial + cert_row["Subject"] = "%.30s" % cert.subject + cert_row["Subject Alt Names (SAN)"] = "\n".join( + ["%.30s" % x for x in cert.san_names] + ) + cert_row["Provisioner"] = ( + f"{cert.provisioner['name']} ({cert.provisioner['type']})" + ) + validity = [] + validity.append(f"Not before: {cert.not_before}") + validity.append(f"Not after: {cert.not_after}") + if cert.revoked_at is not None: + validity.append(f"Revoked at: {cert.revoked_at}") + + cert_row["Validity"] = "\n".join(validity) + cert_row["Status"] = cert.status + + cert_tbl.append(cert_row) + + print(tabulate(cert_tbl, headers="keys", tablefmt="fancy_grid")) + + +def get_x509_cert(serial, show_pem=False): + cert = x509_cert.cert(serial) + cert_tbl = [] + + cert_tbl.append(["Serial", cert.serial]) + cert_tbl.append(["Subject", cert.subject]) + cert_tbl.append(["Subject Alt Names (SAN)", "\n".join(cert.san_names)]) + cert_tbl.append(["Issuer", cert.issuer]) + cert_tbl.append(["Not valid before", cert.not_before]) + cert_tbl.append(["Not valid after", cert.not_after]) + if cert.revoked_at is not None: + cert_tbl.append(["Revoked at", cert.revoked_at]) + cert_tbl.append( + ["Provisioner", f"{cert.provisioner['name']} ({cert.provisioner['type']})"] + ) + fingerprints = [] + fingerprints.append(f"MD5: {cert.md5.decode()}") + fingerprints.append(f"SHA-1: {cert.sha1.decode()}") + fingerprints.append(f"SHA-256: {cert.sha256.decode()}") + cert_tbl.append(["Fingerprints", "\n".join(fingerprints)]) + cert_tbl.append(["Public key algorithm", cert.pub_alg]) + cert_tbl.append(["Signature algorithm", cert.sig_alg]) + cert_tbl.append(["Status", cert.status]) + # cert_tbl.append(["Extensions", cert.extensions]) + if show_pem: + cert_tbl.append(["PEM", cert.pem.decode("utf-8")]) + + print(tabulate(cert_tbl, tablefmt="fancy_grid")) + + +def dump_x509_cert(serial, cert_format="pem"): + cert = x509_cert.cert(serial) + print(cert.pem.decode("utf-8").rstrip()) + + +parser = argparse.ArgumentParser(description="Step CA Inspector") +subparsers = parser.add_subparsers( + help="Object to inspect", dest="object", required=True +) +x509_parser = subparsers.add_parser("x509", help="x509 certificates") +x509_subparsers = x509_parser.add_subparsers( + help="Action for perform", dest="action", required=True +) +x509_list_parser = x509_subparsers.add_parser("list", help="List x509 certificates") +x509_list_parser.add_argument( + "--show-expired", + "-e", + action="store_true", + default=False, + help="Show expired certificates", +) +x509_list_parser.add_argument( + "--show-revoked", + "-r", + action="store_true", + default=False, + help="Show revoked certificates", +) +x509_list_parser.add_argument( + "--sort-by", + "-s", + type=str, + choices=["not_after", "not_before"], + default="not_after", + help="Sort certificates", +) +x509_details_parser = x509_subparsers.add_parser( + "details", help="Show an x509 certificate details" +) +x509_details_parser.add_argument( + "--serial", "-s", type=str, required=True, help="Certificate serial" +) +x509_details_parser.add_argument( + "--show-pem", + "-p", + action="store_true", + default=False, + help="Show PEM", +) +x509_dump_parser = x509_subparsers.add_parser("dump", help="Dump an x509 certificate") +x509_dump_parser.add_argument( + "--serial", "-s", type=str, required=True, help="Certificate serial" +) +x509_dump_parser.add_argument( + "--format", + "-f", + type=str, + choices=["pem"], + required=False, + help="Certificate format", +) +ssh_parser = subparsers.add_parser("ssh", help="ssh certificates") +ssh_subparsers = ssh_parser.add_subparsers( + help="Action for perform", dest="action", required=True +) +ssh_list_parser = ssh_subparsers.add_parser("list", help="List ssh certificates") +ssh_list_parser.add_argument( + "--show-expired", + "-e", + action="store_true", + default=False, + help="Show expired certificates", +) +ssh_list_parser.add_argument( + "--show-revoked", + "-r", + action="store_true", + default=False, + help="Show revoked certificates", +) +ssh_list_parser.add_argument( + "--sort-by", + "-s", + type=str, + choices=["not_after", "not_before"], + default="not_after", + help="Sort certificates (default: not_after)", +) +ssh_details_parser = ssh_subparsers.add_parser( + "details", help="Show an ssh certificate details" +) +ssh_details_parser.add_argument( + "--serial", "-s", type=str, required=True, help="Certificate serial" +) +ssh_dump_parser = ssh_subparsers.add_parser("dump", help="Dump an ssh certificate") +ssh_dump_parser.add_argument( + "--serial", "-s", type=str, required=True, help="Certificate serial" +) +args = parser.parse_args() + +if args.object == "x509": + if args.action == "list": + list_x509_certs( + revoked=args.show_revoked, expired=args.show_expired, sort_key=args.sort_by + ) + elif args.action == "details": + get_x509_cert(serial=args.serial, show_pem=args.show_pem) + elif args.action == "dump": + dump_x509_cert(serial=args.serial) +elif args.object == "ssh": + if args.action == "list": + list_ssh_certs( + revoked=args.show_revoked, expired=args.show_expired, sort_key=args.sort_by + ) + elif args.action == "details": + get_ssh_cert(serial=args.serial) + elif args.action == "dump": + dump_ssh_cert(serial=args.serial)