Initial commit

This commit is contained in:
Benjamin Collet 2025-03-23 18:07:32 +01:00
commit 7c2b4989cc
Signed by: bcollet
SSH key fingerprint: SHA256:8UJspOIcCOS+MtSOcnuq2HjKFube4ox1s/+A62ixov4
7 changed files with 587 additions and 0 deletions

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
.python-version
__pycache__/
config.yaml

11
Dockerfile Normal file
View file

@ -0,0 +1,11 @@
FROM python:3.9
WORKDIR /app
COPY ./requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt
COPY ./step-ca-inspector /app/step-ca-inspector
CMD ["fastapi", "run", "step-ca-inspector/main.py", "--port", "8080", "--proxy-headers"]

9
requirements.txt Normal file
View file

@ -0,0 +1,9 @@
PyYAML
cryptography
mariadb
python-dateutil
uvicorn
prometheus-client
fastapi[standard]
fastapi_utils
typing_inspect

View file

@ -0,0 +1,26 @@
import os
import sys
import yaml
class config:
@classmethod
def __init__(self):
config_path = os.environ.get("STEP_CA_CERTAPI_CONFIGURATION")
if config_path is None:
print("No configuration file found")
sys.exit(1)
try:
with open(config_path) as ymlfile:
cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)
except IOError:
print("Cannot read configuration file")
sys.exit(1)
for k, v in cfg.items():
setattr(self, k, v)
for setting in ["database"]:
if not hasattr(self, setting):
print(f"Mandatory setting {setting} is not configured.")
sys.exit(1)

213
step-ca-inspector/main.py Normal file
View file

