Merge pull request 'Switch to enums, add search for x509 certs and improve ssh cert search' (#1) from rework into main
Reviewed-on: #1
This commit is contained in:
commit
c5132eb828
3 changed files with 104 additions and 69 deletions
step-ca-inspector
|
@ -69,10 +69,30 @@ metrics_app = make_asgi_app()
|
|||
app.mount("/metrics", metrics_app)
|
||||
|
||||
|
||||
class certStatus(str, Enum):
|
||||
REVOKED = "Revoked"
|
||||
EXPIRED = "Expired"
|
||||
VALID = "Valid"
|
||||
|
||||
|
||||
class provisionerType(str, Enum):
|
||||
# https://github.com/smallstep/certificates/blob/938a4da5adf2d32f36ffd06922e5c66956dfff41/authority/provisioner/provisioner.go#L200-L223
|
||||
ACME = "ACME"
|
||||
AWS = "AWS"
|
||||
GCP = "GCP"
|
||||
JWK = "JWK"
|
||||
Nebula = "Nebula"
|
||||
OIDC = "OIDC"
|
||||
SCEP = "SCEP"
|
||||
SSHPOP = "SSHPOP"
|
||||
X5C = "X5C"
|
||||
K8sSA = "K8sSA"
|
||||
|
||||
|
||||
class provisioner(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
type: provisionerType
|
||||
|
||||
|
||||
class sanName(BaseModel):
|
||||
|
@ -88,7 +108,7 @@ class x509Cert(BaseModel):
|
|||
not_after: int
|
||||
not_before: int
|
||||
revoked_at: Union[int, None] = None
|
||||
status: str
|
||||
status: certStatus
|
||||
sha256: str
|
||||
sha1: str
|
||||
md5: str
|
||||
|
@ -99,16 +119,21 @@ class x509Cert(BaseModel):
|
|||
pem: str
|
||||
|
||||
|
||||
class sshCertType(str, Enum):
|
||||
HOST = "Host"
|
||||
USER = "User"
|
||||
|
||||
|
||||
class sshCert(BaseModel):
|
||||
serial: str
|
||||
alg: str
|
||||
type: str
|
||||
type: sshCertType
|
||||
key_id: str
|
||||
principals: List[str] = []
|
||||
not_after: int
|
||||
not_before: int
|
||||
revoked_at: Union[int, None] = None
|
||||
status: str
|
||||
status: certStatus
|
||||
signing_key: str
|
||||
signing_key_type: str
|
||||
signing_key_hash: str
|
||||
|
@ -146,7 +171,7 @@ async def update_metrics():
|
|||
"principals": ",".join([x.decode() for x in cert.principals]),
|
||||
"serial": cert.serial,
|
||||
"key_id": cert.key_id.decode(),
|
||||
"certificate_type": cert.type,
|
||||
"certificate_type": getattr(sshCertType, cert.type.name).value,
|
||||
}
|
||||
|
||||
ssh_cert_not_after.labels(**labels).set(cert.not_after)
|
||||
|
@ -160,18 +185,45 @@ async def update_metrics():
|
|||
|
||||
@app.get("/x509/certs", tags=["x509"])
|
||||
def list_x509_certs(
|
||||
sort_key: str = "not_after", revoked: bool = False, expired: bool = False
|
||||
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
|
||||
revoked: bool = Query(False, deprecated=True),
|
||||
expired: bool = Query(False, deprecated=True),
|
||||
cert_status: list[certStatus] = Query(["Valid"]),
|
||||
subject: str = None,
|
||||
san: str = None,
|
||||
provisioner: str = None,
|
||||
provisioner_type: list[provisionerType] = Query(list(provisionerType)),
|
||||
) -> list[x509Cert]:
|
||||
certs = x509_cert.list(db, sort_key=sort_key)
|
||||
cert_list = []
|
||||
|
||||
for cert in certs:
|
||||
if cert.status.value == x509_cert.status.EXPIRED and not expired:
|
||||
continue
|
||||
if cert.status.value == x509_cert.status.REVOKED and not revoked:
|
||||
continue
|
||||
if cert.status.name not in [item.name for item in cert_status]:
|
||||
# TODO: Remove handling of deprecated parameters
|
||||
if not expired and not revoked:
|
||||
continue
|
||||
if cert.status == x509_cert.status.EXPIRED and not expired:
|
||||
continue
|
||||
if cert.status == x509_cert.status.REVOKED and not revoked:
|
||||
continue
|
||||
|
||||
cert.status = str(cert.status)
|
||||
if (
|
||||
provisioner is not None
|
||||
and provisioner.casefold() not in cert.provisioner["name"].casefold()
|
||||
):
|
||||
continue
|
||||
if cert.provisioner["type"] not in [item.name for item in provisioner_type]:
|
||||
continue
|
||||
if subject is not None and subject.casefold() not in cert.subject.casefold():
|
||||
continue
|
||||
if san is not None:
|
||||
for cert_san_name in cert.san_names:
|
||||
if san.casefold() in cert_san_name["value"].casefold():
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
cert.status = getattr(certStatus, cert.status.name)
|
||||
cert_list.append(cert)
|
||||
|
||||
return cert_list
|
||||
|
@ -182,18 +234,17 @@ def get_x509_cert(serial: str) -> Union[x509Cert, None]:
|
|||
cert = x509_cert.cert.from_serial(db, serial)
|
||||
if cert is None:
|
||||
return None
|
||||
cert.status = str(cert.status)
|
||||
cert.status = getattr(certStatus, cert.status.name)
|
||||
return cert
|
||||
|
||||
|
||||
@app.get("/ssh/certs", tags=["ssh"])
|
||||
def list_ssh_certs(
|
||||
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
|
||||
revoked: bool = False,
|
||||
expired: bool = False,
|
||||
cert_type: list[Enum("Types", [("Host", "Host"), ("User", "User")])] = Query(
|
||||
["Host", "User"]
|
||||
),
|
||||
revoked: bool = Query(False, deprecated=True),
|
||||
expired: bool = Query(False, deprecated=True),
|
||||
cert_type: list[sshCertType] = Query(["Host", "User"]),
|
||||
cert_status: list[certStatus] = Query(["Valid"]),
|
||||
key: str = None,
|
||||
principal: str = None,
|
||||
) -> list[sshCert]:
|
||||
|
@ -201,18 +252,30 @@ def list_ssh_certs(
|
|||
cert_list = []
|
||||
|
||||
for cert in certs:
|
||||
if cert.status.value == ssh_cert.status.EXPIRED and not expired:
|
||||
continue
|
||||
if cert.status.value == ssh_cert.status.REVOKED and not revoked:
|
||||
continue
|
||||
if cert.type not in [item.value for item in cert_type]:
|
||||
continue
|
||||
if key is not None and key not in str(cert.key_id):
|
||||
continue
|
||||
if principal is not None and principal not in str(cert.principals):
|
||||
continue
|
||||
if cert.status.name not in [item.name for item in cert_status]:
|
||||
# TODO: Remove handling of deprecated parameters
|
||||
if not expired and not revoked:
|
||||
continue
|
||||
|
||||
cert.status = str(cert.status)
|
||||
if cert.status == ssh_cert.status.EXPIRED and not expired:
|
||||
continue
|
||||
|
||||
if cert.status == ssh_cert.status.REVOKED and not revoked:
|
||||
continue
|
||||
|
||||
if cert.type.name not in [item.name for item in cert_type]:
|
||||
continue
|
||||
if key is not None and key.casefold() not in str(cert.key_id).casefold():
|
||||
continue
|
||||
if principal is not None:
|
||||
for cert_principal in cert.principals:
|
||||
if principal.casefold() in str(cert_principal).casefold():
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
cert.type = getattr(sshCertType, cert.type.name)
|
||||
cert.status = getattr(certStatus, cert.status.name)
|
||||
cert_list.append(cert)
|
||||
|
||||
return cert_list
|
||||
|
@ -223,5 +286,6 @@ def get_ssh_cert(serial: str) -> Union[sshCert, None]:
|
|||
cert = ssh_cert.cert.from_serial(db, serial)
|
||||
if cert is None:
|
||||
return None
|
||||
cert.status = str(cert.status)
|
||||
cert.type = getattr(sshCertType, cert.type.name)
|
||||
cert.status = getattr(certStatus, cert.status.name)
|
||||
return cert
|
||||
|
|
|
@ -5,6 +5,7 @@ import mariadb
|
|||
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from struct import unpack
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class list:
|
||||
|
@ -58,12 +59,7 @@ class cert:
|
|||
cert = serialization.load_ssh_public_identity(cert_pub_id)
|
||||
self.serial = str(cert.serial)
|
||||
self.alg = cert_alg
|
||||
if cert.type == serialization.SSHCertificateType.USER:
|
||||
self.type = "User"
|
||||
elif cert.type == serialization.SSHCertificateType.HOST:
|
||||
self.type = "Host"
|
||||
else:
|
||||
self.type = "Unknown"
|
||||
self.type = cert.type
|
||||
self.key_id = cert.key_id
|
||||
self.principals = cert.valid_principals
|
||||
self.not_after = cert.valid_before
|
||||
|
@ -96,11 +92,11 @@ class cert:
|
|||
)
|
||||
|
||||
if self.revoked_at is not None and self.revoked_at < now_with_tz:
|
||||
self.status = status(status.REVOKED)
|
||||
self.status = status.REVOKED
|
||||
elif self.not_after < now_with_tz:
|
||||
self.status = status(status.EXPIRED)
|
||||
self.status = status.EXPIRED
|
||||
else:
|
||||
self.status = status(status.VALID)
|
||||
self.status = status.VALID
|
||||
|
||||
def get_cert(self, db, cert_serial):
|
||||
cur = db.cursor()
|
||||
|
@ -142,20 +138,7 @@ class cert:
|
|||
return key_str, key_type, key_hash
|
||||
|
||||
|
||||
class status:
|
||||
class status(Enum):
|
||||
REVOKED = 1
|
||||
EXPIRED = 2
|
||||
VALID = 3
|
||||
|
||||
def __init__(self, status):
|
||||
self.value = status
|
||||
|
||||
def __str__(self):
|
||||
if self.value == self.EXPIRED:
|
||||
return "Expired"
|
||||
elif self.value == self.REVOKED:
|
||||
return "Revoked"
|
||||
elif self.value == self.VALID:
|
||||
return "Valid"
|
||||
else:
|
||||
return "Undefined"
|
||||
|
|
|
@ -5,6 +5,7 @@ import mariadb
|
|||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class list:
|
||||
|
@ -97,11 +98,11 @@ class cert:
|
|||
)
|
||||
|
||||
if self.revoked_at is not None and self.revoked_at < now_with_tz:
|
||||
self.status = status(status.REVOKED)
|
||||
self.status = status.REVOKED
|
||||
elif self.not_after < now_with_tz:
|
||||
self.status = status(status.EXPIRED)
|
||||
self.status = status.EXPIRED
|
||||
else:
|
||||
self.status = status(status.VALID)
|
||||
self.status = status.VALID
|
||||
|
||||
def get_cert(self, db, cert_serial):
|
||||
cur = db.cursor()
|
||||
|
@ -153,20 +154,7 @@ class cert:
|
|||
return sans
|
||||
|
||||
|
||||
class status:
|
||||
class status(Enum):
|
||||
REVOKED = 1
|
||||
EXPIRED = 2
|
||||
VALID = 3
|
||||
|
||||
def __init__(self, status):
|
||||
self.value = status
|
||||
|
||||
def __str__(self):
|
||||
if self.value == self.EXPIRED:
|
||||
return "Expired"
|
||||
elif self.value == self.REVOKED:
|
||||
return "Revoked"
|
||||
elif self.value == self.VALID:
|
||||
return "Valid"
|
||||
else:
|
||||
return "Undefined"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue