Compare commits
5 commits
Author | SHA1 | Date | |
---|---|---|---|
c5132eb828 | |||
204f85fb8e | |||
0cb5337e32 | |||
fdb4926260 | |||
8dcd79d427 |
4 changed files with 110 additions and 57 deletions
|
@ -1,4 +1,4 @@
|
||||||
FROM python:3.9
|
FROM --platform=linux/amd64 python:3.12
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, Query
|
||||||
from fastapi_utils.tasks import repeat_every
|
from fastapi_utils.tasks import repeat_every
|
||||||
from prometheus_client import make_asgi_app, Gauge
|
from prometheus_client import make_asgi_app, Gauge
|
||||||
from models import x509_cert, ssh_cert
|
from models import x509_cert, ssh_cert
|
||||||
|
@ -6,6 +6,7 @@ from config import config
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
import mariadb
|
import mariadb
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
@ -68,10 +69,30 @@ metrics_app = make_asgi_app()
|
||||||
app.mount("/metrics", metrics_app)
|
app.mount("/metrics", metrics_app)
|
||||||
|
|
||||||
|
|
||||||
|
class certStatus(str, Enum):
|
||||||
|
REVOKED = "Revoked"
|
||||||
|
EXPIRED = "Expired"
|
||||||
|
VALID = "Valid"
|
||||||
|
|
||||||
|
|
||||||
|
class provisionerType(str, Enum):
|
||||||
|
# https://github.com/smallstep/certificates/blob/938a4da5adf2d32f36ffd06922e5c66956dfff41/authority/provisioner/provisioner.go#L200-L223
|
||||||
|
ACME = "ACME"
|
||||||
|
AWS = "AWS"
|
||||||
|
GCP = "GCP"
|
||||||
|
JWK = "JWK"
|
||||||
|
Nebula = "Nebula"
|
||||||
|
OIDC = "OIDC"
|
||||||
|
SCEP = "SCEP"
|
||||||
|
SSHPOP = "SSHPOP"
|
||||||
|
X5C = "X5C"
|
||||||
|
K8sSA = "K8sSA"
|
||||||
|
|
||||||
|
|
||||||
class provisioner(BaseModel):
|
class provisioner(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
type: str
|
type: provisionerType
|
||||||
|
|
||||||
|
|
||||||
class sanName(BaseModel):
|
class sanName(BaseModel):
|
||||||
|
@ -87,7 +108,7 @@ class x509Cert(BaseModel):
|
||||||
not_after: int
|
not_after: int
|
||||||
not_before: int
|
not_before: int
|
||||||
revoked_at: Union[int, None] = None
|
revoked_at: Union[int, None] = None
|
||||||
status: str
|
status: certStatus
|
||||||
sha256: str
|
sha256: str
|
||||||
sha1: str
|
sha1: str
|
||||||
md5: str
|
md5: str
|
||||||
|
@ -98,16 +119,21 @@ class x509Cert(BaseModel):
|
||||||
pem: str
|
pem: str
|
||||||
|
|
||||||
|
|
||||||
|
class sshCertType(str, Enum):
|
||||||
|
HOST = "Host"
|
||||||
|
USER = "User"
|
||||||
|
|
||||||
|
|
||||||
class sshCert(BaseModel):
|
class sshCert(BaseModel):
|
||||||
serial: str
|
serial: str
|
||||||
alg: str
|
alg: str
|
||||||
type: str
|
type: sshCertType
|
||||||
key_id: str
|
key_id: str
|
||||||
principals: List[str] = []
|
principals: List[str] = []
|
||||||
not_after: int
|
not_after: int
|
||||||
not_before: int
|
not_before: int
|
||||||
revoked_at: Union[int, None] = None
|
revoked_at: Union[int, None] = None
|
||||||
status: str
|
status: certStatus
|
||||||
signing_key: str
|
signing_key: str
|
||||||
signing_key_type: str
|
signing_key_type: str
|
||||||
signing_key_hash: str
|
signing_key_hash: str
|
||||||
|
@ -145,7 +171,7 @@ async def update_metrics():
|
||||||
"principals": ",".join([x.decode() for x in cert.principals]),
|
"principals": ",".join([x.decode() for x in cert.principals]),
|
||||||
"serial": cert.serial,
|
"serial": cert.serial,
|
||||||
"key_id": cert.key_id.decode(),
|
"key_id": cert.key_id.decode(),
|
||||||
"certificate_type": cert.type,
|
"certificate_type": getattr(sshCertType, cert.type.name).value,
|
||||||
}
|
}
|
||||||
|
|
||||||
ssh_cert_not_after.labels(**labels).set(cert.not_after)
|
ssh_cert_not_after.labels(**labels).set(cert.not_after)
|
||||||
|
@ -159,18 +185,45 @@ async def update_metrics():
|
||||||
|
|
||||||
@app.get("/x509/certs", tags=["x509"])
|
@app.get("/x509/certs", tags=["x509"])
|
||||||
def list_x509_certs(
|
def list_x509_certs(
|
||||||
sort_key: str = "not_after", revoked: bool = False, expired: bool = False
|
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
|
||||||
|
revoked: bool = Query(False, deprecated=True),
|
||||||
|
expired: bool = Query(False, deprecated=True),
|
||||||
|
cert_status: list[certStatus] = Query(["Valid"]),
|
||||||
|
subject: str = None,
|
||||||
|
san: str = None,
|
||||||
|
provisioner: str = None,
|
||||||
|
provisioner_type: list[provisionerType] = Query(list(provisionerType)),
|
||||||
) -> list[x509Cert]:
|
) -> list[x509Cert]:
|
||||||
certs = x509_cert.list(db, sort_key=sort_key)
|
certs = x509_cert.list(db, sort_key=sort_key)
|
||||||
cert_list = []
|
cert_list = []
|
||||||
|
|
||||||
for cert in certs:
|
for cert in certs:
|
||||||
if cert.status.value == x509_cert.status.EXPIRED and not expired:
|
if cert.status.name not in [item.name for item in cert_status]:
|
||||||
continue
|
# TODO: Remove handling of deprecated parameters
|
||||||
if cert.status.value == x509_cert.status.REVOKED and not revoked:
|
if not expired and not revoked:
|
||||||
continue
|
continue
|
||||||
|
if cert.status == x509_cert.status.EXPIRED and not expired:
|
||||||
|
continue
|
||||||
|
if cert.status == x509_cert.status.REVOKED and not revoked:
|
||||||
|
continue
|
||||||
|
|
||||||
cert.status = str(cert.status)
|
if (
|
||||||
|
provisioner is not None
|
||||||
|
and provisioner.casefold() not in cert.provisioner["name"].casefold()
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if cert.provisioner["type"] not in [item.name for item in provisioner_type]:
|
||||||
|
continue
|
||||||
|
if subject is not None and subject.casefold() not in cert.subject.casefold():
|
||||||
|
continue
|
||||||
|
if san is not None:
|
||||||
|
for cert_san_name in cert.san_names:
|
||||||
|
if san.casefold() in cert_san_name["value"].casefold():
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cert.status = getattr(certStatus, cert.status.name)
|
||||||
cert_list.append(cert)
|
cert_list.append(cert)
|
||||||
|
|
||||||
return cert_list
|
return cert_list
|
||||||
|
@ -181,24 +234,48 @@ def get_x509_cert(serial: str) -> Union[x509Cert, None]:
|
||||||
cert = x509_cert.cert.from_serial(db, serial)
|
cert = x509_cert.cert.from_serial(db, serial)
|
||||||
if cert is None:
|
if cert is None:
|
||||||
return None
|
return None
|
||||||
cert.status = str(cert.status)
|
cert.status = getattr(certStatus, cert.status.name)
|
||||||
return cert
|
return cert
|
||||||
|
|
||||||
|
|
||||||
@app.get("/ssh/certs", tags=["ssh"])
|
@app.get("/ssh/certs", tags=["ssh"])
|
||||||
def list_ssh_certs(
|
def list_ssh_certs(
|
||||||
sort_key: str = "not_after", revoked: bool = False, expired: bool = False
|
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
|
||||||
|
revoked: bool = Query(False, deprecated=True),
|
||||||
|
expired: bool = Query(False, deprecated=True),
|
||||||
|
cert_type: list[sshCertType] = Query(["Host", "User"]),
|
||||||
|
cert_status: list[certStatus] = Query(["Valid"]),
|
||||||
|
key: str = None,
|
||||||
|
principal: str = None,
|
||||||
) -> list[sshCert]:
|
) -> list[sshCert]:
|
||||||
certs = ssh_cert.list(db, sort_key=sort_key)
|
certs = ssh_cert.list(db, sort_key=sort_key)
|
||||||
cert_list = []
|
cert_list = []
|
||||||
|
|
||||||
for cert in certs:
|
for cert in certs:
|
||||||
if cert.status.value == ssh_cert.status.EXPIRED and not expired:
|
if cert.status.name not in [item.name for item in cert_status]:
|
||||||
continue
|
# TODO: Remove handling of deprecated parameters
|
||||||
if cert.status.value == ssh_cert.status.REVOKED and not revoked:
|
if not expired and not revoked:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cert.status = str(cert.status)
|
if cert.status == ssh_cert.status.EXPIRED and not expired:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cert.status == ssh_cert.status.REVOKED and not revoked:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cert.type.name not in [item.name for item in cert_type]:
|
||||||
|
continue
|
||||||
|
if key is not None and key.casefold() not in str(cert.key_id).casefold():
|
||||||
|
continue
|
||||||
|
if principal is not None:
|
||||||
|
for cert_principal in cert.principals:
|
||||||
|
if principal.casefold() in str(cert_principal).casefold():
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cert.type = getattr(sshCertType, cert.type.name)
|
||||||
|
cert.status = getattr(certStatus, cert.status.name)
|
||||||
cert_list.append(cert)
|
cert_list.append(cert)
|
||||||
|
|
||||||
return cert_list
|
return cert_list
|
||||||
|
@ -209,5 +286,6 @@ def get_ssh_cert(serial: str) -> Union[sshCert, None]:
|
||||||
cert = ssh_cert.cert.from_serial(db, serial)
|
cert = ssh_cert.cert.from_serial(db, serial)
|
||||||
if cert is None:
|
if cert is None:
|
||||||
return None
|
return None
|
||||||
cert.status = str(cert.status)
|
cert.type = getattr(sshCertType, cert.type.name)
|
||||||
|
cert.status = getattr(certStatus, cert.status.name)
|
||||||
return cert
|
return cert
|
||||||
|
|
|
@ -5,6 +5,7 @@ import mariadb
|
||||||
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
|
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from struct import unpack
|
from struct import unpack
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class list:
|
class list:
|
||||||
|
@ -58,8 +59,7 @@ class cert:
|
||||||
cert = serialization.load_ssh_public_identity(cert_pub_id)
|
cert = serialization.load_ssh_public_identity(cert_pub_id)
|
||||||
self.serial = str(cert.serial)
|
self.serial = str(cert.serial)
|
||||||
self.alg = cert_alg
|
self.alg = cert_alg
|
||||||
if cert.type == serialization.SSHCertificateType.USER:
|
self.type = cert.type
|
||||||
self.type = "User"
|
|
||||||
self.key_id = cert.key_id
|
self.key_id = cert.key_id
|
||||||
self.principals = cert.valid_principals
|
self.principals = cert.valid_principals
|
||||||
self.not_after = cert.valid_before
|
self.not_after = cert.valid_before
|
||||||
|
@ -92,11 +92,11 @@ class cert:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.revoked_at is not None and self.revoked_at < now_with_tz:
|
if self.revoked_at is not None and self.revoked_at < now_with_tz:
|
||||||
self.status = status(status.REVOKED)
|
self.status = status.REVOKED
|
||||||
elif self.not_after < now_with_tz:
|
elif self.not_after < now_with_tz:
|
||||||
self.status = status(status.EXPIRED)
|
self.status = status.EXPIRED
|
||||||
else:
|
else:
|
||||||
self.status = status(status.VALID)
|
self.status = status.VALID
|
||||||
|
|
||||||
def get_cert(self, db, cert_serial):
|
def get_cert(self, db, cert_serial):
|
||||||
cur = db.cursor()
|
cur = db.cursor()
|
||||||
|
@ -138,20 +138,7 @@ class cert:
|
||||||
return key_str, key_type, key_hash
|
return key_str, key_type, key_hash
|
||||||
|
|
||||||
|
|
||||||
class status:
|
class status(Enum):
|
||||||
REVOKED = 1
|
REVOKED = 1
|
||||||
EXPIRED = 2
|
EXPIRED = 2
|
||||||
VALID = 3
|
VALID = 3
|
||||||
|
|
||||||
def __init__(self, status):
|
|
||||||
self.value = status
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
if self.value == self.EXPIRED:
|
|
||||||
return "Expired"
|
|
||||||
elif self.value == self.REVOKED:
|
|
||||||
return "Revoked"
|
|
||||||
elif self.value == self.VALID:
|
|
||||||
return "Valid"
|
|
||||||
else:
|
|
||||||
return "Undefined"
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import mariadb
|
||||||
from cryptography import x509
|
from cryptography import x509
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class list:
|
class list:
|
||||||
|
@ -97,11 +98,11 @@ class cert:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.revoked_at is not None and self.revoked_at < now_with_tz:
|
if self.revoked_at is not None and self.revoked_at < now_with_tz:
|
||||||
self.status = status(status.REVOKED)
|
self.status = status.REVOKED
|
||||||
elif self.not_after < now_with_tz:
|
elif self.not_after < now_with_tz:
|
||||||
self.status = status(status.EXPIRED)
|
self.status = status.EXPIRED
|
||||||
else:
|
else:
|
||||||
self.status = status(status.VALID)
|
self.status = status.VALID
|
||||||
|
|
||||||
def get_cert(self, db, cert_serial):
|
def get_cert(self, db, cert_serial):
|
||||||
cur = db.cursor()
|
cur = db.cursor()
|
||||||
|
@ -153,20 +154,7 @@ class cert:
|
||||||
return sans
|
return sans
|
||||||
|
|
||||||
|
|
||||||
class status:
|
class status(Enum):
|
||||||
REVOKED = 1
|
REVOKED = 1
|
||||||
EXPIRED = 2
|
EXPIRED = 2
|
||||||
VALID = 3
|
VALID = 3
|
||||||
|
|
||||||
def __init__(self, status):
|
|
||||||
self.value = status
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
if self.value == self.EXPIRED:
|
|
||||||
return "Expired"
|
|
||||||
elif self.value == self.REVOKED:
|
|
||||||
return "Revoked"
|
|
||||||
elif self.value == self.VALID:
|
|
||||||
return "Valid"
|
|
||||||
else:
|
|
||||||
return "Undefined"
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue