Refactor to use step-ca-inspector API server

This commit is contained in:
Benjamin Collet 2025-03-24 08:06:19 +01:00
parent 27a39e6bbc
commit 1583cda39b
Signed by: bcollet
SSH key fingerprint: SHA256:8UJspOIcCOS+MtSOcnuq2HjKFube4ox1s/+A62ixov4
5 changed files with 162 additions and 441 deletions

View file

@ -23,3 +23,9 @@ class config:
for k, v in cfg.items():
setattr(self, k, v)
for setting in ["url"]:
if not hasattr(self, setting):
# FIXME: Raise instead
print(f"Mandatory setting {setting} is not configured.")
sys.exit(1)

View file

@ -1,166 +0,0 @@
import base64
import dateutil
import json
import mariadb
from cryptography import x509
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
from datetime import datetime, timedelta, timezone
from models.config import config
from struct import unpack
config()
conn = mariadb.connect(
host=config.database_host,
user=config.database_user,
password=config.database_password,
database=config.database_name,
)
class list:
certs = []
def __new__(cls, sort_key=None):
cur = conn.cursor()
cur.execute(
"""SELECT ssh_certs.nvalue AS cert,
revoked_ssh_certs.nvalue AS revoked
FROM ssh_certs
LEFT JOIN revoked_ssh_certs USING(nkey)"""
)
for result in cur:
cert_object = cert(result)
cls.certs.append(cert_object)
cur.close()
if sort_key is not None:
cls.certs.sort(key=lambda item: getattr(item, sort_key))
return cls.certs
class cert:
def __init__(self, cert):
(cert_raw, cert_revoked_raw) = cert
size = unpack(">I", cert_raw[:4])[0] + 4
alg = cert_raw[4:size]
cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)])
if cert_revoked_raw is not None:
cert_revoked = json.loads(cert_revoked_raw)
else:
cert_revoked = None
self.load(cert_pub_id, cert_revoked, alg)
@classmethod
def from_serial(cls, serial):
return cls(cert=cls.get_cert(cls, serial))
def load(self, cert_pub_id, cert_revoked, cert_alg):
cert = serialization.load_ssh_public_identity(cert_pub_id)
self.serial = cert.serial
self.alg = cert_alg
if cert.type == serialization.SSHCertificateType.USER:
self.type = "User"
self.key_id = cert.key_id
self.principals = cert.valid_principals
self.not_after = datetime.fromtimestamp(cert.valid_before).replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
self.not_before = datetime.fromtimestamp(cert.valid_after).replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
# TODO: Implement critical options parsing
# cert.critical_options
self.extensions = cert.extensions
(self.signing_key, self.signing_key_type, self.signing_key_hash) = (
self.get_public_key_params(cert.signature_key())
)
(self.public_key, self.public_key_type, self.public_key_hash) = (
self.get_public_key_params(cert.public_key())
)
self.public_identity = cert.public_bytes()
if cert_revoked is not None:
self.revoked_at = dateutil.parser.isoparse(
cert_revoked.get("RevokedAt")
).replace(microsecond=0)
else:
self.revoked_at = None
now_with_tz = datetime.utcnow().replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED)
elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED)
else:
self.status = status(status.VALID)
def get_cert(self, cert_serial):
cur = conn.cursor()
cur.execute(
"""SELECT ssh_certs.nvalue AS cert,
revoked_ssh_certs.nvalue AS revoked
FROM ssh_certs
LEFT JOIN revoked_ssh_certs USING(nkey)
WHERE nkey=?""",
(cert_serial,),
)
if cur.rowcount > 0:
cert = cur.fetchone()
else:
cert = None
cur.close()
return cert
def get_public_key_params(self, public_key):
if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey):
key_type = "ECDSA"
elif isinstance(public_key, asymmetric.ed25519.Ed25519PublicKey):
key_type = "ED25519"
elif isinstance(public_key, asymmetric.rsa.RSAPublicKey):
key_type = "RSA"
key_str = public_key.public_bytes(
serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
)
key_data = key_str.strip().split()[1]
digest = hashes.Hash(hashes.SHA256())
digest.update(base64.b64decode(key_data))
hash_sha256 = digest.finalize()
key_hash = base64.b64encode(hash_sha256)
return key_str, key_type, key_hash
class status:
REVOKED = 1
EXPIRED = 2
VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"

View file