@ -0,0 +1,213 @@
from fastapi import FastAPI, HTTPException
from fastapi_utils.tasks import repeat_every
from prometheus_client import make_asgi_app, Gauge
from models import x509_cert, ssh_cert
from config import config
from pydantic import BaseModel
from typing import List, Union
from datetime import datetime
import mariadb
import sys
config()
try:
db = mariadb.connect(**config.database)
except Exception as e:
print(f"Could not connect to database: {e}")
sys.exit(1)
app = FastAPI(title="step-ca Inspector API")
x509_label_names = ["subject", "san", "serial", "provisioner", "provisioner_type"]
x509_cert_not_before = Gauge(
"step_ca_x509_certificate_not_before_timestamp_seconds",
"Certificate not valid before timestamp",
x509_label_names,
)
x509_cert_not_after = Gauge(
"step_ca_x509_certificate_not_after_timestamp_seconds",
"Certificate not valid after timestamp",
x509_label_names,
)
x509_cert_revoked_at = Gauge(
"step_ca_x509_certificate_revoked_at_timestamp_seconds",
"Certificate not valid after timestamp",
x509_label_names,
)
x509_cert_status = Gauge(
"step_ca_x509_certificate_status",
"Certificate status",
x509_label_names,
)
ssh_label_names = ["key_id", "principals", "serial", "certificate_type"]
ssh_cert_not_before = Gauge(
"step_ca_ssh_certificate_not_before_timestamp_seconds",
"Certificate not valid before timestamp",
ssh_label_names,
)
ssh_cert_not_after = Gauge(
"step_ca_ssh_certificate_not_after_timestamp_seconds",
"Certificate not valid after timestamp",
ssh_label_names,
)
ssh_cert_revoked_at = Gauge(
"step_ca_ssh_certificate_revoked_at_timestamp_seconds",
"Certificate not valid after timestamp",
ssh_label_names,
)
ssh_cert_status = Gauge(
"step_ca_ssh_certificate_status",
"Certificate status",
ssh_label_names,
)
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
class provisioner(BaseModel):
id: str
name: str
type: str
class sanName(BaseModel):
type: str
value: str
class x509Cert(BaseModel):
serial: str
subject: str
san_names: List[sanName] = []
provisioner: provisioner
not_after: int
not_before: int
revoked_at: Union[int, None] = None
status: str
sha256: str
sha1: str
md5: str
pub_key: str
pub_alg: str
sig_alg: str
issuer: str
pem: str
class sshCert(BaseModel):
serial: str
alg: str
type: str
key_id: str
principals: List[str] = []
not_after: int
not_before: int
revoked_at: Union[int, None] = None
status: str
signing_key: str
signing_key_type: str
signing_key_hash: str
public_key: str
public_key_type: str
public_key_hash: str
public_identity: str
extensions: dict = {}
@app.on_event("startup")
@repeat_every(seconds=15, raise_exceptions=False)
async def update_metrics():
x509_certs = x509_cert.list(db=db)
for cert in x509_certs:
labels = {
"subject": cert.subject,
"san": ",".join(f"{x['type']}:{x['value']}" for x in cert.san_names),
"serial": cert.serial,
"provisioner": cert.provisioner["name"],
"provisioner_type": cert.provisioner["type"],
}
x509_cert_not_after.labels(**labels).set(cert.not_after)
x509_cert_not_before.labels(**labels).set(cert.not_before)
if cert.revoked_at is not None:
x509_cert_revoked_at.labels(**labels).set(cert.revoked_at)
x509_cert_status.labels(**labels).set(cert.status.value)
ssh_certs = ssh_cert.list(db=db)
for cert in ssh_certs:
labels = {
"principals": ",".join([x.decode() for x in cert.principals]),
"serial": cert.serial,
"key_id": cert.key_id.decode(),
"certificate_type": cert.type,
}
ssh_cert_not_after.labels(**labels).set(cert.not_after)
ssh_cert_not_before.labels(**labels).set(cert.not_before)
if cert.revoked_at is not None:
ssh_cert_revoked_at.labels(**labels).set(cert.revoked_at)
ssh_cert_status.labels(**labels).set(cert.status.value)
@app.get("/x509/certs", tags=["x509"])
def list_x509_certs(
sort_key: str = "not_after", revoked: bool = False, expired: bool = False
) -> list[x509Cert]:
certs = x509_cert.list(db, sort_key=sort_key)
cert_list = []
for cert in certs:
if cert.status.value == x509_cert.status.EXPIRED and not expired:
continue
if cert.status.value == x509_cert.status.REVOKED and not revoked:
continue
cert.status = str(cert.status)
cert_list.append(cert)
return cert_list
@app.get("/x509/certs/{serial}", tags=["x509"])
def get_x509_cert(serial: str) -> Union[x509Cert, None]:
cert = x509_cert.cert.from_serial(db, serial)
if cert is None:
return None
cert.status = str(cert.status)
return cert
@app.get("/ssh/certs", tags=["ssh"])
def list_ssh_certs(
sort_key: str = "not_after", revoked: bool = False, expired: bool = False
) -> list[sshCert]:
certs = ssh_cert.list(db, sort_key=sort_key)
cert_list = []
for cert in certs:
if cert.status.value == ssh_cert.status.EXPIRED and not expired:
continue
if cert.status.value == ssh_cert.status.REVOKED and not revoked:
continue
cert.status = str(cert.status)
cert_list.append(cert)
return cert_list
@app.get("/ssh/certs/{serial}", tags=["ssh"])
def get_ssh_cert(serial: str) -> Union[sshCert, None]:
cert = ssh_cert.cert.from_serial(db, serial)
if cert is None:
return None
cert.status = str(cert.status)
return cert

View file

