Initial commit

This commit is contained in:
Benjamin Collet 2025-01-07 09:50:19 +01:00
commit 6c02075a57
Signed by: bcollet
SSH key fingerprint: SHA256:8UJspOIcCOS+MtSOcnuq2HjKFube4ox1s/+A62ixov4
5 changed files with 574 additions and 0 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
.python-version
__pycache__/

25
models/config.py Normal file
View file

@ -0,0 +1,25 @@
import os
import sys
import yaml
class config:
@classmethod
def __init__(self):
for config_path in (
os.path.expanduser("~/.config/step-ca-inspector"),
os.environ.get("STEP_CA_INSPECTOR_CONF"),
):
if config_path is None:
continue
try:
with open(os.path.join(config_path, "config.yaml")) as ymlfile:
cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)
break
except IOError:
pass
else:
print("No configuration file found")
sys.exit(1)
for k, v in cfg.items():
setattr(self, k, v)

156
models/ssh_cert.py Normal file
View file

@ -0,0 +1,156 @@
import base64
import dateutil
import json
import mariadb
from cryptography import x509
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
from datetime import datetime, timedelta, timezone
from models.config import config
from struct import unpack
config()
conn = mariadb.connect(
host=config.database_host,
user=config.database_user,
password=config.database_password,
database=config.database_name,
)
class list:
certs = []
def __new__(cls, sort_key=None):
cur = conn.cursor()
cur.execute("SELECT nkey FROM ssh_certs")
for (cert_serial,) in cur:
cert_object = cert(cert_serial)
cls.certs.append(cert_object)
cur.close()
if sort_key is not None:
cls.certs.sort(key=lambda item: getattr(item, sort_key))
return cls.certs
class cert:
def __init__(self, serial):
cert_raw = self.get_cert(serial)
size = unpack(">I", cert_raw[:4])[0] + 4
alg = cert_raw[4:size]
cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)])
cert_revoked = self.get_cert_revoked(serial)
self.load(cert_pub_id, cert_revoked, alg)
def load(self, cert_pub_id, cert_revoked, cert_alg):
cert = serialization.load_ssh_public_identity(cert_pub_id)
self.serial = cert.serial
self.alg = cert_alg
if cert.type == serialization.SSHCertificateType.USER:
self.type = "User"
self.key_id = cert.key_id
self.principals = cert.valid_principals
self.not_after = datetime.fromtimestamp(cert.valid_before).replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
self.not_before = datetime.fromtimestamp(cert.valid_after).replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
# TODO: Implement critical options parsing
# cert.critical_options
self.extensions = cert.extensions
(self.signing_key, self.signing_key_type, self.signing_key_hash) = (
self.get_public_key_params(cert.signature_key())
)
(self.public_key, self.public_key_type, self.public_key_hash) = (
self.get_public_key_params(cert.public_key())
)
self.public_identity = cert.public_bytes()
if cert_revoked is not None:
self.revoked_at = dateutil.parser.isoparse(
cert_revoked.get("RevokedAt")
).replace(microsecond=0)
else:
self.revoked_at = None
now_with_tz = datetime.utcnow().replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED)
elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED)
else:
self.status = status(status.VALID)
def get_cert(self, cert_serial):
cur = conn.cursor()
cur.execute("SELECT nvalue FROM ssh_certs WHERE nkey=?", (cert_serial,))
if cur.rowcount > 0:
(cert,) = cur.fetchone()
else:
cert = None
cur.close()
return cert
def get_cert_revoked(self, cert_serial):
cur = conn.cursor()
cur.execute("SELECT nvalue FROM revoked_ssh_certs WHERE nkey=?", (cert_serial,))
if cur.rowcount > 0:
(cert_revoked_raw,) = cur.fetchone()
cert_revoked = json.loads(cert_revoked_raw)
else:
cert_revoked = None
cur.close()
return cert_revoked
def get_public_key_params(self, public_key):
if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey):
key_type = "ECDSA"
elif isinstance(public_key, asymmetric.ed25519.Ed25519PublicKey):
key_type = "ED25519"
elif isinstance(public_key, asymmetric.rsa.RSAPublicKey):
key_type = "RSA"
key_str = public_key.public_bytes(
serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
)
key_data = key_str.strip().split()[1]
digest = hashes.Hash(hashes.SHA256())
digest.update(base64.b64decode(key_data))
hash_sha256 = digest.finalize()
key_hash = base64.b64encode(hash_sha256)
return key_str, key_type, key_hash
class status:
REVOKED = 1
EXPIRED = 2
VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"

137
models/x509_cert.py Normal file
View file

@ -0,0 +1,137 @@
import binascii
import dateutil
import json
import mariadb
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from datetime import datetime, timedelta, timezone
from models.config import config
config()
conn = mariadb.connect(
host = config.database_host,
user = config.database_user,
password = config.database_password,
database = config.database_name
)
class list:
certs = []
def __new__(cls, sort_key=None):
cur = conn.cursor()
cur.execute("SELECT nkey FROM x509_certs")
for (cert_serial,) in cur:
cert_object = cert(cert_serial)
cls.certs.append(cert_object)
cur.close()
if sort_key is not None:
cls.certs.sort(key=lambda item: getattr(item, sort_key))
return cls.certs
class cert:
def __init__(self, serial):
cert_der = self.get_cert(serial)
cert_data = self.get_cert_data(serial)
cert_revoked = self.get_cert_revoked(serial)
self.load(cert_der, cert_data, cert_revoked)
def load(self, cert_der, cert_data, cert_revoked):
cert = x509.load_der_x509_certificate(cert_der)
self.pem = cert.public_bytes(serialization.Encoding.PEM)
self.serial = str(cert.serial_number)
self.sha256 = binascii.b2a_hex(cert.fingerprint(hashes.SHA256()))
self.sha1 = binascii.b2a_hex(cert.fingerprint(hashes.SHA1()))
self.md5 = binascii.b2a_hex(cert.fingerprint(hashes.MD5()))
self.pub_alg = cert.public_key_algorithm_oid._name
self.sig_alg = cert.signature_algorithm_oid._name
self.issuer = cert.issuer.rfc4514_string()
self.subject = cert.subject.rfc4514_string({x509.NameOID.EMAIL_ADDRESS: "E"})
self.not_before = cert.not_valid_before_utc.replace(microsecond=0)
self.not_after = cert.not_valid_after_utc.replace(microsecond=0)
san_data = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName)
self.san_names = san_data.value.get_values_for_type(x509.GeneralName)
self.provisioner = cert_data.get("provisioner", None)
if cert_revoked is not None:
self.revoked_at = dateutil.parser.isoparse(
cert_revoked.get("RevokedAt")
).replace(microsecond=0)
else:
self.revoked_at = None
now_with_tz = datetime.utcnow().replace(
tzinfo=timezone(offset=timedelta()), microsecond=0
)
if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED)
elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED)
else:
self.status = status(status.VALID)
def get_cert(self, cert_serial):
cur = conn.cursor()
cur.execute("SELECT nvalue FROM x509_certs WHERE nkey=?", (cert_serial,))
if cur.rowcount > 0:
(cert,) = cur.fetchone()
else:
cert = None
cur.close()
return cert
def get_cert_data(self, cert_serial):
cur = conn.cursor()
cur.execute("SELECT nvalue FROM x509_certs_data WHERE nkey=?", (cert_serial,))
(cert_data_raw,) = cur.fetchone()
cur.close()
cert_data = json.loads(cert_data_raw)
return cert_data
def get_cert_revoked(self, cert_serial):
cur = conn.cursor()
cur.execute(
"SELECT nvalue FROM revoked_x509_certs WHERE nkey=?", (cert_serial,)
)
if cur.rowcount > 0:
(cert_revoked_raw,) = cur.fetchone()
cert_revoked = json.loads(cert_revoked_raw)
else:
cert_revoked = None
cur.close()
return cert_revoked
class status:
REVOKED = 1
EXPIRED = 2
VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"

254
step-ca-inspector.py Executable file
View file

