Merge pull request 'Switch to enums, add search for x509 certs and improve ssh cert search' (#1) from rework into main
Reviewed-on: #1
This commit is contained in:
commit
c5132eb828
3 changed files with 104 additions and 69 deletions
step-ca-inspector
|
@ -69,10 +69,30 @@ metrics_app = make_asgi_app()
|
||||||
app.mount("/metrics", metrics_app)
|
app.mount("/metrics", metrics_app)
|
||||||
|
|
||||||
|
|
||||||
|
class certStatus(str, Enum):
|
||||||
|
REVOKED = "Revoked"
|
||||||
|
EXPIRED = "Expired"
|
||||||
|
VALID = "Valid"
|
||||||
|
|
||||||
|
|
||||||
|
class provisionerType(str, Enum):
|
||||||
|
# https://github.com/smallstep/certificates/blob/938a4da5adf2d32f36ffd06922e5c66956dfff41/authority/provisioner/provisioner.go#L200-L223
|
||||||
|
ACME = "ACME"
|
||||||
|
AWS = "AWS"
|
||||||
|
GCP = "GCP"
|
||||||
|
JWK = "JWK"
|
||||||
|
Nebula = "Nebula"
|
||||||
|
OIDC = "OIDC"
|
||||||
|
SCEP = "SCEP"
|
||||||
|
SSHPOP = "SSHPOP"
|
||||||
|
X5C = "X5C"
|
||||||
|
K8sSA = "K8sSA"
|
||||||
|
|
||||||
|
|
||||||
class provisioner(BaseModel):
|
class provisioner(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
type: str
|
type: provisionerType
|
||||||
|
|
||||||
|
|
||||||
class sanName(BaseModel):
|
class sanName(BaseModel):
|
||||||
|
@ -88,7 +108,7 @@ class x509Cert(BaseModel):
|
||||||
not_after: int
|
not_after: int
|
||||||
not_before: int
|
not_before: int
|
||||||
revoked_at: Union[int, None] = None
|
revoked_at: Union[int, None] = None
|
||||||
status: str
|
status: certStatus
|
||||||
sha256: str
|
sha256: str
|
||||||
sha1: str
|
sha1: str
|
||||||
md5: str
|
md5: str
|
||||||
|
@ -99,16 +119,21 @@ class x509Cert(BaseModel):
|
||||||
pem: str
|
pem: str
|
||||||
|
|
||||||
|
|
||||||
|
class sshCertType(str, Enum):
|
||||||
|
HOST = "Host"
|
||||||
|
USER = "User"
|
||||||
|
|
||||||
|
|
||||||
class sshCert(BaseModel):
|
class sshCert(BaseModel):
|
||||||
serial: str
|
serial: str
|
||||||
alg: str
|
alg: str
|
||||||
type: str
|
type: sshCertType
|
||||||
key_id: str
|
key_id: str
|
||||||
principals: List[str] = []
|
principals: List[str] = []
|
||||||
not_after: int
|
not_after: int
|
||||||
not_before: int
|
not_before: int
|
||||||
revoked_at: Union[int, None] = None
|
revoked_at: Union[int, None] = None
|
||||||
status: str
|
status: certStatus
|
||||||
signing_key: str
|
signing_key: str
|
||||||
signing_key_type: str
|
signing_key_type: str
|
||||||
signing_key_hash: str
|
signing_key_hash: str
|
||||||
|
@ -146,7 +171,7 @@ async def update_metrics():
|
||||||
"principals": ",".join([x.decode() for x in cert.principals]),
|
"principals": ",".join([x.decode() for x in cert.principals]),
|
||||||
"serial": cert.serial,
|
"serial": cert.serial,
|
||||||
"key_id": cert.key_id.decode(),
|
"key_id": cert.key_id.decode(),
|
||||||
"certificate_type": cert.type,
|
"certificate_type": getattr(sshCertType, cert.type.name).value,
|
||||||
}
|
}
|
||||||
|
|
||||||
ssh_cert_not_after.labels(**labels).set(cert.not_after)
|
ssh_cert_not_after.labels(**labels).set(cert.not_after)
|
||||||
|
@ -160,18 +185,45 @@ async def update_metrics():
|
||||||
|
|
||||||
@app.get("/x509/certs", tags=["x509"])
|
@app.get("/x509/certs", tags=["x509"])
|
||||||
def list_x509_certs(
|
def list_x509_certs(
|
||||||
sort_key: str = "not_after", revoked: bool = False, expired: bool = False
|
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
|
||||||
|
revoked: bool = Query(False, deprecated=True),
|
||||||
|
expired: bool = Query(False, deprecated=True),
|
||||||
|
cert_status: list[certStatus] = Query(["Valid"]),
|
||||||
|
subject: str = None,
|
||||||
|
san: str = None,
|
||||||
|
provisioner: str = None,
|
||||||
|
provisioner_type: list[provisionerType] = Query(list(provisionerType)),
|
||||||
) -> list[x509Cert]:
|
) -> list[x509Cert]:
|
||||||
certs = x509_cert.list(db, sort_key=sort_key)
|
certs = x509_cert.list(db, sort_key=sort_key)
|
||||||
cert_list = []
|
cert_list = []
|
||||||
|
|
||||||
for cert in certs:
|
for cert in certs:
|
||||||
if cert.status.value == x509_cert.status.EXPIRED and not expired:
|
if cert.status.name not in [item.name for item in cert_status]:
|
||||||
continue
|
# TODO: Remove handling of deprecated parameters
|
||||||
if cert.status.value == x509_cert.status.REVOKED and not revoked:
|
if not expired and not revoked:
|
||||||
continue
|
continue
|
||||||
|
if cert.status == x509_cert.status.EXPIRED and not expired:
|
||||||
|
continue
|
||||||
|
if cert.status == x509_cert.status.REVOKED and not revoked:
|
||||||
|
continue
|
||||||
|
|
||||||
cert.status = str(cert.status)
|
if (
|
||||||
|
provisioner is not None
|
||||||
|
and provisioner.casefold() not in cert.provisioner["name"].casefold()
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if cert.provisioner["type"] not in [item.name for item in provisioner_type]:
|
||||||
|
continue
|
||||||
|
if subject is not None and subject.casefold() not in cert.subject.casefold():
|
||||||
|
continue
|
||||||
|
if san is not None:
|
||||||
|
for cert_san_name in cert.san_names:
|
||||||
|
if san.casefold() in cert_san_name["value"].casefold():
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cert.status = getattr(certStatus, cert.status.name)
|
||||||
cert_list.append(cert)
|
cert_list.append(cert)
|
||||||
|
|
||||||
return cert_list
|
return cert_list
|
||||||
|
@ -182,18 +234,17 @@ def get_x509_cert(serial: str) -> Union[x509Cert, None]:
|
||||||
cert = x509_cert.cert.from_serial(db, serial)
|
cert = x509_cert.cert.from_serial(db, serial)
|
||||||
if cert is None:
|
if cert is None:
|
||||||
return None
|
return None
|
||||||
cert.status = str(cert.status)
|
cert.status = getattr(certStatus, cert.status.name)
|
||||||
return cert
|
return cert
|
||||||
|
|
||||||
|
|
||||||
@app.get("/ssh/certs", tags=["ssh"])
|
@app.get("/ssh/certs", tags=["ssh"])
|
||||||
def list_ssh_certs(
|
def list_ssh_certs(
|
||||||
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
|
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
|
||||||
revoked: bool = False,
|
revoked: bool = Query(False, deprecated=True),
|
||||||
expired: bool = False,
|
expired: bool = Query(False, deprecated=True),
|
||||||
cert_type: list[Enum("Types", [("Host", "Host"), ("User", "User")])] = Query(
|
cert_type: list[sshCertType] = Query(["Host", "User"]),
|
||||||
["Host", "User"]
|
cert_status: list[certStatus] = Query(["Valid"]),
|
||||||
),
|
|
||||||
key: str = None,
|
key: str = None,
|
||||||
principal: str = None,
|
principal: str = None,
|
||||||
) -> list[sshCert]:
|
) -> list[sshCert]:
|
||||||
|
@ -201,18 +252,30 @@ def list_ssh_certs(
|
||||||
cert_list = []
|
cert_list = []
|
||||||
|
|
||||||
for cert in certs:
|
for cert in certs:
|
||||||
if cert.status.value == ssh_cert.status.EXPIRED and not expired:
|
if cert.status.name not in [item.name for item in cert_status]:
|
||||||
continue
|
# TODO: Remove handling of deprecated parameters
|
||||||
if cert.status.value == ssh_cert.status.REVOKED and not revoked:
|
if not expired and not revoked:
|
||||||
continue
|
continue
|
||||||
if cert.type not in [item.value for item in cert_type]:
|
|
||||||
continue
|
|
||||||
if key is not None and key not in str(cert.key_id):
|
|
||||||
continue
|
|
||||||
if principal is not None and principal not in str(cert.principals):
|
|
||||||
continue
|
|
||||||
|
|
||||||
cert.status = str(cert.status)
|
if cert.status == ssh_cert.status.EXPIRED and not expired:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cert.status == ssh_cert.status.REVOKED and not revoked:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cert.type.name not in [item.name for item in cert_type]:
|
||||||
|
continue
|
||||||
|
if key is not None and key.casefold() not in str(cert.key_id).casefold():
|
||||||
|
continue
|
||||||
|
if principal is not None:
|
||||||
|
for cert_principal in cert.principals:
|
||||||
|
if principal.casefold() in str(cert_principal).casefold():
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cert.type = getattr(sshCertType, cert.type.name)
|
||||||
|
cert.status = getattr(certStatus, cert.status.name)
|
||||||
cert_list.append(cert)
|
cert_list.append(cert)
|
||||||
|
|
||||||
return cert_list
|
return cert_list
|
||||||
|
@ -223,5 +286,6 @@ def get_ssh_cert(serial: str) -> Union[sshCert, None]:
|
||||||
cert = ssh_cert.cert.from_serial(db, serial)
|
cert = ssh_cert.cert.from_serial(db, serial)
|
||||||
if cert is None:
|
if cert is None:
|
||||||
return None
|
return None
|
||||||
cert.status = str(cert.status)
|
cert.type = getattr(sshCertType, cert.type.name)
|
||||||
|
cert.status = getattr(certStatus, cert.status.name)
|
||||||
return cert
|
return cert
|
||||||
|
|
|
@ -5,6 +5,7 @@ import mariadb
|
||||||
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
|
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from struct import unpack
|
from struct import unpack
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class list:
|
class list:
|
||||||
|
@ -58,12 +59,7 @@ class cert:
|
||||||
cert = serialization.load_ssh_public_identity(cert_pub_id)
|
cert = serialization.load_ssh_public_identity(cert_pub_id)
|
||||||
self.serial = str(cert.serial)
|
self.serial = str(cert.serial)
|
||||||
self.alg = cert_alg
|
self.alg = cert_alg
|
||||||
if cert.type == serialization.SSHCertificateType.USER:
|
self.type = cert.type
|
||||||
self.type = "User"
|
|
||||||
elif cert.type == serialization.SSHCertificateType.HOST:
|
|
||||||
self.type = "Host"
|
|
||||||
else:
|
|
||||||
self.type = "Unknown"
|
|
||||||
self.key_id = cert.key_id
|
self.key_id = cert.key_id
|
||||||
self.principals = cert.valid_principals
|
self.principals = cert.valid_principals
|
||||||
self.not_after = cert.valid_before
|
self.not_after = cert.valid_before
|
||||||
|
@ -96,11 +92,11 @@ class cert:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.revoked_at is not None and self.revoked_at < now_with_tz:
|
if self.revoked_at is not None and self.revoked_at < now_with_tz:
|
||||||
self.status = status(status.REVOKED)
|
self.status = status.REVOKED
|
||||||
elif self.not_after < now_with_tz:
|
elif self.not_after < now_with_tz:
|
||||||
self.status = status(status.EXPIRED)
|
self.status = status.EXPIRED
|
||||||
else:
|
else:
|
||||||
self.status = status(status.VALID)
|
self.status = status.VALID
|
||||||
|
|
||||||
def get_cert(self, db, cert_serial):
|
def get_cert(self, db, cert_serial):
|
||||||
cur = db.cursor()
|
cur = db.cursor()
|
||||||
|
@ -142,20 +138,7 @@ class cert:
|
||||||
return key_str, key_type, key_hash
|
return key_str, key_type, key_hash
|
||||||
|
|
||||||
|
|
||||||
class status:
|
class status(Enum):
|
||||||
REVOKED = 1
|
REVOKED = 1
|
||||||
EXPIRED = 2
|
EXPIRED = 2
|
||||||
VALID = 3
|
VALID = 3
|
||||||
|
|
||||||
def __init__(self, status):
|
|
||||||
self.value = status
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
if self.value == self.EXPIRED:
|
|
||||||
return "Expired"
|
|
||||||
elif self.value == self.REVOKED:
|
|
||||||
return "Revoked"
|
|
||||||
elif self.value == self.VALID:
|
|
||||||
return "Valid"
|
|
||||||
else:
|
|
||||||
return "Undefined"
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import mariadb
|
||||||
from cryptography import x509
|
from cryptography import x509
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class list:
|
class list:
|
||||||
|
@ -97,11 +98,11 @@ class cert:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.revoked_at is not None and self.revoked_at < now_with_tz:
|
if self.revoked_at is not None and self.revoked_at < now_with_tz:
|
||||||
self.status = status(status.REVOKED)
|
self.status = status.REVOKED
|
||||||
elif self.not_after < now_with_tz:
|
elif self.not_after < now_with_tz:
|
||||||
self.status = status(status.EXPIRED)
|
self.status = status.EXPIRED
|
||||||
else:
|
else:
|
||||||
self.status = status(status.VALID)
|
self.status = status.VALID
|
||||||
|
|
||||||
def get_cert(self, db, cert_serial):
|
def get_cert(self, db, cert_serial):
|
||||||
cur = db.cursor()
|
cur = db.cursor()
|
||||||
|
@ -153,20 +154,7 @@ class cert:
|
||||||
return sans
|
return sans
|
||||||
|
|
||||||
|
|
||||||
class status:
|
class status(Enum):
|
||||||
REVOKED = 1
|
REVOKED = 1
|
||||||
EXPIRED = 2
|
EXPIRED = 2
|
||||||
VALID = 3
|
VALID = 3
|
||||||
|
|
||||||
def __init__(self, status):
|
|
||||||
self.value = status
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
if self.value == self.EXPIRED:
|
|
||||||
return "Expired"
|
|
||||||
elif self.value == self.REVOKED:
|
|
||||||
return "Revoked"
|
|
||||||
elif self.value == self.VALID:
|
|
||||||
return "Valid"
|
|
||||||
else:
|
|
||||||
return "Undefined"
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue