From 1583cda39bdd5b17af894d8c4ff2191ffeb7c66b Mon Sep 17 00:00:00 2001 From: Benjamin Collet <benjamin@collet.eu> Date: Mon, 24 Mar 2025 08:06:19 +0100 Subject: [PATCH] Refactor to use step-ca-inspector API server --- models/config.py => config.py | 6 + models/ssh_cert.py | 166 ---------------------- models/x509_cert.py | 172 ----------------------- requirements.txt | 3 - step-ca-inspector.py | 256 +++++++++++++++++++++------------- 5 files changed, 162 insertions(+), 441 deletions(-) rename models/config.py => config.py (76%) delete mode 100644 models/ssh_cert.py delete mode 100644 models/x509_cert.py diff --git a/models/config.py b/config.py similarity index 76% rename from models/config.py rename to config.py index d9945b5..7acb267 100644 --- a/models/config.py +++ b/config.py @@ -23,3 +23,9 @@ class config: for k, v in cfg.items(): setattr(self, k, v) + + for setting in ["url"]: + if not hasattr(self, setting): + # FIXME: Raise instead + print(f"Mandatory setting {setting} is not configured.") + sys.exit(1) diff --git a/models/ssh_cert.py b/models/ssh_cert.py deleted file mode 100644 index 0ade70c..0000000 --- a/models/ssh_cert.py +++ /dev/null @@ -1,166 +0,0 @@ -import base64 -import dateutil -import json -import mariadb -from cryptography import x509 -from cryptography.hazmat.primitives import asymmetric, hashes, serialization -from datetime import datetime, timedelta, timezone -from models.config import config -from struct import unpack - - -config() -conn = mariadb.connect( - host=config.database_host, - user=config.database_user, - password=config.database_password, - database=config.database_name, -) - - -class list: - certs = [] - - def __new__(cls, sort_key=None): - cur = conn.cursor() - cur.execute( - """SELECT ssh_certs.nvalue AS cert, - revoked_ssh_certs.nvalue AS revoked - FROM ssh_certs - LEFT JOIN revoked_ssh_certs USING(nkey)""" - ) - - for result in cur: - cert_object = cert(result) - cls.certs.append(cert_object) - - cur.close() - - if sort_key is not None: - cls.certs.sort(key=lambda item: getattr(item, sort_key)) - - return cls.certs - - -class cert: - def __init__(self, cert): - (cert_raw, cert_revoked_raw) = cert - - size = unpack(">I", cert_raw[:4])[0] + 4 - alg = cert_raw[4:size] - - cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)]) - - if cert_revoked_raw is not None: - cert_revoked = json.loads(cert_revoked_raw) - else: - cert_revoked = None - - self.load(cert_pub_id, cert_revoked, alg) - - @classmethod - def from_serial(cls, serial): - return cls(cert=cls.get_cert(cls, serial)) - - def load(self, cert_pub_id, cert_revoked, cert_alg): - cert = serialization.load_ssh_public_identity(cert_pub_id) - self.serial = cert.serial - self.alg = cert_alg - if cert.type == serialization.SSHCertificateType.USER: - self.type = "User" - self.key_id = cert.key_id - self.principals = cert.valid_principals - self.not_after = datetime.fromtimestamp(cert.valid_before).replace( - tzinfo=timezone(offset=timedelta()), microsecond=0 - ) - self.not_before = datetime.fromtimestamp(cert.valid_after).replace( - tzinfo=timezone(offset=timedelta()), microsecond=0 - ) - # TODO: Implement critical options parsing - # cert.critical_options - self.extensions = cert.extensions - - (self.signing_key, self.signing_key_type, self.signing_key_hash) = ( - self.get_public_key_params(cert.signature_key()) - ) - - (self.public_key, self.public_key_type, self.public_key_hash) = ( - self.get_public_key_params(cert.public_key()) - ) - - self.public_identity = cert.public_bytes() - - if cert_revoked is not None: - self.revoked_at = dateutil.parser.isoparse( - cert_revoked.get("RevokedAt") - ).replace(microsecond=0) - else: - self.revoked_at = None - - now_with_tz = datetime.utcnow().replace( - tzinfo=timezone(offset=timedelta()), microsecond=0 - ) - - if self.revoked_at is not None and self.revoked_at < now_with_tz: - self.status = status(status.REVOKED) - elif self.not_after < now_with_tz: - self.status = status(status.EXPIRED) - else: - self.status = status(status.VALID) - - def get_cert(self, cert_serial): - cur = conn.cursor() - cur.execute( - """SELECT ssh_certs.nvalue AS cert, - revoked_ssh_certs.nvalue AS revoked - FROM ssh_certs - LEFT JOIN revoked_ssh_certs USING(nkey) - WHERE nkey=?""", - (cert_serial,), - ) - if cur.rowcount > 0: - cert = cur.fetchone() - else: - cert = None - - cur.close() - return cert - - def get_public_key_params(self, public_key): - if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey): - key_type = "ECDSA" - elif isinstance(public_key, asymmetric.ed25519.Ed25519PublicKey): - key_type = "ED25519" - elif isinstance(public_key, asymmetric.rsa.RSAPublicKey): - key_type = "RSA" - - key_str = public_key.public_bytes( - serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH - ) - - key_data = key_str.strip().split()[1] - digest = hashes.Hash(hashes.SHA256()) - digest.update(base64.b64decode(key_data)) - hash_sha256 = digest.finalize() - key_hash = base64.b64encode(hash_sha256) - - return key_str, key_type, key_hash - - -class status: - REVOKED = 1 - EXPIRED = 2 - VALID = 3 - - def __init__(self, status): - self.value = status - - def __str__(self): - if self.value == self.EXPIRED: - return "Expired" - elif self.value == self.REVOKED: - return "Revoked" - elif self.value == self.VALID: - return "Valid" - else: - return "Undefined" diff --git a/models/x509_cert.py b/models/x509_cert.py deleted file mode 100644 index 8d0c10d..0000000 --- a/models/x509_cert.py +++ /dev/null @@ -1,172 +0,0 @@ -import binascii -import dateutil -import json -import mariadb -from cryptography import x509 -from cryptography.hazmat.primitives import hashes, serialization -from datetime import datetime, timedelta, timezone -from models.config import config - - -config() -conn = mariadb.connect( - host=config.database_host, - user=config.database_user, - password=config.database_password, - database=config.database_name, -) - - -class list: - certs = [] - - def __new__(cls, sort_key=None): - cur = conn.cursor() - cur.execute( - """SELECT x509_certs.nvalue AS cert, - x509_certs_data.nvalue AS data, - revoked_x509_certs.nvalue AS revoked - FROM x509_certs - INNER JOIN x509_certs_data USING(nkey) - LEFT JOIN revoked_x509_certs USING(nkey)""" - ) - - for result in cur: - cert_object = cert(result) - cls.certs.append(cert_object) - - cur.close() - - if sort_key is not None: - cls.certs.sort(key=lambda item: getattr(item, sort_key)) - - return cls.certs - - -class cert: - def __init__(self, cert): - (cert_der, cert_data_raw, cert_revoked_raw) = cert - - cert_data = json.loads(cert_data_raw) - if cert_revoked_raw is not None: - cert_revoked = json.loads(cert_revoked_raw) - else: - cert_revoked = None - - self.load(cert_der, cert_data, cert_revoked) - - @classmethod - def from_serial(cls, serial): - return cls(cert=cls.get_cert(cls, serial)) - - def load(self, cert_der, cert_data, cert_revoked): - cert = x509.load_der_x509_certificate(cert_der) - - self.pem = cert.public_bytes(serialization.Encoding.PEM) - self.serial = str(cert.serial_number) - self.sha256 = binascii.b2a_hex(cert.fingerprint(hashes.SHA256())) - self.sha1 = binascii.b2a_hex(cert.fingerprint(hashes.SHA1())) - self.md5 = binascii.b2a_hex(cert.fingerprint(hashes.MD5())) - self.pub_key = cert.public_key().public_bytes( - serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo - ) - self.pub_alg = cert.public_key_algorithm_oid._name - self.sig_alg = cert.signature_algorithm_oid._name - self.issuer = cert.issuer.rfc4514_string() - self.subject = cert.subject.rfc4514_string({x509.NameOID.EMAIL_ADDRESS: "E"}) - self.not_before = cert.not_valid_before_utc.replace(microsecond=0) - self.not_after = cert.not_valid_after_utc.replace(microsecond=0) - try: - san_data = cert.extensions.get_extension_for_class( - x509.SubjectAlternativeName - ) - self.san_names = self.get_sans(san_data) - except x509.extensions.ExtensionNotFound: - self.san_names = [] - - self.provisioner = cert_data.get("provisioner", None) - - if cert_revoked is not None: - self.revoked_at = dateutil.parser.isoparse( - cert_revoked.get("RevokedAt") - ).replace(microsecond=0) - else: - self.revoked_at = None - - now_with_tz = datetime.utcnow().replace( - tzinfo=timezone(offset=timedelta()), microsecond=0 - ) - - if self.revoked_at is not None and self.revoked_at < now_with_tz: - self.status = status(status.REVOKED) - elif self.not_after < now_with_tz: - self.status = status(status.EXPIRED) - else: - self.status = status(status.VALID) - - def get_cert(self, cert_serial): - cur = conn.cursor() - cur.execute( - """SELECT x509_certs.nvalue AS cert, - x509_certs_data.nvalue AS data, - revoked_x509_certs.nvalue AS revoked - FROM x509_certs - INNER JOIN x509_certs_data USING(nkey) - LEFT JOIN revoked_x509_certs USING(nkey) - WHERE nkey=?""", - (cert_serial,), - ) - - if cur.rowcount > 0: - cert = cur.fetchone() - else: - cert = None - - cur.close() - return cert - - def get_sans(self, san_data): - sans = [] - - for san_value in san_data.value: - san = {} - if isinstance(san_value, x509.general_name.DNSName): - san["type"] = "DNS" - elif isinstance(san_value, x509.general_name.UniformResourceIdentifier): - san["type"] = "URI" - elif isinstance(san_value, x509.general_name.RFC822Name): - san["type"] = "Email" - elif isinstance(san_value, x509.general_name.IPAddress): - san["type"] = "IP" - elif isinstance(san_value, x509.general_name.DirectoryName): - san["type"] = "DirectoryName" - elif isinstance(san_value, x509.general_name.RegisteredID): - san["type"] = "RegisteredID" - elif isinstance(san_value, x509.general_name.OtherName): - san["type"] = "Other ({san_value.type_id})" - else: - continue - - san["value"] = san_value.value - sans.append(san) - - return sans - - -class status: - REVOKED = 1 - EXPIRED = 2 - VALID = 3 - - def __init__(self, status): - self.value = status - - def __str__(self): - if self.value == self.EXPIRED: - return "Expired" - elif self.value == self.REVOKED: - return "Revoked" - elif self.value == self.VALID: - return "Valid" - else: - return "Undefined" diff --git a/requirements.txt b/requirements.txt index 7eb2b06..2315779 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,2 @@ -PyYAML -cryptography -mariadb python-dateutil tabulate diff --git a/step-ca-inspector.py b/step-ca-inspector.py index 027387a..529347a 100755 --- a/step-ca-inspector.py +++ b/step-ca-inspector.py @@ -1,16 +1,17 @@ #!/usr/bin/env python3 import argparse -import os -import sys -import yaml +import requests +from urllib.parse import urljoin from datetime import datetime, timedelta, timezone from tabulate import tabulate -from models import ssh_cert, x509_cert +from config import config + +config() def delta_text(delta): - s = 's'[:abs(delta.days)^1] + s = "s"[: abs(delta.days) ^ 1] if delta < timedelta(days=-1): return f"in {abs(delta.days)} day{s}" @@ -22,38 +23,53 @@ def delta_text(delta): return f"{delta.days} day{s} ago" +def fetch_api(endpoint, params={}): + try: + results = requests.get(urljoin(config.url, endpoint), params=params) + results.raise_for_status() + except requests.HTTPError as e: + raise e + except requests.Timeout as e: + raise e + # request took too long + + return results.json() + + def list_ssh_certs(sort_key, revoked=False, expired=False): - cert_list = ssh_cert.list(sort_key=sort_key) + params = { + "sort_key": sort_key, + "revoked": revoked, + "expired": expired, + } + cert_list = fetch_api("ssh/certs", params=params) + cert_tbl = [] for cert in cert_list: - if cert.status.value == ssh_cert.status.EXPIRED and not expired: - continue - if cert.status.value == ssh_cert.status.REVOKED and not revoked: - continue - cert_row = {} - cert_row["Serial"] = cert.serial - cert_row["Type"] = cert.type - cert_row["Key ID"] = cert.key_id - principals_count = len(cert.principals) - principals_list = [x.decode() for x in cert.principals] + cert_row["Serial"] = cert["serial"] + cert_row["Type"] = cert["type"] + cert_row["Key ID"] = cert["key_id"] + principals_count = len(cert["principals"]) if principals_count > 2: - principals = principals_list[:2] + [f"+{principals_count - 2} more"] + principals = cert["principals"][:2] + [f"+{principals_count - 2} more"] else: - principals = principals_list + principals = cert["principals"] cert_row["Principals"] = "\n".join(principals) - now_with_tz = datetime.utcnow().replace( - tzinfo=timezone(offset=timedelta()), microsecond=0 - ) + now_with_tz = datetime.now(timezone.utc).replace(microsecond=0) - if cert.revoked_at is not None: - delta = now_with_tz - cert.revoked_at + if cert["revoked_at"] is not None: + delta = now_with_tz - datetime.fromtimestamp( + cert["revoked_at"], tz=timezone.utc + ) else: - delta = now_with_tz - cert.not_after + delta = now_with_tz - datetime.fromtimestamp( + cert["not_after"], tz=timezone.utc + ) cert_row["Expires"] = delta_text(delta).capitalize() - cert_row["Status"] = cert.status + cert_row["Status"] = cert["status"] cert_tbl.append(cert_row) @@ -61,87 +77,104 @@ def list_ssh_certs(sort_key, revoked=False, expired=False): def get_ssh_cert(serial): - cert = ssh_cert.cert.from_serial(serial) + cert = fetch_api(f"ssh/certs/{serial}") + if cert is None: + return + cert_tbl = [] - - cert_tbl.append(["Serial", cert.serial]) - cert_tbl.append(["Certificate type", cert.type]) - cert_tbl.append(["Certificate key type", cert.alg.decode()]) - public_key = f"{cert.public_key_type} SHA256:{cert.public_key_hash.decode()}" + cert_tbl.append(["Serial", cert["serial"]]) + cert_tbl.append(["Certificate type", cert["type"]]) + cert_tbl.append(["Certificate key type", cert["alg"]]) + public_key = f"{cert['public_key_type']} SHA256:{cert['public_key_hash']}" cert_tbl.append(["Public key", public_key]) - signing_key = f"{cert.signing_key_type} SHA256:{cert.signing_key_hash.decode()}" + signing_key = f"{cert['signing_key_type']} SHA256:{cert['signing_key_hash']}" cert_tbl.append(["Signing key", signing_key]) - cert_tbl.append(["Key ID", cert.key_id.decode()]) - principals = [x.decode() for x in cert.principals] - cert_tbl.append(["Principals", "\n".join(principals)]) + cert_tbl.append(["Key ID", cert["key_id"]]) + cert_tbl.append(["Principals", "\n".join(cert["principals"])]) - now_with_tz = datetime.utcnow().replace( - tzinfo=timezone(offset=timedelta()), microsecond=0 + now_with_tz = datetime.now(timezone.utc).replace(microsecond=0) + + delta_after = now_with_tz - datetime.fromtimestamp( + cert["not_after"], tz=timezone.utc + ) + delta_before = now_with_tz - datetime.fromtimestamp( + cert["not_before"], tz=timezone.utc ) - - delta_after = now_with_tz - cert.not_after - delta_before = now_with_tz - cert.not_before cert_tbl.append( - ["Not valid before", f"{cert.not_before} ({delta_text(delta_before)})"] + [ + "Not valid before", + f"{datetime.fromtimestamp(cert['not_before']).astimezone()} ({delta_text(delta_before)})", + ] ) cert_tbl.append( - ["Not valid after", f"{cert.not_after} ({delta_text(delta_after)})"] + [ + "Not valid after", + f"{datetime.fromtimestamp(cert['not_after']).astimezone()} ({delta_text(delta_after)})", + ] ) - if cert.revoked_at is not None: - delta_revoked = now_with_tz - cert.revoked_at - cert_tbl.append( - ["Revoked at", f"{cert.revoked_at} ({delta_text(delta_revoked)})"] + + if cert["revoked_at"] is not None: + delta_revoked = now_with_tz - datetime.fromtimestamp( + cert["revoked_at"], tz=timezone.utc ) - cert_tbl.append(["Valid for", f"{delta_revoked.days} days"]) - else: - cert_tbl.append(["Valid for", f"{abs(delta_after.days)} days"]) - extensions = [x.decode() for x in cert.extensions] - cert_tbl.append(["Extensions", "\n".join(extensions)]) - # cert_tbl.append(["Signing key", cert.signing_key.decode()]) - cert_tbl.append(["Status", cert.status]) + cert_tbl.append( + [ + "Revoked at", + f"{datetime.fromtimestamp(cert['revoked_at']).astimezone()} ({delta_text(delta_revoked)})", + ] + ) + + cert_tbl.append(["Extensions", "\n".join(cert["extensions"])]) + #cert_tbl.append(["Signing key", cert["signing_key"]]) + cert_tbl.append(["Status", cert["status"]]) print(tabulate(cert_tbl, tablefmt="fancy_grid")) def dump_ssh_cert(serial): - cert = ssh_cert.cert.from_serial(serial) - print(cert.public_identity.decode()) + cert = fetch_api(f"ssh/certs/{serial}") + if cert is None: + return + + print(cert["public_identity"]) def list_x509_certs(sort_key, revoked=False, expired=False): - cert_list = x509_cert.list(sort_key=sort_key) + params = { + "sort_key": sort_key, + "revoked": revoked, + "expired": expired, + } + cert_list = fetch_api(f"x509/certs", params=params) cert_tbl = [] for cert in cert_list: - if cert.status.value == x509_cert.status.EXPIRED and not expired: - continue - if cert.status.value == x509_cert.status.REVOKED and not revoked: - continue - cert_row = {} - cert_row["Serial"] = cert.serial + cert_row["Serial"] = cert["serial"] cert_row["Subject/Subject Alt Names (SAN)"] = "\n".join( [ "%.33s" % x - for x in [cert.subject] - + [f"{x['type']}: {x['value']}" for x in cert.san_names] + for x in [cert["subject"]] + + [f"{x['type']}: {x['value']}" for x in cert["san_names"]] ] ) cert_row["Provisioner"] = ( - f"{cert.provisioner['name']} ({cert.provisioner['type']})" + f"{cert['provisioner']['name']} ({cert['provisioner']['type']})" ) - now_with_tz = datetime.utcnow().replace( - tzinfo=timezone(offset=timedelta()), microsecond=0 - ) + now_with_tz = datetime.now(timezone.utc).replace(microsecond=0) - if cert.revoked_at is not None: - delta = now_with_tz - cert.revoked_at + if cert["revoked_at"] is not None: + delta = now_with_tz - datetime.fromtimestamp( + cert["revoked_at"], tz=timezone.utc + ) else: - delta = now_with_tz - cert.not_after + delta = now_with_tz - datetime.fromtimestamp( + cert["not_after"], tz=timezone.utc + ) cert_row["Expires"] = delta_text(delta).capitalize() - cert_row["Status"] = cert.status + cert_row["Status"] = cert["status"] cert_tbl.append(cert_row) @@ -149,64 +182,87 @@ def list_x509_certs(sort_key, revoked=False, expired=False): def get_x509_cert(serial, show_cert=False, show_pubkey=False): - cert = x509_cert.cert.from_serial(serial) - cert_tbl = [] + cert = fetch_api(f"x509/certs/{serial}") - cert_tbl.append(["Serial", cert.serial]) - cert_tbl.append(["Subject", cert.subject]) + if cert is None: + return + + cert_tbl = [] + cert_tbl.append(["Serial", cert["serial"]]) + cert_tbl.append(["Subject", cert["subject"]]) cert_tbl.append( [ "Subject Alt Names (SAN)", - "\n".join([f"{x['type']}: {x['value']}" for x in cert.san_names]), + "\n".join([f"{x['type']}: {x['value']}" for x in cert["san_names"]]), ] ) - cert_tbl.append(["Issuer", cert.issuer]) + cert_tbl.append(["Issuer", cert["issuer"]]) - now_with_tz = datetime.utcnow().replace( - tzinfo=timezone(offset=timedelta()), microsecond=0 + now_with_tz = datetime.now(timezone.utc).replace(microsecond=0) + + delta_after = now_with_tz - datetime.fromtimestamp( + cert["not_after"], tz=timezone.utc + ) + delta_before = now_with_tz - datetime.fromtimestamp( + cert["not_before"], tz=timezone.utc ) - - delta_after = now_with_tz - cert.not_after - delta_before = now_with_tz - cert.not_before cert_tbl.append( - ["Not valid before", f"{cert.not_before} ({delta_text(delta_before)})"] + [ + "Not valid before", + f"{datetime.fromtimestamp(cert['not_before']).astimezone()} ({delta_text(delta_before)})", + ] ) cert_tbl.append( - ["Not valid after", f"{cert.not_after} ({delta_text(delta_after)})"] + [ + "Not valid after", + f"{datetime.fromtimestamp(cert['not_after']).astimezone()} ({delta_text(delta_after)})", + ] ) - if cert.revoked_at is not None: - delta_revoked = now_with_tz - cert.revoked_at + if cert["revoked_at"] is not None: + delta_revoked = now_with_tz - datetime.fromtimestamp( + cert["revoked_at"], tz=timezone.utc + ) cert_tbl.append( - ["Revoked at", f"{cert.revoked_at} ({delta_text(delta_revoked)})"] + [ + "Revoked at", + f"{datetime.fromtimestamp(cert['revoked_at']).astimezone()} ({delta_text(delta_revoked)})", + ] ) cert_tbl.append(["Valid for", f"{delta_revoked.days} days"]) else: cert_tbl.append(["Valid for", f"{abs(delta_after.days)} days"]) cert_tbl.append( - ["Provisioner", f"{cert.provisioner['name']} ({cert.provisioner['type']})"] + [ + "Provisioner", + f"{cert['provisioner']['name']} ({cert['provisioner']['type']})", + ] ) fingerprints = [] - fingerprints.append(f"MD5: {cert.md5.decode()}") - fingerprints.append(f"SHA-1: {cert.sha1.decode()}") - fingerprints.append(f"SHA-256: {cert.sha256.decode()}") + fingerprints.append(f"MD5: {cert['md5']}") + fingerprints.append(f"SHA-1: {cert['sha1']}") + fingerprints.append(f"SHA-256: {cert['sha256']}") cert_tbl.append(["Fingerprints", "\n".join(fingerprints)]) - cert_tbl.append(["Public key algorithm", cert.pub_alg]) - cert_tbl.append(["Signature algorithm", cert.sig_alg]) - cert_tbl.append(["Status", cert.status]) - # cert_tbl.append(["Extensions", cert.extensions]) + cert_tbl.append(["Public key algorithm", cert["pub_alg"]]) + cert_tbl.append(["Signature algorithm", cert["sig_alg"]]) + cert_tbl.append(["Status", cert["status"]]) + if show_pubkey: - cert_tbl.append(["Public key", cert.pub_key.decode("utf-8")]) + cert_tbl.append(["Public key", cert["pub_key"]]) if show_cert: - cert_tbl.append(["PEM", cert.pem.decode("utf-8")]) + cert_tbl.append(["PEM", cert["pem"]]) print(tabulate(cert_tbl, tablefmt="fancy_grid")) def dump_x509_cert(serial, cert_format="pem"): - cert = x509_cert.cert.from_serial(serial) - print(cert.pem.decode("utf-8").rstrip()) + cert = fetch_api(f"x509/certs/{serial}") + + if cert is None: + return + + print(cert["pem"].rstrip()) parser = argparse.ArgumentParser(description="Step CA Inspector")