@ -1,172 +0,0 @@
import binascii
import dateutil
import json
import mariadb
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from datetime import datetime, timedelta, timezone
from models.config import config
config()
conn = mariadb.connect(
host=config.database_host,
user=config.database_user,
password=config.database_password,
database=config.database_name,
)
class list:
certs = []
def __new__(cls, sort_key=None):
cur = conn.cursor()
cur.execute(
"""SELECT x509_certs.nvalue AS cert,
x509_certs_data.nvalue AS data,
revoked_x509_certs.nvalue AS revoked
FROM x509_certs
INNER JOIN x509_certs_data USING(nkey)
LEFT JOIN revoked_x509_certs USING(nkey)"""
)
for result in cur:
cert_object = cert(result)
cls.certs.append(cert_object)
cur.close()
if sort_key is not None:
cls.certs.sort(key=lambda item: getattr(item, sort_key))
return cls.certs
class cert:
def __init__(self, cert):
(cert_der, cert_data_raw, cert_revoked_raw) = cert
cert_data = json.loads(cert_data_raw)
if cert_revoked_raw is not None:
cert_revoked = json.loads(cert_revoked_raw)
else:
cert_revoked = None
self.load(cert_der, cert_data, cert_revoked)
@classmethod
def from_serial(cls, serial):
return cls(cert=cls.get_cert(cls, serial))
def load(self, cert_der, cert_data, cert_revoked):
cert = x509.load_der_x509_certificate(cert_der)
self.pem = cert.public_bytes(serialization.Encoding.PEM)
self.serial = str(cert.serial_number)
self.sha256 = binascii.b2a_hex(cert.fingerprint(hashes.SHA256()))
self.sha1 = binascii.b2a_hex(cert.fingerprint(hashes.SHA1()))
self.md5 = binascii.b2a_hex(cert.fingerprint(hashes.MD5()))
self.pub_key = cert.public_key().public_bytes(
serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo
)
self.pub_alg = cert.public_key_algorithm_oid._name
self.sig_alg = cert.signature_algorithm_oid._name
self.issuer = cert.issuer.rfc4514_string()
self.subject = cert.subject.rfc4514_string({x509.NameOID.EMAIL_ADDRESS: "E"})
self.not_before = cert.not_valid_before_utc.replace(microsecond=0)
self.not_after = cert.not_valid_after_utc.replace(microsecond=0)
try:
san_data = cert.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
self.san_names = self.get_sans(san_data)
except x509.extensions.ExtensionNotFound:
self.san_names = []
self.provisioner = cert_data.get("provisioner", None)
if cert_revoked is not None:
self.revoked_at = dateutil.parser.isoparse(
cert_revoked.get("RevokedAt")
).replace(microsecond=0)
else:
self.revoked_at = None
now_with_tz = datetime.utcnow().replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED)
elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED)
else:
self.status = status(status.VALID)
def get_cert(self, cert_serial):
cur = conn.cursor()
cur.execute(
"""SELECT x509_certs.nvalue AS cert,
x509_certs_data.nvalue AS data,
revoked_x509_certs.nvalue AS revoked
FROM x509_certs
INNER JOIN x509_certs_data USING(nkey)
LEFT JOIN revoked_x509_certs USING(nkey)
WHERE nkey=?""",
(cert_serial,),
)
if cur.rowcount > 0:
cert = cur.fetchone()
else:
cert = None
cur.close()
return cert
def get_sans(self, san_data):
sans = []
for san_value in san_data.value:
san = {}
if isinstance(san_value, x509.general_name.DNSName):
san["type"] = "DNS"
elif isinstance(san_value, x509.general_name.UniformResourceIdentifier):
san["type"] = "URI"
elif isinstance(san_value, x509.general_name.RFC822Name):
san["type"] = "Email"
elif isinstance(san_value, x509.general_name.IPAddress):
san["type"] = "IP"
elif isinstance(san_value, x509.general_name.DirectoryName):
san["type"] = "DirectoryName"
elif isinstance(san_value, x509.general_name.RegisteredID):
san["type"] = "RegisteredID"
elif isinstance(san_value, x509.general_name.OtherName):
san["type"] = "Other ({san_value.type_id})"
else:
continue
san["value"] = san_value.value
sans.append(san)
return sans
class status:
REVOKED = 1
EXPIRED = 2
VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"

View file