@ -0,0 +1,155 @@
import base64
import dateutil
import json
import mariadb
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
from datetime import datetime, timedelta, timezone
from struct import unpack
class list:
def __new__(cls, db, sort_key=None):
cls.certs = []
cur = db.cursor()
cur.execute(
"""SELECT ssh_certs.nvalue AS cert,
revoked_ssh_certs.nvalue AS revoked
FROM ssh_certs
LEFT JOIN revoked_ssh_certs USING(nkey)"""
)
for result in cur:
cert_object = cert(result)
cls.certs.append(cert_object)
cur.close()
if sort_key is not None:
cls.certs.sort(key=lambda item: getattr(item, sort_key))
return cls.certs
class cert:
def __init__(self, cert):
(cert_raw, cert_revoked_raw) = cert
size = unpack(">I", cert_raw[:4])[0] + 4
alg = cert_raw[4:size]
cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)])
if cert_revoked_raw is not None:
cert_revoked = json.loads(cert_revoked_raw)
else:
cert_revoked = None
self.load(cert_pub_id, cert_revoked, alg)
@classmethod
def from_serial(cls, db, serial):
cert = cls.get_cert(cls, db, serial)
if cert is None:
return None
return cls(cert=cert)
def load(self, cert_pub_id, cert_revoked, cert_alg):
cert = serialization.load_ssh_public_identity(cert_pub_id)
self.serial = str(cert.serial)
self.alg = cert_alg
if cert.type == serialization.SSHCertificateType.USER:
self.type = "User"
self.key_id = cert.key_id
self.principals = cert.valid_principals
self.not_after = cert.valid_before
self.not_before = cert.valid_after
# TODO: Implement critical options parsing
# cert.critical_options
self.extensions = cert.extensions
(self.signing_key, self.signing_key_type, self.signing_key_hash) = (
self.get_public_key_params(cert.signature_key())
)
(self.public_key, self.public_key_type, self.public_key_hash) = (
self.get_public_key_params(cert.public_key())
)
self.public_identity = cert.public_bytes()
if cert_revoked is not None:
self.revoked_at = datetime.timestamp(
dateutil.parser.isoparse(cert_revoked.get("RevokedAt")).replace(
microsecond=0
)
)
else:
self.revoked_at = None
now_with_tz = datetime.timestamp(
datetime.now(timezone.utc).replace(microsecond=0)
)
if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED)
elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED)
else:
self.status = status(status.VALID)
def get_cert(self, db, cert_serial):
cur = db.cursor()
cur.execute(
"""SELECT ssh_certs.nvalue AS cert,
revoked_ssh_certs.nvalue AS revoked
FROM ssh_certs
LEFT JOIN revoked_ssh_certs USING(nkey)
WHERE nkey=?""",
(cert_serial,),
)
if cur.rowcount > 0:
cert = cur.fetchone()
else:
cert = None
cur.close()
return cert
def get_public_key_params(self, public_key):
if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey):
key_type = "ECDSA"
elif isinstance(public_key, asymmetric.ed25519.Ed25519PublicKey):
key_type = "ED25519"
elif isinstance(public_key, asymmetric.rsa.RSAPublicKey):
key_type = "RSA"
key_str = public_key.public_bytes(
serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
)
key_data = key_str.strip().split()[1]
digest = hashes.Hash(hashes.SHA256())
digest.update(base64.b64decode(key_data))
hash_sha256 = digest.finalize()
key_hash = base64.b64encode(hash_sha256)
return key_str, key_type, key_hash
class status:
REVOKED = 1
EXPIRED = 2
VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"

View file

