Refactor to use step-ca-inspector API server

This commit is contained in:
Benjamin Collet 2025-03-24 08:06:19 +01:00
parent 27a39e6bbc
commit 1583cda39b
Signed by: bcollet
SSH key fingerprint: SHA256:8UJspOIcCOS+MtSOcnuq2HjKFube4ox1s/+A62ixov4
5 changed files with 162 additions and 441 deletions

View file

@ -23,3 +23,9 @@ class config:
for k, v in cfg.items(): for k, v in cfg.items():
setattr(self, k, v) setattr(self, k, v)
for setting in ["url"]:
if not hasattr(self, setting):
# FIXME: Raise instead
print(f"Mandatory setting {setting} is not configured.")
sys.exit(1)

View file

@ -1,166 +0,0 @@
import base64
import dateutil
import json
import mariadb
from cryptography import x509
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
from datetime import datetime, timedelta, timezone
from models.config import config
from struct import unpack
config()
conn = mariadb.connect(
host=config.database_host,
user=config.database_user,
password=config.database_password,
database=config.database_name,
)
class list:
certs = []
def __new__(cls, sort_key=None):
cur = conn.cursor()
cur.execute(
"""SELECT ssh_certs.nvalue AS cert,
revoked_ssh_certs.nvalue AS revoked
FROM ssh_certs
LEFT JOIN revoked_ssh_certs USING(nkey)"""
)
for result in cur:
cert_object = cert(result)
cls.certs.append(cert_object)
cur.close()
if sort_key is not None:
cls.certs.sort(key=lambda item: getattr(item, sort_key))
return cls.certs
class cert:
def __init__(self, cert):
(cert_raw, cert_revoked_raw) = cert
size = unpack(">I", cert_raw[:4])[0] + 4
alg = cert_raw[4:size]
cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)])
if cert_revoked_raw is not None:
cert_revoked = json.loads(cert_revoked_raw)
else:
cert_revoked = None
self.load(cert_pub_id, cert_revoked, alg)
@classmethod
def from_serial(cls, serial):
return cls(cert=cls.get_cert(cls, serial))
def load(self, cert_pub_id, cert_revoked, cert_alg):
cert = serialization.load_ssh_public_identity(cert_pub_id)
self.serial = cert.serial
self.alg = cert_alg
if cert.type == serialization.SSHCertificateType.USER:
self.type = "User"
self.key_id = cert.key_id
self.principals = cert.valid_principals
self.not_after = datetime.fromtimestamp(cert.valid_before).replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
self.not_before = datetime.fromtimestamp(cert.valid_after).replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
# TODO: Implement critical options parsing
# cert.critical_options
self.extensions = cert.extensions
(self.signing_key, self.signing_key_type, self.signing_key_hash) = (
self.get_public_key_params(cert.signature_key())
)
(self.public_key, self.public_key_type, self.public_key_hash) = (
self.get_public_key_params(cert.public_key())
)
self.public_identity = cert.public_bytes()
if cert_revoked is not None:
self.revoked_at = dateutil.parser.isoparse(
cert_revoked.get("RevokedAt")
).replace(microsecond=0)
else:
self.revoked_at = None
now_with_tz = datetime.utcnow().replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED)
elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED)
else:
self.status = status(status.VALID)
def get_cert(self, cert_serial):
cur = conn.cursor()
cur.execute(
"""SELECT ssh_certs.nvalue AS cert,
revoked_ssh_certs.nvalue AS revoked
FROM ssh_certs
LEFT JOIN revoked_ssh_certs USING(nkey)
WHERE nkey=?""",
(cert_serial,),
)
if cur.rowcount > 0:
cert = cur.fetchone()
else:
cert = None
cur.close()
return cert
def get_public_key_params(self, public_key):
if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey):
key_type = "ECDSA"
elif isinstance(public_key, asymmetric.ed25519.Ed25519PublicKey):
key_type = "ED25519"
elif isinstance(public_key, asymmetric.rsa.RSAPublicKey):
key_type = "RSA"
key_str = public_key.public_bytes(
serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
)
key_data = key_str.strip().split()[1]
digest = hashes.Hash(hashes.SHA256())
digest.update(base64.b64decode(key_data))
hash_sha256 = digest.finalize()
key_hash = base64.b64encode(hash_sha256)
return key_str, key_type, key_hash
class status:
REVOKED = 1
EXPIRED = 2
VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"

View file

@ -1,172 +0,0 @@
import binascii
import dateutil
import json
import mariadb
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from datetime import datetime, timedelta, timezone
from models.config import config
config()
conn = mariadb.connect(
host=config.database_host,
user=config.database_user,
password=config.database_password,
database=config.database_name,
)
class list:
certs = []
def __new__(cls, sort_key=None):
cur = conn.cursor()
cur.execute(
"""SELECT x509_certs.nvalue AS cert,
x509_certs_data.nvalue AS data,
revoked_x509_certs.nvalue AS revoked
FROM x509_certs
INNER JOIN x509_certs_data USING(nkey)
LEFT JOIN revoked_x509_certs USING(nkey)"""
)
for result in cur:
cert_object = cert(result)
cls.certs.append(cert_object)
cur.close()
if sort_key is not None:
cls.certs.sort(key=lambda item: getattr(item, sort_key))
return cls.certs
class cert:
def __init__(self, cert):
(cert_der, cert_data_raw, cert_revoked_raw) = cert
cert_data = json.loads(cert_data_raw)
if cert_revoked_raw is not None:
cert_revoked = json.loads(cert_revoked_raw)
else:
cert_revoked = None
self.load(cert_der, cert_data, cert_revoked)
@classmethod
def from_serial(cls, serial):
return cls(cert=cls.get_cert(cls, serial))
def load(self, cert_der, cert_data, cert_revoked):
cert = x509.load_der_x509_certificate(cert_der)
self.pem = cert.public_bytes(serialization.Encoding.PEM)
self.serial = str(cert.serial_number)
self.sha256 = binascii.b2a_hex(cert.fingerprint(hashes.SHA256()))
self.sha1 = binascii.b2a_hex(cert.fingerprint(hashes.SHA1()))
self.md5 = binascii.b2a_hex(cert.fingerprint(hashes.MD5()))
self.pub_key = cert.public_key().public_bytes(
serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo
)
self.pub_alg = cert.public_key_algorithm_oid._name
self.sig_alg = cert.signature_algorithm_oid._name
self.issuer = cert.issuer.rfc4514_string()
self.subject = cert.subject.rfc4514_string({x509.NameOID.EMAIL_ADDRESS: "E"})
self.not_before = cert.not_valid_before_utc.replace(microsecond=0)
self.not_after = cert.not_valid_after_utc.replace(microsecond=0)
try:
san_data = cert.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
self.san_names = self.get_sans(san_data)
except x509.extensions.ExtensionNotFound:
self.san_names = []
self.provisioner = cert_data.get("provisioner", None)
if cert_revoked is not None:
self.revoked_at = dateutil.parser.isoparse(
cert_revoked.get("RevokedAt")
).replace(microsecond=0)
else:
self.revoked_at = None
now_with_tz = datetime.utcnow().replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED)
elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED)
else:
self.status = status(status.VALID)
def get_cert(self, cert_serial):
cur = conn.cursor()
cur.execute(
"""SELECT x509_certs.nvalue AS cert,
x509_certs_data.nvalue AS data,
revoked_x509_certs.nvalue AS revoked
FROM x509_certs
INNER JOIN x509_certs_data USING(nkey)
LEFT JOIN revoked_x509_certs USING(nkey)
WHERE nkey=?""",
(cert_serial,),
)
if cur.rowcount > 0:
cert = cur.fetchone()
else:
cert = None
cur.close()
return cert
def get_sans(self, san_data):
sans = []
for san_value in san_data.value:
san = {}
if isinstance(san_value, x509.general_name.DNSName):
san["type"] = "DNS"
elif isinstance(san_value, x509.general_name.UniformResourceIdentifier):
san["type"] = "URI"
elif isinstance(san_value, x509.general_name.RFC822Name):
san["type"] = "Email"
elif isinstance(san_value, x509.general_name.IPAddress):
san["type"] = "IP"
elif isinstance(san_value, x509.general_name.DirectoryName):
san["type"] = "DirectoryName"
elif isinstance(san_value, x509.general_name.RegisteredID):
san["type"] = "RegisteredID"
elif isinstance(san_value, x509.general_name.OtherName):
san["type"] = "Other ({san_value.type_id})"
else:
continue
san["value"] = san_value.value
sans.append(san)
return sans
class status:
REVOKED = 1
EXPIRED = 2
VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"

