Merge pull request 'Switch to enums, add search for x509 certs and improve ssh cert search' () from rework into main

Reviewed-on: 
This commit is contained in:
Benjamin Collet 2025-05-13 09:10:20 +00:00
commit c5132eb828
3 changed files with 104 additions and 69 deletions
step-ca-inspector

View file

@ -69,10 +69,30 @@ metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
class certStatus(str, Enum):
REVOKED = "Revoked"
EXPIRED = "Expired"
VALID = "Valid"
class provisionerType(str, Enum):
# https://github.com/smallstep/certificates/blob/938a4da5adf2d32f36ffd06922e5c66956dfff41/authority/provisioner/provisioner.go#L200-L223
ACME = "ACME"
AWS = "AWS"
GCP = "GCP"
JWK = "JWK"
Nebula = "Nebula"
OIDC = "OIDC"
SCEP = "SCEP"
SSHPOP = "SSHPOP"
X5C = "X5C"
K8sSA = "K8sSA"
class provisioner(BaseModel):
id: str
name: str
type: str
type: provisionerType
class sanName(BaseModel):
@ -88,7 +108,7 @@ class x509Cert(BaseModel):
not_after: int
not_before: int
revoked_at: Union[int, None] = None
status: str
status: certStatus
sha256: str
sha1: str
md5: str
@ -99,16 +119,21 @@ class x509Cert(BaseModel):
pem: str
class sshCertType(str, Enum):
HOST = "Host"
USER = "User"
class sshCert(BaseModel):
serial: str
alg: str
type: str
type: sshCertType
key_id: str
principals: List[str] = []
not_after: int
not_before: int
revoked_at: Union[int, None] = None
status: str
status: certStatus
signing_key: str
signing_key_type: str
signing_key_hash: str
@ -146,7 +171,7 @@ async def update_metrics():
"principals": ",".join([x.decode() for x in cert.principals]),
"serial": cert.serial,
"key_id": cert.key_id.decode(),
"certificate_type": cert.type,
"certificate_type": getattr(sshCertType, cert.type.name).value,
}
ssh_cert_not_after.labels(**labels).set(cert.not_after)
@ -160,18 +185,45 @@ async def update_metrics():
@app.get("/x509/certs", tags=["x509"])
def list_x509_certs(
sort_key: str = "not_after", revoked: bool = False, expired: bool = False
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
revoked: bool = Query(False, deprecated=True),
expired: bool = Query(False, deprecated=True),
cert_status: list[certStatus] = Query(["Valid"]),
subject: str = None,
san: str = None,
provisioner: str = None,
provisioner_type: list[provisionerType] = Query(list(provisionerType)),
) -> list[x509Cert]:
certs = x509_cert.list(db, sort_key=sort_key)
cert_list = []
for cert in certs:
if cert.status.value == x509_cert.status.EXPIRED and not expired:
continue
if cert.status.value == x509_cert.status.REVOKED and not revoked:
continue
if cert.status.name not in [item.name for item in cert_status]:
# TODO: Remove handling of deprecated parameters
if not expired and not revoked:
continue
if cert.status == x509_cert.status.EXPIRED and not expired:
continue
if cert.status == x509_cert.status.REVOKED and not revoked:
continue
cert.status = str(cert.status)
if (
provisioner is not None
and provisioner.casefold() not in cert.provisioner["name"].casefold()
):
continue
if cert.provisioner["type"] not in [item.name for item in provisioner_type]:
continue
if subject is not None and subject.casefold() not in cert.subject.casefold():
continue
if san is not None:
for cert_san_name in cert.san_names:
if san.casefold() in cert_san_name["value"].casefold():
break
else:
continue
cert.status = getattr(certStatus, cert.status.name)
cert_list.append(cert)
return cert_list
@ -182,18 +234,17 @@ def get_x509_cert(serial: str) -> Union[x509Cert, None]:
cert = x509_cert.cert.from_serial(db, serial)
if cert is None:
return None
cert.status = str(cert.status)
cert.status = getattr(certStatus, cert.status.name)
return cert
@app.get("/ssh/certs", tags=["ssh"])
def list_ssh_certs(
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
revoked: bool = False,
expired: bool = False,
cert_type: list[Enum("Types", [("Host", "Host"), ("User", "User")])] = Query(
["Host", "User"]
),
revoked: bool = Query(False, deprecated=True),
expired: bool = Query(False, deprecated=True),
cert_type: list[sshCertType] = Query(["Host", "User"]),
cert_status: list[certStatus] = Query(["Valid"]),
key: str = None,
principal: str = None,
) -> list[sshCert]:
@ -201,18 +252,30 @@ def list_ssh_certs(
cert_list = []
for cert in certs:
if cert.status.value == ssh_cert.status.EXPIRED and not expired:
continue
if cert.status.value == ssh_cert.status.REVOKED and not revoked:
continue
if cert.type not in [item.value for item in cert_type]:
continue
if key is not None and key not in str(cert.key_id):
continue
if principal is not None and principal not in str(cert.principals):
continue
if cert.status.name not in [item.name for item in cert_status]:
# TODO: Remove handling of deprecated parameters
if not expired and not revoked:
continue
cert.status = str(cert.status)
if cert.status == ssh_cert.status.EXPIRED and not expired:
continue
if cert.status == ssh_cert.status.REVOKED and not revoked:
continue
if cert.type.name not in [item.name for item in cert_type]:
continue
if key is not None and key.casefold() not in str(cert.key_id).casefold():
continue
if principal is not None:
for cert_principal in cert.principals:
if principal.casefold() in str(cert_principal).casefold():
break
else:
continue
cert.type = getattr(sshCertType, cert.type.name)
cert.status = getattr(certStatus, cert.status.name)
cert_list.append(cert)
return cert_list
@ -223,5 +286,6 @@ def get_ssh_cert(serial: str) -> Union[sshCert, None]:
cert = ssh_cert.cert.from_serial(db, serial)
if cert is None:
return None
cert.status = str(cert.status)
cert.type = getattr(sshCertType, cert.type.name)
cert.status = getattr(certStatus, cert.status.name)
return cert

View file

@ -5,6 +5,7 @@ import mariadb
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
from datetime import datetime, timedelta, timezone
from struct import unpack
from enum import Enum
class list:
@ -58,12 +59,7 @@ class cert:
cert = serialization.load_ssh_public_identity(cert_pub_id)
self.serial = str(cert.serial)
self.alg = cert_alg
if cert.type == serialization.SSHCertificateType.USER:
self.type = "User"
elif cert.type == serialization.SSHCertificateType.HOST:
self.type = "Host"
else:
self.type = "Unknown"
self.type = cert.type
self.key_id = cert.key_id
self.principals = cert.valid_principals
self.not_after = cert.valid_before
@ -96,11 +92,11 @@ class cert:
)
if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED)
self.status = status.REVOKED
elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED)
self.status = status.EXPIRED
else:
self.status = status(status.VALID)
self.status = status.VALID
def get_cert(self, db, cert_serial):
cur = db.cursor()
@ -142,20 +138,7 @@ class cert:
return key_str, key_type, key_hash
class status:
class status(Enum):
REVOKED = 1
EXPIRED = 2
VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"

View file

@ -5,6 +5,7 @@ import mariadb
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from datetime import datetime, timedelta, timezone
from enum import Enum
class list:
@ -97,11 +98,11 @@ class cert:
)
if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED)
self.status = status.REVOKED
elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED)
self.status = status.EXPIRED
else:
self.status = status(status.VALID)
self.status = status.VALID
def get_cert(self, db, cert_serial):
cur = db.cursor()
@ -153,20 +154,7 @@ class cert:
return sans
class status:
class status(Enum):
REVOKED = 1
EXPIRED = 2
VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"