@ -0,0 +1,170 @@
import binascii
import dateutil
import json
import mariadb
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from datetime import datetime, timedelta, timezone
class list:
def __new__(cls, db, sort_key=None):
cls.certs = []
cur = db.cursor()
cur.execute(
"""SELECT x509_certs.nvalue AS cert,
x509_certs_data.nvalue AS data,
revoked_x509_certs.nvalue AS revoked
FROM x509_certs
INNER JOIN x509_certs_data USING(nkey)
LEFT JOIN revoked_x509_certs USING(nkey)"""
)
for result in cur:
cert_object = cert(result)
cls.certs.append(cert_object)
cur.close()
if sort_key is not None:
cls.certs.sort(key=lambda item: getattr(item, sort_key))
return cls.certs
class cert:
def __init__(self, cert):
(cert_der, cert_data_raw, cert_revoked_raw) = cert
cert_data = json.loads(cert_data_raw)
if cert_revoked_raw is not None:
cert_revoked = json.loads(cert_revoked_raw)
else:
cert_revoked = None
self.load(cert_der, cert_data, cert_revoked)
@classmethod
def from_serial(cls, db, serial):
cert = cls.get_cert(cls, db, serial)
if cert is None:
return None
return cls(cert=cert)
def load(self, cert_der, cert_data, cert_revoked):
cert = x509.load_der_x509_certificate(cert_der)
self.pem = cert.public_bytes(serialization.Encoding.PEM)
self.serial = str(cert.serial_number)
self.sha256 = binascii.b2a_hex(cert.fingerprint(hashes.SHA256()))
self.sha1 = binascii.b2a_hex(cert.fingerprint(hashes.SHA1()))
self.md5 = binascii.b2a_hex(cert.fingerprint(hashes.MD5()))
self.pub_key = cert.public_key().public_bytes(
serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo
)
self.pub_alg = cert.public_key_algorithm_oid._name
self.sig_alg = cert.signature_algorithm_oid._name
self.issuer = cert.issuer.rfc4514_string()
self.subject = cert.subject.rfc4514_string({x509.NameOID.EMAIL_ADDRESS: "E"})
self.not_before = datetime.timestamp(
cert.not_valid_before_utc.replace(microsecond=0)
)
self.not_after = datetime.timestamp(
cert.not_valid_after_utc.replace(microsecond=0)
)
try:
san_data = cert.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
self.san_names = self.get_sans(san_data)
except x509.extensions.ExtensionNotFound:
self.san_names = []
self.provisioner = cert_data.get("provisioner", None)
if cert_revoked is not None:
self.revoked_at = datetime.timestamp(
dateutil.parser.isoparse(cert_revoked.get("RevokedAt")).replace(
microsecond=0
)
)
else:
self.revoked_at = None
now_with_tz = datetime.timestamp(
datetime.now(timezone.utc).replace(microsecond=0)
)
if self.revoked_at is not None and self.revoked_at < now_with_tz:
self.status = status(status.REVOKED)
elif self.not_after < now_with_tz:
self.status = status(status.EXPIRED)
else:
self.status = status(status.VALID)
def get_cert(self, db, cert_serial):
cur = db.cursor()
cur.execute(
"""SELECT x509_certs.nvalue AS cert,
x509_certs_data.nvalue AS data,
revoked_x509_certs.nvalue AS revoked
FROM x509_certs
INNER JOIN x509_certs_data USING(nkey)
LEFT JOIN revoked_x509_certs USING(nkey)
WHERE nkey=?""",
(cert_serial,),
)
if cur.rowcount > 0:
cert = cur.fetchone()
else:
cert = None
cur.close()
return cert
def get_sans(self, san_data):
sans = []
for san_value in san_data.value:
san = {}
if isinstance(san_value, x509.general_name.DNSName):
san["type"] = "DNS"
elif isinstance(san_value, x509.general_name.UniformResourceIdentifier):
san["type"] = "URI"
elif isinstance(san_value, x509.general_name.RFC822Name):
san["type"] = "Email"
elif isinstance(san_value, x509.general_name.IPAddress):
san["type"] = "IP"
elif isinstance(san_value, x509.general_name.DirectoryName):
san["type"] = "DirectoryName"
elif isinstance(san_value, x509.general_name.RegisteredID):
san["type"] = "RegisteredID"
elif isinstance(san_value, x509.general_name.OtherName):
san["type"] = "Other ({san_value.type_id})"
else:
continue
san["value"] = san_value.value
sans.append(san)
return sans
class status:
REVOKED = 1
EXPIRED = 2
VALID = 3
def __init__(self, status):
self.value = status
def __str__(self):
if self.value == self.EXPIRED:
return "Expired"
elif self.value == self.REVOKED:
return "Revoked"
elif self.value == self.VALID:
return "Valid"
else:
return "Undefined"