diff --git a/step-ca-inspector/main.py b/step-ca-inspector/main.py index 6320f0e..eee5105 100644 --- a/step-ca-inspector/main.py +++ b/step-ca-inspector/main.py @@ -69,10 +69,30 @@ metrics_app = make_asgi_app() app.mount("/metrics", metrics_app) +class certStatus(str, Enum): + REVOKED = "Revoked" + EXPIRED = "Expired" + VALID = "Valid" + + +class provisionerType(str, Enum): + # https://github.com/smallstep/certificates/blob/938a4da5adf2d32f36ffd06922e5c66956dfff41/authority/provisioner/provisioner.go#L200-L223 + ACME = "ACME" + AWS = "AWS" + GCP = "GCP" + JWK = "JWK" + Nebula = "Nebula" + OIDC = "OIDC" + SCEP = "SCEP" + SSHPOP = "SSHPOP" + X5C = "X5C" + K8sSA = "K8sSA" + + class provisioner(BaseModel): id: str name: str - type: str + type: provisionerType class sanName(BaseModel): @@ -88,7 +108,7 @@ class x509Cert(BaseModel): not_after: int not_before: int revoked_at: Union[int, None] = None - status: str + status: certStatus sha256: str sha1: str md5: str @@ -99,16 +119,21 @@ class x509Cert(BaseModel): pem: str +class sshCertType(str, Enum): + HOST = "Host" + USER = "User" + + class sshCert(BaseModel): serial: str alg: str - type: str + type: sshCertType key_id: str principals: List[str] = [] not_after: int not_before: int revoked_at: Union[int, None] = None - status: str + status: certStatus signing_key: str signing_key_type: str signing_key_hash: str @@ -146,7 +171,7 @@ async def update_metrics(): "principals": ",".join([x.decode() for x in cert.principals]), "serial": cert.serial, "key_id": cert.key_id.decode(), - "certificate_type": cert.type, + "certificate_type": getattr(sshCertType, cert.type.name).value, } ssh_cert_not_after.labels(**labels).set(cert.not_after) @@ -160,18 +185,45 @@ async def update_metrics(): @app.get("/x509/certs", tags=["x509"]) def list_x509_certs( - sort_key: str = "not_after", revoked: bool = False, expired: bool = False + sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"), + revoked: bool = Query(False, deprecated=True), + expired: bool = Query(False, deprecated=True), + cert_status: list[certStatus] = Query(["Valid"]), + subject: str = None, + san: str = None, + provisioner: str = None, + provisioner_type: list[provisionerType] = Query(list(provisionerType)), ) -> list[x509Cert]: certs = x509_cert.list(db, sort_key=sort_key) cert_list = [] for cert in certs: - if cert.status.value == x509_cert.status.EXPIRED and not expired: - continue - if cert.status.value == x509_cert.status.REVOKED and not revoked: - continue + if cert.status.name not in [item.name for item in cert_status]: + # TODO: Remove handling of deprecated parameters + if not expired and not revoked: + continue + if cert.status == x509_cert.status.EXPIRED and not expired: + continue + if cert.status == x509_cert.status.REVOKED and not revoked: + continue - cert.status = str(cert.status) + if ( + provisioner is not None + and provisioner.casefold() not in cert.provisioner["name"].casefold() + ): + continue + if cert.provisioner["type"] not in [item.name for item in provisioner_type]: + continue + if subject is not None and subject.casefold() not in cert.subject.casefold(): + continue + if san is not None: + for cert_san_name in cert.san_names: + if san.casefold() in cert_san_name["value"].casefold(): + break + else: + continue + + cert.status = getattr(certStatus, cert.status.name) cert_list.append(cert) return cert_list @@ -182,18 +234,17 @@ def get_x509_cert(serial: str) -> Union[x509Cert, None]: cert = x509_cert.cert.from_serial(db, serial) if cert is None: return None - cert.status = str(cert.status) + cert.status = getattr(certStatus, cert.status.name) return cert @app.get("/ssh/certs", tags=["ssh"]) def list_ssh_certs( sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"), - revoked: bool = False, - expired: bool = False, - cert_type: list[Enum("Types", [("Host", "Host"), ("User", "User")])] = Query( - ["Host", "User"] - ), + revoked: bool = Query(False, deprecated=True), + expired: bool = Query(False, deprecated=True), + cert_type: list[sshCertType] = Query(["Host", "User"]), + cert_status: list[certStatus] = Query(["Valid"]), key: str = None, principal: str = None, ) -> list[sshCert]: @@ -201,18 +252,30 @@ def list_ssh_certs( cert_list = [] for cert in certs: - if cert.status.value == ssh_cert.status.EXPIRED and not expired: - continue - if cert.status.value == ssh_cert.status.REVOKED and not revoked: - continue - if cert.type not in [item.value for item in cert_type]: - continue - if key is not None and key not in str(cert.key_id): - continue - if principal is not None and principal not in str(cert.principals): - continue + if cert.status.name not in [item.name for item in cert_status]: + # TODO: Remove handling of deprecated parameters + if not expired and not revoked: + continue - cert.status = str(cert.status) + if cert.status == ssh_cert.status.EXPIRED and not expired: + continue + + if cert.status == ssh_cert.status.REVOKED and not revoked: + continue + + if cert.type.name not in [item.name for item in cert_type]: + continue + if key is not None and key.casefold() not in str(cert.key_id).casefold(): + continue + if principal is not None: + for cert_principal in cert.principals: + if principal.casefold() in str(cert_principal).casefold(): + break + else: + continue + + cert.type = getattr(sshCertType, cert.type.name) + cert.status = getattr(certStatus, cert.status.name) cert_list.append(cert) return cert_list @@ -223,5 +286,6 @@ def get_ssh_cert(serial: str) -> Union[sshCert, None]: cert = ssh_cert.cert.from_serial(db, serial) if cert is None: return None - cert.status = str(cert.status) + cert.type = getattr(sshCertType, cert.type.name) + cert.status = getattr(certStatus, cert.status.name) return cert diff --git a/step-ca-inspector/models/ssh_cert.py b/step-ca-inspector/models/ssh_cert.py index b222fd5..12785b5 100644 --- a/step-ca-inspector/models/ssh_cert.py +++ b/step-ca-inspector/models/ssh_cert.py @@ -5,6 +5,7 @@ import mariadb from cryptography.hazmat.primitives import asymmetric, hashes, serialization from datetime import datetime, timedelta, timezone from struct import unpack +from enum import Enum class list: @@ -58,12 +59,7 @@ class cert: cert = serialization.load_ssh_public_identity(cert_pub_id) self.serial = str(cert.serial) self.alg = cert_alg - if cert.type == serialization.SSHCertificateType.USER: - self.type = "User" - elif cert.type == serialization.SSHCertificateType.HOST: - self.type = "Host" - else: - self.type = "Unknown" + self.type = cert.type self.key_id = cert.key_id self.principals = cert.valid_principals self.not_after = cert.valid_before @@ -96,11 +92,11 @@ class cert: ) if self.revoked_at is not None and self.revoked_at < now_with_tz: - self.status = status(status.REVOKED) + self.status = status.REVOKED elif self.not_after < now_with_tz: - self.status = status(status.EXPIRED) + self.status = status.EXPIRED else: - self.status = status(status.VALID) + self.status = status.VALID def get_cert(self, db, cert_serial): cur = db.cursor() @@ -142,20 +138,7 @@ class cert: return key_str, key_type, key_hash -class status: +class status(Enum): REVOKED = 1 EXPIRED = 2 VALID = 3 - - def __init__(self, status): - self.value = status - - def __str__(self): - if self.value == self.EXPIRED: - return "Expired" - elif self.value == self.REVOKED: - return "Revoked" - elif self.value == self.VALID: - return "Valid" - else: - return "Undefined" diff --git a/step-ca-inspector/models/x509_cert.py b/step-ca-inspector/models/x509_cert.py index 162a303..12817a4 100644 --- a/step-ca-inspector/models/x509_cert.py +++ b/step-ca-inspector/models/x509_cert.py @@ -5,6 +5,7 @@ import mariadb from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization from datetime import datetime, timedelta, timezone +from enum import Enum class list: @@ -97,11 +98,11 @@ class cert: ) if self.revoked_at is not None and self.revoked_at < now_with_tz: - self.status = status(status.REVOKED) + self.status = status.REVOKED elif self.not_after < now_with_tz: - self.status = status(status.EXPIRED) + self.status = status.EXPIRED else: - self.status = status(status.VALID) + self.status = status.VALID def get_cert(self, db, cert_serial): cur = db.cursor() @@ -153,20 +154,7 @@ class cert: return sans -class status: +class status(Enum): REVOKED = 1 EXPIRED = 2 VALID = 3 - - def __init__(self, status): - self.value = status - - def __str__(self): - if self.value == self.EXPIRED: - return "Expired" - elif self.value == self.REVOKED: - return "Revoked" - elif self.value == self.VALID: - return "Valid" - else: - return "Undefined"