Compare commits

...

3 commits
v0.0.4 ... main

3 changed files with 109 additions and 60 deletions
step-ca-inspector

View file

@ -1,4 +1,4 @@
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Query
from fastapi_utils.tasks import repeat_every
from prometheus_client import make_asgi_app, Gauge
from models import x509_cert, ssh_cert
@ -6,6 +6,7 @@ from config import config
from pydantic import BaseModel
from typing import List, Union
from datetime import datetime
from enum import Enum
import mariadb
import sys
@ -68,10 +69,30 @@ metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
class certStatus(str, Enum):
REVOKED = "Revoked"
EXPIRED = "Expired"
VALID = "Valid"
class provisionerType(str, Enum):
# https://github.com/smallstep/certificates/blob/938a4da5adf2d32f36ffd06922e5c66956dfff41/authority/provisioner/provisioner.go#L200-L223
ACME = "ACME"
AWS = "AWS"
GCP = "GCP"
JWK = "JWK"
Nebula = "Nebula"
OIDC = "OIDC"
SCEP = "SCEP"
SSHPOP = "SSHPOP"
X5C = "X5C"
K8sSA = "K8sSA"
class provisioner(BaseModel):
id: str
name: str
type: str
type: provisionerType
class sanName(BaseModel):
@ -87,7 +108,7 @@ class x509Cert(BaseModel):
not_after: int
not_before: int
revoked_at: Union[int, None] = None
status: str
status: certStatus
sha256: str
sha1: str
md5: str
@ -98,16 +119,21 @@ class x509Cert(BaseModel):
pem: str
class sshCertType(str, Enum):
HOST = "Host"
USER = "User"
class sshCert(BaseModel):
serial: str
alg: str
type: str
type: sshCertType
key_id: str
principals: List[str] = []
not_after: int
not_before: int
revoked_at: Union[int, None] = None
status: str
status: certStatus
signing_key: str
signing_key_type: str
signing_key_hash: str
@ -145,7 +171,7 @@ async def update_metrics():
"principals": ",".join([x.decode() for x in cert.principals]),
"serial": cert.serial,
"key_id": cert.key_id.decode(),
"certificate_type": cert.type,
"certificate_type": getattr(sshCertType, cert.type.name).value,
}
ssh_cert_not_after.labels(**labels).set(cert.not_after)
@ -159,18 +185,45 @@ async def update_metrics():
@app.get("/x509/certs", tags=["x509"])
def list_x509_certs(
sort_key: str = "not_after", revoked: bool = False, expired: bool = False
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
revoked: bool = Query(False, deprecated=True),
expired: bool = Query(False, deprecated=True),
cert_status: list[certStatus] = Query(["Valid"]),
subject: str = None,
san: str = None,
provisioner: str = None,
provisioner_type: list[provisionerType] = Query(list(provisionerType)),
) -> list[x509Cert]:
certs = x509_cert.list(db, sort_key=sort_key)
cert_list = []
for cert in certs:
if cert.status.value == x509_cert.status.EXPIRED and not expired:
continue
if cert.status.value == x509_cert.status.REVOKED and not revoked:
continue
if cert.status.name not in [item.name for item in cert_status]:
# TODO: Remove handling of deprecated parameters
if not expired and not revoked:
continue
if cert.status == x509_cert.status.EXPIRED and not expired:
continue
if cert.status == x509_cert.status.REVOKED and not revoked:
continue
cert.status = str(cert.status)
if (
provisioner is not None
and provisioner.casefold() not in cert.provisioner["name"].casefold()
):
continue
if cert.provisioner["type"] not in [item.name for item in provisioner_type]:
continue
if subject is not None and subject.casefold() not in cert.subject.casefold():
continue
if san is not None:
for cert_san_name in cert.san_names:
if san.casefold() in cert_san_name["value"].casefold():
break
else:
continue
cert.status = getattr(certStatus, cert.status.name)
cert_list.append(cert)
return cert_list
@ -181,24 +234,48 @@ def get_x509_cert(serial: str) -> Union[x509Cert, None]:
cert = x509_cert.cert.from_serial(db, serial)
if cert is None:
return None
cert.status = str(cert.status)
cert.status = getattr(certStatus, cert.status.name)
return cert
@app.get("/ssh/certs", tags=["ssh"])
def list_ssh_certs(
sort_key: str = "not_after", revoked: bool = False, expired: bool = False
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
revoked: bool = Query(False, deprecated=True),
expired: bool = Query(False, deprecated=True),
cert_type: list[sshCertType] = Query(["Host", "User"]),
cert_status: list[certStatus] = Query(["Valid"]),
key: str = None,
principal: str = None,
) -> list[sshCert]:
certs = ssh_cert.list(db, sort_key=sort_key)
cert_list = []
for cert in certs:
if cert.status.value == ssh_cert.status.EXPIRED and not expired:
continue
if cert.status.value == ssh_cert.status.REVOKED and not revoked:
continue
if cert.status.name not in [item.name for item in cert_status]:
# TODO: Remove handling of deprecated parameters
if not expired and not revoked:
continue
cert.status = str(cert.status)
if cert.status == ssh_cert.status.EXPIRED and not expired:
continue
if cert.status == ssh_cert.status.REVOKED and not revoked:
continue
if cert.type.name not in [item.name for item in cert_type]:
continue
if key is not None and key.casefold() not in str(cert.key_id).casefold():
continue
if principal is not None:
for cert_principal in cert.principals:
if principal.casefold() in str(cert_principal).casefold():
break
else:
continue
cert.type = getattr(sshCertType, cert.type.name)
cert.status = getattr(certStatus, cert.status.name)
cert_list.append(cert)
return cert_list
@ -209,5 +286,6 @@ def get_ssh_cert(serial: str) -> Union[sshCert, None]:
cert = ssh_cert.cert.from_serial(db, serial)
if cert is None:
return None
cert.status = str(cert.status)
cert.type = getattr(sshCertType, cert.type.name)
cert.status = getattr(certStatus, cert.status.name)
return cert

View file

@ -5,6 +5,7 @@ import mariadb
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
from datetime import datetime, timedelta, timezone
from struct import unpack
from enum import Enum
class list:
@ -58,12 +59,7 @@ class cert:
cert = serialization.load_ssh_public_identity(cert_pub_id)
self.serial = str(cert.serial)
self.alg = cert_alg
if cert.type == serialization.SSHCertificateType.USER:
self.type = "User"
elif cert.type == serialization.SSHCertificateType.HOST:
self.type = "Host"
else:
self.type = "Unknown"
self.type = cert.type
self.key_id = cert.key_id
self.principals = cert.valid_principals
self.not_after = cert.valid_before
@ -96,11 +92,11 @@ class cert:
)
if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED)
self.status = status.REVOKED
elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED)
self.status = status.EXPIRED
else:
self.status = status(status.VALID)
self.status = status.VALID
def get_cert(self, db, cert_serial):
cur = db.cursor()
@ -142,20 +138,7 @@ class cert:
return key_str, key_type, key_hash
class status:
class status(Enum):
REVOKED = 1
EXPIRED = 2
VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"

View file

@ -5,6 +5,7 @@ import mariadb
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from datetime import datetime, timedelta, timezone
from enum import Enum
class list:
@ -97,11 +98,11 @@ class cert:
)
if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED)
self.status = status.REVOKED
elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED)
self.status = status.EXPIRED
else:
self.status = status(status.VALID)
self.status = status.VALID
def get_cert(self, db, cert_serial):
cur = db.cursor()
@ -153,20 +154,7 @@ class cert:
return sans
class status:
class status(Enum):
REVOKED = 1
EXPIRED = 2
VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"