Switch to enums, add search for x509 certs and improve ssh cert search

This commit is contained in:
Benjamin Collet 2025-05-13 08:29:56 +02:00
parent 0cb5337e32
commit 204f85fb8e
Signed by: bcollet
SSH key fingerprint: SHA256:8UJspOIcCOS+MtSOcnuq2HjKFube4ox1s/+A62ixov4
3 changed files with 104 additions and 69 deletions
step-ca-inspector

View file

@ -69,10 +69,30 @@ metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app) app.mount("/metrics", metrics_app)
class certStatus(str, Enum):
REVOKED = "Revoked"
EXPIRED = "Expired"
VALID = "Valid"
class provisionerType(str, Enum):
# https://github.com/smallstep/certificates/blob/938a4da5adf2d32f36ffd06922e5c66956dfff41/authority/provisioner/provisioner.go#L200-L223
ACME = "ACME"
AWS = "AWS"
GCP = "GCP"
JWK = "JWK"
Nebula = "Nebula"
OIDC = "OIDC"
SCEP = "SCEP"
SSHPOP = "SSHPOP"
X5C = "X5C"
K8sSA = "K8sSA"
class provisioner(BaseModel): class provisioner(BaseModel):
id: str id: str
name: str name: str
type: str type: provisionerType
class sanName(BaseModel): class sanName(BaseModel):
@ -88,7 +108,7 @@ class x509Cert(BaseModel):
not_after: int not_after: int
not_before: int not_before: int
revoked_at: Union[int, None] = None revoked_at: Union[int, None] = None
status: str status: certStatus
sha256: str sha256: str
sha1: str sha1: str
md5: str md5: str
@ -99,16 +119,21 @@ class x509Cert(BaseModel):
pem: str pem: str
class sshCertType(str, Enum):
HOST = "Host"
USER = "User"
class sshCert(BaseModel): class sshCert(BaseModel):
serial: str serial: str
alg: str alg: str
type: str type: sshCertType
key_id: str key_id: str
principals: List[str] = [] principals: List[str] = []
not_after: int not_after: int
not_before: int not_before: int
revoked_at: Union[int, None] = None revoked_at: Union[int, None] = None
status: str status: certStatus
signing_key: str signing_key: str
signing_key_type: str signing_key_type: str
signing_key_hash: str signing_key_hash: str
@ -146,7 +171,7 @@ async def update_metrics():
"principals": ",".join([x.decode() for x in cert.principals]), "principals": ",".join([x.decode() for x in cert.principals]),
"serial": cert.serial, "serial": cert.serial,
"key_id": cert.key_id.decode(), "key_id": cert.key_id.decode(),
"certificate_type": cert.type, "certificate_type": getattr(sshCertType, cert.type.name).value,
} }
ssh_cert_not_after.labels(**labels).set(cert.not_after) ssh_cert_not_after.labels(**labels).set(cert.not_after)
@ -160,18 +185,45 @@ async def update_metrics():
@app.get("/x509/certs", tags=["x509"]) @app.get("/x509/certs", tags=["x509"])
def list_x509_certs( def list_x509_certs(
sort_key: str = "not_after", revoked: bool = False, expired: bool = False sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
revoked: bool = Query(False, deprecated=True),
expired: bool = Query(False, deprecated=True),
cert_status: list[certStatus] = Query(["Valid"]),
subject: str = None,
san: str = None,
provisioner: str = None,
provisioner_type: list[provisionerType] = Query(list(provisionerType)),
) -> list[x509Cert]: ) -> list[x509Cert]:
certs = x509_cert.list(db, sort_key=sort_key) certs = x509_cert.list(db, sort_key=sort_key)
cert_list = [] cert_list = []
for cert in certs: for cert in certs:
if cert.status.value == x509_cert.status.EXPIRED and not expired: if cert.status.name not in [item.name for item in cert_status]:
continue # TODO: Remove handling of deprecated parameters
if cert.status.value == x509_cert.status.REVOKED and not revoked: if not expired and not revoked:
continue continue
if cert.status == x509_cert.status.EXPIRED and not expired:
continue
if cert.status == x509_cert.status.REVOKED and not revoked:
continue
cert.status = str(cert.status) if (
provisioner is not None
and provisioner.casefold() not in cert.provisioner["name"].casefold()
):
continue
if cert.provisioner["type"] not in [item.name for item in provisioner_type]:
continue
if subject is not None and subject.casefold() not in cert.subject.casefold():
continue
if san is not None:
for cert_san_name in cert.san_names:
if san.casefold() in cert_san_name["value"].casefold():
break
else:
continue
cert.status = getattr(certStatus, cert.status.name)
cert_list.append(cert) cert_list.append(cert)
return cert_list return cert_list
@ -182,18 +234,17 @@ def get_x509_cert(serial: str) -> Union[x509Cert, None]:
cert = x509_cert.cert.from_serial(db, serial) cert = x509_cert.cert.from_serial(db, serial)
if cert is None: if cert is None:
return None return None
cert.status = str(cert.status) cert.status = getattr(certStatus, cert.status.name)
return cert return cert
@app.get("/ssh/certs", tags=["ssh"]) @app.get("/ssh/certs", tags=["ssh"])
def list_ssh_certs( def list_ssh_certs(
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"), sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
revoked: bool = False, revoked: bool = Query(False, deprecated=True),
expired: bool = False, expired: bool = Query(False, deprecated=True),
cert_type: list[Enum("Types", [("Host", "Host"), ("User", "User")])] = Query( cert_type: list[sshCertType] = Query(["Host", "User"]),
["Host", "User"] cert_status: list[certStatus] = Query(["Valid"]),
),
key: str = None, key: str = None,
principal: str = None, principal: str = None,
) -> list[sshCert]: ) -> list[sshCert]:
@ -201,18 +252,30 @@ def list_ssh_certs(
cert_list = [] cert_list = []
for cert in certs: for cert in certs:
if cert.status.value == ssh_cert.status.EXPIRED and not expired: if cert.status.name not in [item.name for item in cert_status]:
continue # TODO: Remove handling of deprecated parameters
if cert.status.value == ssh_cert.status.REVOKED and not revoked: if not expired and not revoked:
continue continue
if cert.type not in [item.value for item in cert_type]:
continue
if key is not None and key not in str(cert.key_id):
continue
if principal is not None and principal not in str(cert.principals):
continue
cert.status = str(cert.status) if cert.status == ssh_cert.status.EXPIRED and not expired:
continue
if cert.status == ssh_cert.status.REVOKED and not revoked:
continue
if cert.type.name not in [item.name for item in cert_type]:
continue
if key is not None and key.casefold() not in str(cert.key_id).casefold():
continue
if principal is not None:
for cert_principal in cert.principals:
if principal.casefold() in str(cert_principal).casefold():
break
else:
continue
cert.type = getattr(sshCertType, cert.type.name)
cert.status = getattr(certStatus, cert.status.name)
cert_list.append(cert) cert_list.append(cert)
return cert_list return cert_list
@ -223,5 +286,6 @@ def get_ssh_cert(serial: str) -> Union[sshCert, None]:
cert = ssh_cert.cert.from_serial(db, serial) cert = ssh_cert.cert.from_serial(db, serial)
if cert is None: if cert is None:
return None return None
cert.status = str(cert.status) cert.type = getattr(sshCertType, cert.type.name)
cert.status = getattr(certStatus, cert.status.name)
return cert return cert

View file

@ -5,6 +5,7 @@ import mariadb
from cryptography.hazmat.primitives import asymmetric, hashes, serialization from cryptography.hazmat.primitives import asymmetric, hashes, serialization
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from struct import unpack from struct import unpack
from enum import Enum
class list: class list:
@ -58,12 +59,7 @@ class cert:
cert = serialization.load_ssh_public_identity(cert_pub_id) cert = serialization.load_ssh_public_identity(cert_pub_id)
self.serial = str(cert.serial) self.serial = str(cert.serial)
self.alg = cert_alg self.alg = cert_alg
if cert.type == serialization.SSHCertificateType.USER: self.type = cert.type
self.type = "User"
elif cert.type == serialization.SSHCertificateType.HOST:
self.type = "Host"
else:
self.type = "Unknown"
self.key_id = cert.key_id self.key_id = cert.key_id
self.principals = cert.valid_principals self.principals = cert.valid_principals
self.not_after = cert.valid_before self.not_after = cert.valid_before
@ -96,11 +92,11 @@ class cert:
) )
if self.revoked_at is not None and self.revoked_at < now_with_tz: if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED) self.status = status.REVOKED
elif self.not_after < now_with_tz: elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED) self.status = status.EXPIRED
else: else:
self.status = status(status.VALID) self.status = status.VALID
def get_cert(self, db, cert_serial): def get_cert(self, db, cert_serial):
cur = db.cursor() cur = db.cursor()
@ -142,20 +138,7 @@ class cert:
return key_str, key_type, key_hash return key_str, key_type, key_hash
class status: class status(Enum):
REVOKED = 1 REVOKED = 1
EXPIRED = 2 EXPIRED = 2
VALID = 3 VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"

View file

@ -5,6 +5,7 @@ import mariadb
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives import hashes, serialization
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from enum import Enum
class list: class list:
@ -97,11 +98,11 @@ class cert:
) )
if self.revoked_at is not None and self.revoked_at < now_with_tz: if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED) self.status = status.REVOKED
elif self.not_after < now_with_tz: elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED) self.status = status.EXPIRED
else: else:
self.status = status(status.VALID) self.status = status.VALID
def get_cert(self, db, cert_serial): def get_cert(self, db, cert_serial):
cur = db.cursor() cur = db.cursor()
@ -153,20 +154,7 @@ class cert:
return sans return sans
class status: class status(Enum):
REVOKED = 1 REVOKED = 1
EXPIRED = 2 EXPIRED = 2
VALID = 3 VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"