@ -0,0 +1,254 @@
#!/usr/bin/env python3
import argparse
import os
import sys
import yaml
from tabulate import tabulate
from models import ssh_cert, x509_cert
def list_ssh_certs(sort_key, revoked=False, expired=False):
cert_list = ssh_cert.list(sort_key=sort_key)
cert_tbl = []
for cert in cert_list:
if cert.status.value == ssh_cert.status.EXPIRED and not expired:
continue
if cert.status.value == ssh_cert.status.REVOKED and not revoked:
continue
cert_row = {}
cert_row["Serial"] = cert.serial
cert_row["Type"] = cert.type
cert_row["Key ID"] = cert.key_id
principals_count = len(cert.principals)
principals_list = [x.decode() for x in cert.principals]
if principals_count > 2:
principals = principals_list[:2] + [f"+{principals_count - 2} more"]
cert_row["Principals"] = "\n".join(principals)
validity = []
validity.append(f"Not before: {cert.not_before}")
validity.append(f"Not after: {cert.not_after}")
if cert.revoked_at is not None:
validity.append(f"Revoked at: {cert.revoked_at}")
cert_row["Validity"] = "\n".join(validity)
cert_row["Status"] = cert.status
cert_tbl.append(cert_row)
print(tabulate(cert_tbl, headers="keys", tablefmt="fancy_grid"))
def get_ssh_cert(serial):
cert = ssh_cert.cert(serial)
cert_tbl = []
cert_tbl.append(["Serial", cert.serial])
cert_tbl.append(["Certificate type", cert.type])
cert_tbl.append(["Certificate key type", cert.alg.decode()])
public_key = f"{cert.public_key_type} SHA256:{cert.public_key_hash.decode()}"
cert_tbl.append(["Public key", public_key])
signing_key = f"{cert.signing_key_type} SHA256:{cert.signing_key_hash.decode()}"
cert_tbl.append(["Signing key", signing_key])
cert_tbl.append(["Key ID", cert.key_id.decode()])
principals = [x.decode() for x in cert.principals]
cert_tbl.append(["Principals", "\n".join(principals)])
cert_tbl.append(["Not valid before", cert.not_before])
cert_tbl.append(["Not valid after", cert.not_after])
if cert.revoked_at is not None:
cert_tbl.append(["Revoked at", cert.revoked_at])
extensions = [x.decode() for x in cert.extensions]
cert_tbl.append(["Extensions", "\n".join(extensions)])
# cert_tbl.append(["Signing key", cert.signing_key.decode()])
cert_tbl.append(["Status", cert.status])
print(tabulate(cert_tbl, tablefmt="fancy_grid"))
def dump_ssh_cert(serial):
cert = ssh_cert.cert(serial)
print(cert.public_identity.decode())
def list_x509_certs(sort_key, revoked=False, expired=False):
cert_list = x509_cert.list(sort_key=sort_key)
cert_tbl = []
for cert in cert_list:
if cert.status.value == x509_cert.status.EXPIRED and not expired:
continue
if cert.status.value == x509_cert.status.REVOKED and not revoked:
continue
cert_row = {}
cert_row["Serial"] = cert.serial
cert_row["Subject"] = "%.30s" % cert.subject
cert_row["Subject Alt Names (SAN)"] = "\n".join(
["%.30s" % x for x in cert.san_names]
)
cert_row["Provisioner"] = (
f"{cert.provisioner['name']} ({cert.provisioner['type']})"
)
validity = []
validity.append(f"Not before: {cert.not_before}")
validity.append(f"Not after: {cert.not_after}")
if cert.revoked_at is not None:
validity.append(f"Revoked at: {cert.revoked_at}")
cert_row["Validity"] = "\n".join(validity)
cert_row["Status"] = cert.status
cert_tbl.append(cert_row)
print(tabulate(cert_tbl, headers="keys", tablefmt="fancy_grid"))
def get_x509_cert(serial, show_pem=False):
cert = x509_cert.cert(serial)
cert_tbl = []
cert_tbl.append(["Serial", cert.serial])
cert_tbl.append(["Subject", cert.subject])
cert_tbl.append(["Subject Alt Names (SAN)", "\n".join(cert.san_names)])
cert_tbl.append(["Issuer", cert.issuer])
cert_tbl.append(["Not valid before", cert.not_before])
cert_tbl.append(["Not valid after", cert.not_after])
if cert.revoked_at is not None:
cert_tbl.append(["Revoked at", cert.revoked_at])
cert_tbl.append(
["Provisioner", f"{cert.provisioner['name']} ({cert.provisioner['type']})"]
)
fingerprints = []
fingerprints.append(f"MD5: {cert.md5.decode()}")
fingerprints.append(f"SHA-1: {cert.sha1.decode()}")
fingerprints.append(f"SHA-256: {cert.sha256.decode()}")
cert_tbl.append(["Fingerprints", "\n".join(fingerprints)])
cert_tbl.append(["Public key algorithm", cert.pub_alg])
cert_tbl.append(["Signature algorithm", cert.sig_alg])
cert_tbl.append(["Status", cert.status])
# cert_tbl.append(["Extensions", cert.extensions])
if show_pem:
cert_tbl.append(["PEM", cert.pem.decode("utf-8")])
print(tabulate(cert_tbl, tablefmt="fancy_grid"))
def dump_x509_cert(serial, cert_format="pem"):
cert = x509_cert.cert(serial)
print(cert.pem.decode("utf-8").rstrip())
parser = argparse.ArgumentParser(description="Step CA Inspector")
subparsers = parser.add_subparsers(
help="Object to inspect", dest="object", required=True
)
x509_parser = subparsers.add_parser("x509", help="x509 certificates")
x509_subparsers = x509_parser.add_subparsers(
help="Action for perform", dest="action", required=True
)
x509_list_parser = x509_subparsers.add_parser("list", help="List x509 certificates")
x509_list_parser.add_argument(
"--show-expired",
"-e",
action="store_true",
default=False,
help="Show expired certificates",
)
x509_list_parser.add_argument(
"--show-revoked",
"-r",
action="store_true",
default=False,
help="Show revoked certificates",
)
x509_list_parser.add_argument(
"--sort-by",
"-s",
type=str,
choices=["not_after", "not_before"],
default="not_after",
help="Sort certificates",
)
x509_details_parser = x509_subparsers.add_parser(
"details", help="Show an x509 certificate details"
)
x509_details_parser.add_argument(
"--serial", "-s", type=str, required=True, help="Certificate serial"
)
x509_details_parser.add_argument(
"--show-pem",
"-p",
action="store_true",
default=False,
help="Show PEM",
)
x509_dump_parser = x509_subparsers.add_parser("dump", help="Dump an x509 certificate")
x509_dump_parser.add_argument(
"--serial", "-s", type=str, required=True, help="Certificate serial"
)
x509_dump_parser.add_argument(
"--format",
"-f",
type=str,
choices=["pem"],
required=False,
help="Certificate format",
)
ssh_parser = subparsers.add_parser("ssh", help="ssh certificates")
ssh_subparsers = ssh_parser.add_subparsers(
help="Action for perform", dest="action", required=True
)
ssh_list_parser = ssh_subparsers.add_parser("list", help="List ssh certificates")
ssh_list_parser.add_argument(
"--show-expired",
"-e",
action="store_true",
default=False,
help="Show expired certificates",
)
ssh_list_parser.add_argument(
"--show-revoked",
"-r",
action="store_true",
default=False,
help="Show revoked certificates",
)
ssh_list_parser.add_argument(
"--sort-by",
"-s",
type=str,
choices=["not_after", "not_before"],
default="not_after",
help="Sort certificates (default: not_after)",
)
ssh_details_parser = ssh_subparsers.add_parser(
"details", help="Show an ssh certificate details"
)
ssh_details_parser.add_argument(
"--serial", "-s", type=str, required=True, help="Certificate serial"
)
ssh_dump_parser = ssh_subparsers.add_parser("dump", help="Dump an ssh certificate")
ssh_dump_parser.add_argument(
"--serial", "-s", type=str, required=True, help="Certificate serial"
)
args = parser.parse_args()
if args.object == "x509":
if args.action == "list":
list_x509_certs(
revoked=args.show_revoked, expired=args.show_expired, sort_key=args.sort_by
)
elif args.action == "details":
get_x509_cert(serial=args.serial, show_pem=args.show_pem)
elif args.action == "dump":
dump_x509_cert(serial=args.serial)
elif args.object == "ssh":
if args.action == "list":
list_ssh_certs(
revoked=args.show_revoked, expired=args.show_expired, sort_key=args.sort_by
)
elif args.action == "details":
get_ssh_cert(serial=args.serial)
elif args.action == "dump":
dump_ssh_cert(serial=args.serial)