@ -1,5 +1,2 @@
PyYAML
cryptography
mariadb
python-dateutil
tabulate

View file

@ -1,16 +1,17 @@
#!/usr/bin/env python3
import argparse
import os
import sys
import yaml
import requests
from urllib.parse import urljoin
from datetime import datetime, timedelta, timezone
from tabulate import tabulate
from models import ssh_cert, x509_cert
from config import config
config()
def delta_text(delta):
s = 's'[:abs(delta.days)^1]
s = "s"[: abs(delta.days) ^ 1]
if delta < timedelta(days=-1):
return f"in {abs(delta.days)} day{s}"
@ -22,38 +23,53 @@ def delta_text(delta):
return f"{delta.days} day{s} ago"
def fetch_api(endpoint, params={}):
try:
results = requests.get(urljoin(config.url, endpoint), params=params)
results.raise_for_status()
except requests.HTTPError as e:
raise e
except requests.Timeout as e:
raise e
# request took too long
return results.json()
def list_ssh_certs(sort_key, revoked=False, expired=False):
cert_list = ssh_cert.list(sort_key=sort_key)
params = {
"sort_key": sort_key,
"revoked": revoked,
"expired": expired,
}
cert_list = fetch_api("ssh/certs", params=params)
cert_tbl = []
for cert in cert_list:
if cert.status.value == ssh_cert.status.EXPIRED and not expired:
continue
if cert.status.value == ssh_cert.status.REVOKED and not revoked:
continue
cert_row = {}
cert_row["Serial"] = cert.serial
cert_row["Type"] = cert.type
cert_row["Key ID"] = cert.key_id
principals_count = len(cert.principals)
principals_list = [x.decode() for x in cert.principals]
cert_row["Serial"] = cert["serial"]
cert_row["Type"] = cert["type"]
cert_row["Key ID"] = cert["key_id"]
principals_count = len(cert["principals"])
if principals_count > 2:
principals = principals_list[:2] + [f"+{principals_count - 2} more"]
principals = cert["principals"][:2] + [f"+{principals_count - 2} more"]
else:
principals = principals_list
principals = cert["principals"]
cert_row["Principals"] = "\n".join(principals)
now_with_tz = datetime.utcnow().replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
now_with_tz = datetime.now(timezone.utc).replace(microsecond=0)
if cert.revoked_at is not None:
delta = now_with_tz - cert.revoked_at
if cert["revoked_at"] is not None:
delta = now_with_tz - datetime.fromtimestamp(
cert["revoked_at"], tz=timezone.utc
)
else:
delta = now_with_tz - cert.not_after
delta = now_with_tz - datetime.fromtimestamp(
cert["not_after"], tz=timezone.utc
)
cert_row["Expires"] = delta_text(delta).capitalize()
cert_row["Status"] = cert.status
cert_row["Status"] = cert["status"]
cert_tbl.append(cert_row)
@ -61,87 +77,104 @@ def list_ssh_certs(sort_key, revoked=False, expired=False):
def get_ssh_cert(serial):
cert = ssh_cert.cert.from_serial(serial)
cert = fetch_api(f"ssh/certs/{serial}")
if cert is None:
return
cert_tbl = []
cert_tbl.append(["Serial", cert.serial])
cert_tbl.append(["Certificate type", cert.type])
cert_tbl.append(["Certificate key type", cert.alg.decode()])
public_key = f"{cert.public_key_type} SHA256:{cert.public_key_hash.decode()}"
cert_tbl.append(["Serial", cert["serial"]])
cert_tbl.append(["Certificate type", cert["type"]])
cert_tbl.append(["Certificate key type", cert["alg"]])
public_key = f"{cert['public_key_type']} SHA256:{cert['public_key_hash']}"
cert_tbl.append(["Public key", public_key])
signing_key = f"{cert.signing_key_type} SHA256:{cert.signing_key_hash.decode()}"
signing_key = f"{cert['signing_key_type']} SHA256:{cert['signing_key_hash']}"
cert_tbl.append(["Signing key", signing_key])
cert_tbl.append(["Key ID", cert.key_id.decode()])
principals = [x.decode() for x in cert.principals]
cert_tbl.append(["Principals", "\n".join(principals)])
cert_tbl.append(["Key ID", cert["key_id"]])
cert_tbl.append(["Principals", "\n".join(cert["principals"])])
now_with_tz = datetime.utcnow().replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
now_with_tz = datetime.now(timezone.utc).replace(microsecond=0)
delta_after = now_with_tz - datetime.fromtimestamp(
cert["not_after"], tz=timezone.utc
)
delta_before = now_with_tz - datetime.fromtimestamp(
cert["not_before"], tz=timezone.utc
)
delta_after = now_with_tz - cert.not_after
delta_before = now_with_tz - cert.not_before
cert_tbl.append(
["Not valid before", f"{cert.not_before} ({delta_text(delta_before)})"]
[
"Not valid before",
f"{datetime.fromtimestamp(cert['not_before']).astimezone()} ({delta_text(delta_before)})",
]
)
cert_tbl.append(
["Not valid after", f"{cert.not_after} ({delta_text(delta_after)})"]
[
"Not valid after",
f"{datetime.fromtimestamp(cert['not_after']).astimezone()} ({delta_text(delta_after)})",
]
)
if cert.revoked_at is not None:
delta_revoked = now_with_tz - cert.revoked_at
cert_tbl.append(
["Revoked at", f"{cert.revoked_at} ({delta_text(delta_revoked)})"]
if cert["revoked_at"] is not None:
delta_revoked = now_with_tz - datetime.fromtimestamp(
cert["revoked_at"], tz=timezone.utc
)
cert_tbl.append(["Valid for", f"{delta_revoked.days} days"])
else:
cert_tbl.append(["Valid for", f"{abs(delta_after.days)} days"])
extensions = [x.decode() for x in cert.extensions]
cert_tbl.append(["Extensions", "\n".join(extensions)])
# cert_tbl.append(["Signing key", cert.signing_key.decode()])
cert_tbl.append(["Status", cert.status])
cert_tbl.append(
[
"Revoked at",
f"{datetime.fromtimestamp(cert['revoked_at']).astimezone()} ({delta_text(delta_revoked)})",
]
)
cert_tbl.append(["Extensions", "\n".join(cert["extensions"])])
#cert_tbl.append(["Signing key", cert["signing_key"]])
cert_tbl.append(["Status", cert["status"]])
print(tabulate(cert_tbl, tablefmt="fancy_grid"))
def dump_ssh_cert(serial):
cert = ssh_cert.cert.from_serial(serial)
print(cert.public_identity.decode())
cert = fetch_api(f"ssh/certs/{serial}")
if cert is None:
return
print(cert["public_identity"])
def list_x509_certs(sort_key, revoked=False, expired=False):
cert_list = x509_cert.list(sort_key=sort_key)
params = {
"sort_key": sort_key,
"revoked": revoked,
"expired": expired,
}
cert_list = fetch_api(f"x509/certs", params=params)
cert_tbl = []
for cert in cert_list:
if cert.status.value == x509_cert.status.EXPIRED and not expired:
continue
if cert.status.value == x509_cert.status.REVOKED and not revoked:
continue
cert_row = {}
cert_row["Serial"] = cert.serial
cert_row["Serial"] = cert["serial"]
cert_row["Subject/Subject Alt Names (SAN)"] = "\n".join(
[
"%.33s" % x
for x in [cert.subject]
+ [f"{x['type']}: {x['value']}" for x in cert.san_names]
for x in [cert["subject"]]
+ [f"{x['type']}: {x['value']}" for x in cert["san_names"]]
]
)
cert_row["Provisioner"] = (
f"{cert.provisioner['name']} ({cert.provisioner['type']})"
f"{cert['provisioner']['name']} ({cert['provisioner']['type']})"
)
now_with_tz = datetime.utcnow().replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
now_with_tz = datetime.now(timezone.utc).replace(microsecond=0)
if cert.revoked_at is not None:
delta = now_with_tz - cert.revoked_at
if cert["revoked_at"] is not None:
delta = now_with_tz - datetime.fromtimestamp(
cert["revoked_at"], tz=timezone.utc
)
else:
delta = now_with_tz - cert.not_after
delta = now_with_tz - datetime.fromtimestamp(
cert["not_after"], tz=timezone.utc
)
cert_row["Expires"] = delta_text(delta).capitalize()
cert_row["Status"] = cert.status
cert_row["Status"] = cert["status"]
cert_tbl.append(cert_row)
@ -149,64 +182,87 @@ def list_x509_certs(sort_key, revoked=False, expired=False):
def get_x509_cert(serial, show_cert=False, show_pubkey=False):
cert = x509_cert.cert.from_serial(serial)
cert_tbl = []
cert = fetch_api(f"x509/certs/{serial}")
cert_tbl.append(["Serial", cert.serial])
cert_tbl.append(["Subject", cert.subject])
if cert is None:
return
cert_tbl = []
cert_tbl.append(["Serial", cert["serial"]])
cert_tbl.append(["Subject", cert["subject"]])
cert_tbl.append(
[
"Subject Alt Names (SAN)",
"\n".join([f"{x['type']}: {x['value']}" for x in cert.san_names]),
"\n".join([f"{x['type']}: {x['value']}" for x in cert["san_names"]]),
]
)
cert_tbl.append(["Issuer", cert.issuer])
cert_tbl.append(["Issuer", cert["issuer"]])
now_with_tz = datetime.utcnow().replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
now_with_tz = datetime.now(timezone.utc).replace(microsecond=0)
delta_after = now_with_tz - datetime.fromtimestamp(
cert["not_after"], tz=timezone.utc
)
delta_before = now_with_tz - datetime.fromtimestamp(
cert["not_before"], tz=timezone.utc
)
delta_after = now_with_tz - cert.not_after
delta_before = now_with_tz - cert.not_before
cert_tbl.append(
["Not valid before", f"{cert.not_before} ({delta_text(delta_before)})"]
[
"Not valid before",
f"{datetime.fromtimestamp(cert['not_before']).astimezone()} ({delta_text(delta_before)})",
]
)
cert_tbl.append(
["Not valid after", f"{cert.not_after} ({delta_text(delta_after)})"]
[
"Not valid after",
f"{datetime.fromtimestamp(cert['not_after']).astimezone()} ({delta_text(delta_after)})",
]
)
if cert.revoked_at is not None:
delta_revoked = now_with_tz - cert.revoked_at
if cert["revoked_at"] is not None:
delta_revoked = now_with_tz - datetime.fromtimestamp(
cert["revoked_at"], tz=timezone.utc
)
cert_tbl.append(
["Revoked at", f"{cert.revoked_at} ({delta_text(delta_revoked)})"]
[
"Revoked at",
f"{datetime.fromtimestamp(cert['revoked_at']).astimezone()} ({delta_text(delta_revoked)})",
]
)
cert_tbl.append(["Valid for", f"{delta_revoked.days} days"])
else:
cert_tbl.append(["Valid for", f"{abs(delta_after.days)} days"])
cert_tbl.append(
["Provisioner", f"{cert.provisioner['name']} ({cert.provisioner['type']})"]
[
"Provisioner",
f"{cert['provisioner']['name']} ({cert['provisioner']['type']})",
]
)
fingerprints = []
fingerprints.append(f"MD5: {cert.md5.decode()}")
fingerprints.append(f"SHA-1: {cert.sha1.decode()}")
fingerprints.append(f"SHA-256: {cert.sha256.decode()}")
fingerprints.append(f"MD5: {cert['md5']}")
fingerprints.append(f"SHA-1: {cert['sha1']}")
fingerprints.append(f"SHA-256: {cert['sha256']}")
cert_tbl.append(["Fingerprints", "\n".join(fingerprints)])
cert_tbl.append(["Public key algorithm", cert.pub_alg])
cert_tbl.append(["Signature algorithm", cert.sig_alg])
cert_tbl.append(["Status", cert.status])
# cert_tbl.append(["Extensions", cert.extensions])
cert_tbl.append(["Public key algorithm", cert["pub_alg"]])
cert_tbl.append(["Signature algorithm", cert["sig_alg"]])
cert_tbl.append(["Status", cert["status"]])
if show_pubkey:
cert_tbl.append(["Public key", cert.pub_key.decode("utf-8")])
cert_tbl.append(["Public key", cert["pub_key"]])
if show_cert:
cert_tbl.append(["PEM", cert.pem.decode("utf-8")])
cert_tbl.append(["PEM", cert["pem"]])
print(tabulate(cert_tbl, tablefmt="fancy_grid"))
def dump_x509_cert(serial, cert_format="pem"):
cert = x509_cert.cert.from_serial(serial)
print(cert.pem.decode("utf-8").rstrip())
cert = fetch_api(f"x509/certs/{serial}")
if cert is None:
return
print(cert["pem"].rstrip())
parser = argparse.ArgumentParser(description="Step CA Inspector")