View file

@ -1,5 +1,2 @@
PyYAML
cryptography
mariadb
python-dateutil python-dateutil
tabulate tabulate

View file

@ -1,16 +1,17 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import os import requests
import sys from urllib.parse import urljoin
import yaml
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from tabulate import tabulate from tabulate import tabulate
from models import ssh_cert, x509_cert from config import config
config()
def delta_text(delta): def delta_text(delta):
s = 's'[:abs(delta.days)^1] s = "s"[: abs(delta.days) ^ 1]
if delta < timedelta(days=-1): if delta < timedelta(days=-1):
return f"in {abs(delta.days)} day{s}" return f"in {abs(delta.days)} day{s}"
@ -22,38 +23,53 @@ def delta_text(delta):
return f"{delta.days} day{s} ago" return f"{delta.days} day{s} ago"
def fetch_api(endpoint, params={}):
try:
results = requests.get(urljoin(config.url, endpoint), params=params)
results.raise_for_status()
except requests.HTTPError as e:
raise e
except requests.Timeout as e:
raise e
# request took too long
return results.json()
def list_ssh_certs(sort_key, revoked=False, expired=False): def list_ssh_certs(sort_key, revoked=False, expired=False):
cert_list = ssh_cert.list(sort_key=sort_key) params = {
"sort_key": sort_key,
"revoked": revoked,
"expired": expired,
}
cert_list = fetch_api("ssh/certs", params=params)
cert_tbl = [] cert_tbl = []
for cert in cert_list: for cert in cert_list:
if cert.status.value == ssh_cert.status.EXPIRED and not expired:
continue
if cert.status.value == ssh_cert.status.REVOKED and not revoked:
continue
cert_row = {} cert_row = {}
cert_row["Serial"] = cert.serial cert_row["Serial"] = cert["serial"]
cert_row["Type"] = cert.type cert_row["Type"] = cert["type"]
cert_row["Key ID"] = cert.key_id cert_row["Key ID"] = cert["key_id"]
principals_count = len(cert.principals) principals_count = len(cert["principals"])
principals_list = [x.decode() for x in cert.principals]
if principals_count > 2: if principals_count > 2:
principals = principals_list[:2] + [f"+{principals_count - 2} more"] principals = cert["principals"][:2] + [f"+{principals_count - 2} more"]
else: else:
principals = principals_list principals = cert["principals"]
cert_row["Principals"] = "\n".join(principals) cert_row["Principals"] = "\n".join(principals)
now_with_tz = datetime.utcnow().replace( now_with_tz = datetime.now(timezone.utc).replace(microsecond=0)
tzinfo=timezone(offset=timedelta()), microsecond=0
)
if cert.revoked_at is not None: if cert["revoked_at"] is not None:
delta = now_with_tz - cert.revoked_at delta = now_with_tz - datetime.fromtimestamp(
cert["revoked_at"], tz=timezone.utc
)
else: else:
delta = now_with_tz - cert.not_after delta = now_with_tz - datetime.fromtimestamp(
cert["not_after"], tz=timezone.utc
)
cert_row["Expires"] = delta_text(delta).capitalize() cert_row["Expires"] = delta_text(delta).capitalize()
cert_row["Status"] = cert.status cert_row["Status"] = cert["status"]
cert_tbl.append(cert_row) cert_tbl.append(cert_row)
@ -61,87 +77,104 @@ def list_ssh_certs(sort_key, revoked=False, expired=False):
def get_ssh_cert(serial): def get_ssh_cert(serial):
cert = ssh_cert.cert.from_serial(serial) cert = fetch_api(f"ssh/certs/{serial}")
if cert is None:
return
cert_tbl = [] cert_tbl = []
cert_tbl.append(["Serial", cert["serial"]])
cert_tbl.append(["Serial", cert.serial]) cert_tbl.append(["Certificate type", cert["type"]])
cert_tbl.append(["Certificate type", cert.type]) cert_tbl.append(["Certificate key type", cert["alg"]])
cert_tbl.append(["Certificate key type", cert.alg.decode()]) public_key = f"{cert['public_key_type']} SHA256:{cert['public_key_hash']}"
public_key = f"{cert.public_key_type} SHA256:{cert.public_key_hash.decode()}"
cert_tbl.append(["Public key", public_key]) cert_tbl.append(["Public key", public_key])
signing_key = f"{cert.signing_key_type} SHA256:{cert.signing_key_hash.decode()}" signing_key = f"{cert['signing_key_type']} SHA256:{cert['signing_key_hash']}"
cert_tbl.append(["Signing key", signing_key]) cert_tbl.append(["Signing key", signing_key])
cert_tbl.append(["Key ID", cert.key_id.decode()]) cert_tbl.append(["Key ID", cert["key_id"]])
principals = [x.decode() for x in cert.principals] cert_tbl.append(["Principals", "\n".join(cert["principals"])])
cert_tbl.append(["Principals", "\n".join(principals)])
now_with_tz = datetime.utcnow().replace( now_with_tz = datetime.now(timezone.utc).replace(microsecond=0)
tzinfo=timezone(offset=timedelta()), microsecond=0
delta_after = now_with_tz - datetime.fromtimestamp(
cert["not_after"], tz=timezone.utc
)
delta_before = now_with_tz - datetime.fromtimestamp(
cert["not_before"], tz=timezone.utc
) )
delta_after = now_with_tz - cert.not_after
delta_before = now_with_tz - cert.not_before
cert_tbl.append( cert_tbl.append(
["Not valid before", f"{cert.not_before} ({delta_text(delta_before)})"] [
"Not valid before",
f"{datetime.fromtimestamp(cert['not_before']).astimezone()} ({delta_text(delta_before)})",
]
) )
cert_tbl.append( cert_tbl.append(
["Not valid after", f"{cert.not_after} ({delta_text(delta_after)})"] [
"Not valid after",
f"{datetime.fromtimestamp(cert['not_after']).astimezone()} ({delta_text(delta_after)})",
]
) )
if cert.revoked_at is not None:
delta_revoked = now_with_tz - cert.revoked_at if cert["revoked_at"] is not None:
cert_tbl.append( delta_revoked = now_with_tz - datetime.fromtimestamp(
["Revoked at", f"{cert.revoked_at} ({delta_text(delta_revoked)})"] cert["revoked_at"], tz=timezone.utc
) )
cert_tbl.append(["Valid for", f"{delta_revoked.days} days"]) cert_tbl.append(
else: [
cert_tbl.append(["Valid for", f"{abs(delta_after.days)} days"]) "Revoked at",
extensions = [x.decode() for x in cert.extensions] f"{datetime.fromtimestamp(cert['revoked_at']).astimezone()} ({delta_text(delta_revoked)})",
cert_tbl.append(["Extensions", "\n".join(extensions)]) ]
# cert_tbl.append(["Signing key", cert.signing_key.decode()]) )
cert_tbl.append(["Status", cert.status])
cert_tbl.append(["Extensions", "\n".join(cert["extensions"])])
#cert_tbl.append(["Signing key", cert["signing_key"]])
cert_tbl.append(["Status", cert["status"]])
print(tabulate(cert_tbl, tablefmt="fancy_grid")) print(tabulate(cert_tbl, tablefmt="fancy_grid"))
def dump_ssh_cert(serial): def dump_ssh_cert(serial):
cert = ssh_cert.cert.from_serial(serial) cert = fetch_api(f"ssh/certs/{serial}")
print(cert.public_identity.decode()) if cert is None:
return
print(cert["public_identity"])
def list_x509_certs(sort_key, revoked=False, expired=False): def list_x509_certs(sort_key, revoked=False, expired=False):
cert_list = x509_cert.list(sort_key=sort_key) params = {
"sort_key": sort_key,
"revoked": revoked,
"expired": expired,
}
cert_list = fetch_api(f"x509/certs", params=params)
cert_tbl = [] cert_tbl = []
for cert in cert_list: for cert in cert_list:
if cert.status.value == x509_cert.status.EXPIRED and not expired:
continue
if cert.status.value == x509_cert.status.REVOKED and not revoked:
continue
cert_row = {} cert_row = {}
cert_row["Serial"] = cert.serial cert_row["Serial"] = cert["serial"]
cert_row["Subject/Subject Alt Names (SAN)"] = "\n".join( cert_row["Subject/Subject Alt Names (SAN)"] = "\n".join(
[ [
"%.33s" % x "%.33s" % x
for x in [cert.subject] for x in [cert["subject"]]
+ [f"{x['type']}: {x['value']}" for x in cert.san_names] + [f"{x['type']}: {x['value']}" for x in cert["san_names"]]
] ]
) )
cert_row["Provisioner"] = ( cert_row["Provisioner"] = (
f"{cert.provisioner['name']} ({cert.provisioner['type']})" f"{cert['provisioner']['name']} ({cert['provisioner']['type']})"
) )
now_with_tz = datetime.utcnow().replace( now_with_tz = datetime.now(timezone.utc).replace(microsecond=0)
tzinfo=timezone(offset=timedelta()), microsecond=0
)
if cert.revoked_at is not None: if cert["revoked_at"] is not None:
delta = now_with_tz - cert.revoked_at delta = now_with_tz - datetime.fromtimestamp(
cert["revoked_at"], tz=timezone.utc
)
else: else:
delta = now_with_tz - cert.not_after delta = now_with_tz - datetime.fromtimestamp(
cert["not_after"], tz=timezone.utc
)
cert_row["Expires"] = delta_text(delta).capitalize() cert_row["Expires"] = delta_text(delta).capitalize()
cert_row["Status"] = cert.status cert_row["Status"] = cert["status"]
cert_tbl.append(cert_row) cert_tbl.append(cert_row)
@ -149,64 +182,87 @@ def list_x509_certs(sort_key, revoked=False, expired=False):
def get_x509_cert(serial, show_cert=False, show_pubkey=False): def get_x509_cert(serial, show_cert=False, show_pubkey=False):
cert = x509_cert.cert.from_serial(serial) cert = fetch_api(f"x509/certs/{serial}")
cert_tbl = []
cert_tbl.append(["Serial", cert.serial]) if cert is None:
cert_tbl.append(["Subject", cert.subject]) return
cert_tbl = []
cert_tbl.append(["Serial", cert["serial"]])
cert_tbl.append(["Subject", cert["subject"]])
cert_tbl.append( cert_tbl.append(
[ [
"Subject Alt Names (SAN)", "Subject Alt Names (SAN)",
"\n".join([f"{x['type']}: {x['value']}" for x in cert.san_names]), "\n".join([f"{x['type']}: {x['value']}" for x in cert["san_names"]]),
] ]
) )
cert_tbl.append(["Issuer", cert.issuer]) cert_tbl.append(["Issuer", cert["issuer"]])
now_with_tz = datetime.utcnow().replace( now_with_tz = datetime.now(timezone.utc).replace(microsecond=0)
tzinfo=timezone(offset=timedelta()), microsecond=0
delta_after = now_with_tz - datetime.fromtimestamp(
cert["not_after"], tz=timezone.utc
)
delta_before = now_with_tz - datetime.fromtimestamp(
cert["not_before"], tz=timezone.utc
) )
delta_after = now_with_tz - cert.not_after
delta_before = now_with_tz - cert.not_before
cert_tbl.append( cert_tbl.append(
["Not valid before", f"{cert.not_before} ({delta_text(delta_before)})"] [
"Not valid before",
f"{datetime.fromtimestamp(cert['not_before']).astimezone()} ({delta_text(delta_before)})",
]
) )
cert_tbl.append( cert_tbl.append(
["Not valid after", f"{cert.not_after} ({delta_text(delta_after)})"] [
"Not valid after",
f"{datetime.fromtimestamp(cert['not_after']).astimezone()} ({delta_text(delta_after)})",
]
) )
if cert.revoked_at is not None: if cert["revoked_at"] is not None:
delta_revoked = now_with_tz - cert.revoked_at delta_revoked = now_with_tz - datetime.fromtimestamp(
cert["revoked_at"], tz=timezone.utc
)
cert_tbl.append( cert_tbl.append(
["Revoked at", f"{cert.revoked_at} ({delta_text(delta_revoked)})"] [
"Revoked at",
f"{datetime.fromtimestamp(cert['revoked_at']).astimezone()} ({delta_text(delta_revoked)})",
]
) )
cert_tbl.append(["Valid for", f"{delta_revoked.days} days"]) cert_tbl.append(["Valid for", f"{delta_revoked.days} days"])
else: else:
cert_tbl.append(["Valid for", f"{abs(delta_after.days)} days"]) cert_tbl.append(["Valid for", f"{abs(delta_after.days)} days"])
cert_tbl.append( cert_tbl.append(
["Provisioner", f"{cert.provisioner['name']} ({cert.provisioner['type']})"] [
"Provisioner",
f"{cert['provisioner']['name']} ({cert['provisioner']['type']})",
]
) )
fingerprints = [] fingerprints = []
fingerprints.append(f"MD5: {cert.md5.decode()}") fingerprints.append(f"MD5: {cert['md5']}")
fingerprints.append(f"SHA-1: {cert.sha1.decode()}") fingerprints.append(f"SHA-1: {cert['sha1']}")
fingerprints.append(f"SHA-256: {cert.sha256.decode()}") fingerprints.append(f"SHA-256: {cert['sha256']}")
cert_tbl.append(["Fingerprints", "\n".join(fingerprints)]) cert_tbl.append(["Fingerprints", "\n".join(fingerprints)])
cert_tbl.append(["Public key algorithm", cert.pub_alg]) cert_tbl.append(["Public key algorithm", cert["pub_alg"]])
cert_tbl.append(["Signature algorithm", cert.sig_alg]) cert_tbl.append(["Signature algorithm", cert["sig_alg"]])
cert_tbl.append(["Status", cert.status]) cert_tbl.append(["Status", cert["status"]])
# cert_tbl.append(["Extensions", cert.extensions])
if show_pubkey: if show_pubkey:
cert_tbl.append(["Public key", cert.pub_key.decode("utf-8")]) cert_tbl.append(["Public key", cert["pub_key"]])
if show_cert: if show_cert:
cert_tbl.append(["PEM", cert.pem.decode("utf-8")]) cert_tbl.append(["PEM", cert["pem"]])
print(tabulate(cert_tbl, tablefmt="fancy_grid")) print(tabulate(cert_tbl, tablefmt="fancy_grid"))
def dump_x509_cert(serial, cert_format="pem"): def dump_x509_cert(serial, cert_format="pem"):
cert = x509_cert.cert.from_serial(serial) cert = fetch_api(f"x509/certs/{serial}")
print(cert.pem.decode("utf-8").rstrip())
if cert is None:
return
print(cert["pem"].rstrip())
parser = argparse.ArgumentParser(description="Step CA Inspector") parser = argparse.ArgumentParser(description="Step CA Inspector")