Compare commits
No commits in common. "main" and "v0.0.4" have entirely different histories.
10 changed files with 86 additions and 789 deletions
21
LICENSE
21
LICENSE
|
@ -1,21 +0,0 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2025 Benjamin Collet
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
13
README.md
13
README.md
|
@ -1,13 +0,0 @@
|
|||
# StepCA Inspector
|
||||
|
||||
StepCA Inspector is a companion app to
|
||||
[step-ca](https://github.com/smallstep/certificates/) that exposes Prometheus
|
||||
metrics about your CA and offer API endpoints to get x509 and SSH certificate
|
||||
data.
|
||||
|
||||
Additionally it can serve as a
|
||||
[webhook](https://smallstep.com/docs/step-ca/webhooks/) endpoint for
|
||||
certificate validation.
|
||||
|
||||
A CLI client is available
|
||||
[here](https://git.alt.tf/bcollet/step-ca-inspector-client/).
|
|
@ -1,51 +0,0 @@
|
|||
---
|
||||
database:
|
||||
host: "mysql.example.com"
|
||||
user: "stepca_inspector"
|
||||
password: "secret"
|
||||
database: "stepca"
|
||||
ssl: true
|
||||
ssl_verify_cert: true
|
||||
ssl_ca: "/app/root-ca.crt"
|
||||
reconnect: true
|
||||
webhook_config:
|
||||
- id: "<webhook_id>":
|
||||
secret: "<webhook_secret>"
|
||||
plugin:
|
||||
#name: "scep_static"
|
||||
#challenges:
|
||||
#- secret: "<challenge>":
|
||||
# allowed_dns_names:
|
||||
# - "host.example.com"
|
||||
# - "*.example.com"
|
||||
# allowed_email_addresses: []
|
||||
# allowed_ip_addresses: []
|
||||
name: hashicorp_vault
|
||||
hvac_connection:
|
||||
url: https://vault.example.com
|
||||
verify: "/app/root-ca.crt"
|
||||
#hvac_auth_method: token
|
||||
#hvac_token: "<token>"
|
||||
hvac_auth_method: approle
|
||||
hvac_role_id: "<approle_id>"
|
||||
hvac_secret_id: "<approle_secret>"
|
||||
hvac_engine: <engine>
|
||||
hvac_secret_path: "%s/scep"
|
||||
#hvac_challenge_key: "challenge"
|
||||
#hvac_allowed_dns_names_key: "allowed_dns_names"
|
||||
#hvac_allowed_email_addresses_key: "allowed_email_addresses"
|
||||
- id: "<webhook_id>"
|
||||
secret: "<webhook_secret>"
|
||||
plugin:
|
||||
name: "yubikey_embedded_attestation"
|
||||
yubikey_attestation_root: /app/yubico-piv-ca-1.pem
|
||||
yubikey_allowed_serials:
|
||||
- <yubikey_sn>
|
||||
#yubikey_pin_policies:
|
||||
# never: true
|
||||
# once: true
|
||||
# always: true
|
||||
#yubikey_touch_policies:
|
||||
# never: true
|
||||
# always: true
|
||||
# cached: true
|
|
@ -7,6 +7,3 @@ prometheus-client
|
|||
fastapi[standard]
|
||||
fastapi_utils
|
||||
typing_inspect
|
||||
hvac
|
||||
asgi_correlation_id
|
||||
pydantic-settings
|
||||
|
|
|
@ -1,132 +1,26 @@
|
|||
import os
|
||||
import sys
|
||||
import yaml
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
EnvSettingsSource,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
from pydantic import field_validator, ConfigDict, Field
|
||||
from pydantic_core import PydanticCustomError
|
||||
from typing import Optional, Literal, Union, List
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DatabaseSettings(BaseSettings):
|
||||
host: str
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
ssl: bool = False
|
||||
ssl_verify_cert: bool = True
|
||||
ssl_ca: Optional[str] = None
|
||||
reconnect: bool = True
|
||||
|
||||
|
||||
class PluginSettings(BaseSettings):
|
||||
pass
|
||||
|
||||
|
||||
class ChallengeStaticSCEPPluginSettings(BaseSettings):
|
||||
secret: str
|
||||
allowed_dns_names: List[str] = []
|
||||
allowed_email_addresses: List[str] = []
|
||||
allowed_ip_addresses: List[str] = []
|
||||
|
||||
|
||||
class StaticSCEPPluginSettings(PluginSettings):
|
||||
name: Literal["scep_static"]
|
||||
challenges: List[ChallengeStaticSCEPPluginSettings]
|
||||
|
||||
|
||||
class VaultAuthMethod(str, Enum):
|
||||
TOKEN = "token"
|
||||
APPROLE = "approle"
|
||||
|
||||
|
||||
class VaultPluginSettings(PluginSettings):
|
||||
name: Literal["hashicorp_vault"]
|
||||
hvac_connection: dict = {}
|
||||
hvac_auth_method: VaultAuthMethod = VaultAuthMethod.TOKEN
|
||||
hvac_token: Optional[str] = None
|
||||
hvac_role_id: Optional[str] = None
|
||||
hvac_secret_id: Optional[str] = None
|
||||
hvac_engine: str
|
||||
hvac_secret_path: str = "%s"
|
||||
hvac_challenge_key: str = "challenge"
|
||||
hvac_allowed_dns_names_key: str = "allowed_dns_names"
|
||||
hvac_allowed_email_addresses_key: str = "allowed_email_addresses"
|
||||
hvac_allowed_ip_addresses_key: str = "allowed_ip_addresses"
|
||||
hvac_allowed_uris: str = "allowed_uris"
|
||||
|
||||
|
||||
class YubikeyPinPolicySettings(BaseSettings):
|
||||
never: Optional[bool] = True
|
||||
once: Optional[bool] = True
|
||||
always: Optional[bool] = True
|
||||
|
||||
|
||||
class YubikeyTouchPolicySettings(BaseSettings):
|
||||
never: Optional[bool] = True
|
||||
always: Optional[bool] = True
|
||||
cached: Optional[bool] = True
|
||||
|
||||
|
||||
class YubikeyEmbeddedAttestationSettings(PluginSettings):
|
||||
name: Literal["yubikey_embedded_attestation"]
|
||||
yubikey_attestation_root: str
|
||||
yubikey_allowed_serials: List[int] = []
|
||||
yubikey_pin_policies: Optional[YubikeyPinPolicySettings] = (
|
||||
YubikeyPinPolicySettings()
|
||||
)
|
||||
yubikey_touch_policies: Optional[YubikeyTouchPolicySettings] = (
|
||||
YubikeyTouchPolicySettings()
|
||||
)
|
||||
|
||||
|
||||
class WebhookSettings(BaseSettings):
|
||||
id: str
|
||||
secret: str
|
||||
plugin: Union[tuple(PluginSettings.__subclasses__())] = Field(discriminator="name")
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
database: DatabaseSettings
|
||||
webhook_config: list[WebhookSettings]
|
||||
|
||||
@field_validator("webhook_config", mode="after")
|
||||
class config:
|
||||
@classmethod
|
||||
def check_webhook_uniqueness(
|
||||
cls, webhooks: list[WebhookSettings]
|
||||
) -> list[WebhookSettings]:
|
||||
ids = [webhook.id for webhook in webhooks]
|
||||
if len(ids) != len(set(ids)):
|
||||
raise PydanticCustomError(
|
||||
"webhook_id_uniqueness", "Webhooks IDs must be unique"
|
||||
)
|
||||
return webhooks
|
||||
def __init__(self):
|
||||
config_path = os.environ.get("STEP_CA_INSPECTOR_CONFIGURATION")
|
||||
if config_path is None:
|
||||
print("No configuration file found")
|
||||
sys.exit(1)
|
||||
try:
|
||||
with open(config_path) as ymlfile:
|
||||
cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)
|
||||
except IOError:
|
||||
print("Cannot read configuration file")
|
||||
sys.exit(1)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
init_settings,
|
||||
EnvSettingsSource(
|
||||
settings_cls,
|
||||
env_nested_delimiter="__",
|
||||
case_sensitive=False,
|
||||
env_prefix="STEP_CA_INSPECTOR_",
|
||||
),
|
||||
YamlConfigSettingsSource(
|
||||
settings_cls,
|
||||
yaml_file=os.environ.get("STEP_CA_INSPECTOR_CONFIGURATION"),
|
||||
),
|
||||
)
|
||||
for k, v in cfg.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
for setting in ["database"]:
|
||||
if not hasattr(self, setting):
|
||||
print(f"Mandatory setting {setting} is not configured.")
|
||||
sys.exit(1)
|
||||
|
|
|
@ -1,54 +1,25 @@
|
|||
from fastapi import FastAPI, HTTPException, Header, Query, Request, Depends
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi_utils.tasks import repeat_every
|
||||
from prometheus_client import make_asgi_app, Gauge
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from models import x509_cert, ssh_cert
|
||||
from config import config
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Union
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from config import Settings, WebhookSettings
|
||||
from models import x509_cert, ssh_cert
|
||||
from webhook import scep_challenge, x509
|
||||
import asgi_correlation_id
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import mariadb
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def configure_logging():
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.addFilter(asgi_correlation_id.CorrelationIdFilter())
|
||||
logging.basicConfig(
|
||||
handlers=[console_handler],
|
||||
level=os.environ.get("STEP_CA_INSPECTOR_LOGLEVEL", logging.INFO),
|
||||
format="%(levelname)s [%(correlation_id)s] %(message)s",
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI(title="step-ca Inspector API", on_startup=[configure_logging])
|
||||
app.add_middleware(asgi_correlation_id.CorrelationIdMiddleware)
|
||||
|
||||
logger = logging.getLogger()
|
||||
config()
|
||||
|
||||
try:
|
||||
config = Settings()
|
||||
except ValidationError as e:
|
||||
for error in e.errors():
|
||||
logger.error(
|
||||
f"Configuration error: {error['msg']}: {'.'.join(str(node) for node in error['loc'])}"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
db = mariadb.connect(**dict(config.database))
|
||||
db = mariadb.connect(**config.database)
|
||||
except Exception as e:
|
||||
print(f"Could not connect to database: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
app = FastAPI(title="step-ca Inspector API")
|
||||
|
||||
x509_label_names = ["subject", "san", "serial", "provisioner", "provisioner_type"]
|
||||
x509_cert_not_before = Gauge(
|
||||
"step_ca_x509_certificate_not_before_timestamp_seconds",
|
||||
|
@ -97,30 +68,10 @@ metrics_app = make_asgi_app()
|
|||
app.mount("/metrics", metrics_app)
|
||||
|
||||
|
||||
class certStatus(str, Enum):
|
||||
REVOKED = "Revoked"
|
||||
EXPIRED = "Expired"
|
||||
VALID = "Valid"
|
||||
|
||||
|
||||
class provisionerType(str, Enum):
|
||||
# https://github.com/smallstep/certificates/blob/938a4da5adf2d32f36ffd06922e5c66956dfff41/authority/provisioner/provisioner.go#L200-L223
|
||||
ACME = "ACME"
|
||||
AWS = "AWS"
|
||||
GCP = "GCP"
|
||||
JWK = "JWK"
|
||||
Nebula = "Nebula"
|
||||
OIDC = "OIDC"
|
||||
SCEP = "SCEP"
|
||||
SSHPOP = "SSHPOP"
|
||||
X5C = "X5C"
|
||||
K8sSA = "K8sSA"
|
||||
|
||||
|
||||
class provisioner(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: provisionerType
|
||||
type: str
|
||||
|
||||
|
||||
class sanName(BaseModel):
|
||||
|
@ -136,7 +87,7 @@ class x509Cert(BaseModel):
|
|||
not_after: int
|
||||
not_before: int
|
||||
revoked_at: Union[int, None] = None
|
||||
status: certStatus
|
||||
status: str
|
||||
sha256: str
|
||||
sha1: str
|
||||
md5: str
|
||||
|
@ -147,60 +98,16 @@ class x509Cert(BaseModel):
|
|||
pem: str
|
||||
|
||||
|
||||
class x509Extension(BaseModel):
|
||||
id: str
|
||||
critical: bool
|
||||
value: str
|
||||
|
||||
|
||||
# https://pkg.go.dev/crypto/x509#CertificateRequest
|
||||
class x509CertificateRequest(BaseModel):
|
||||
version: int
|
||||
signature: Union[str, None] = None
|
||||
signatureAlgorithm: str
|
||||
|
||||
publicKey: str
|
||||
publicKeyAlgorithm: str
|
||||
|
||||
subject: dict
|
||||
|
||||
extensions: Union[List[x509Extension], None] = None
|
||||
extraExtensions: Union[List[x509Extension], None] = None
|
||||
|
||||
dnsNames: Union[list, None] = None
|
||||
emailAddresses: Union[list, None] = None
|
||||
ipAddresses: Union[list, None] = None
|
||||
uris: Union[list, None] = None
|
||||
|
||||
|
||||
class webhookSCEPChallenge(BaseModel):
|
||||
provisionerName: str
|
||||
scepChallenge: str
|
||||
scepTransactionID: str
|
||||
x509CertificateRequest: x509CertificateRequest
|
||||
|
||||
|
||||
class webhookx509CertificateRequest(BaseModel):
|
||||
# NOTE: provisionerName is missing from step-ca requests
|
||||
# provisionerName: str
|
||||
x509CertificateRequest: x509CertificateRequest
|
||||
|
||||
|
||||
class sshCertType(str, Enum):
|
||||
HOST = "Host"
|
||||
USER = "User"
|
||||
|
||||
|
||||
class sshCert(BaseModel):
|
||||
serial: str
|
||||
alg: str
|
||||
type: sshCertType
|
||||
type: str
|
||||
key_id: str
|
||||
principals: List[str] = []
|
||||
not_after: int
|
||||
not_before: int
|
||||
revoked_at: Union[int, None] = None
|
||||
status: certStatus
|
||||
status: str
|
||||
signing_key: str
|
||||
signing_key_type: str
|
||||
signing_key_hash: str
|
||||
|
@ -211,11 +118,6 @@ class sshCert(BaseModel):
|
|||
extensions: dict = {}
|
||||
|
||||
|
||||
class webhookResponse(BaseModel):
|
||||
allow: bool
|
||||
data: dict = {}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=15, raise_exceptions=False)
|
||||
async def update_metrics():
|
||||
|
@ -243,7 +145,7 @@ async def update_metrics():
|
|||
"principals": ",".join([x.decode() for x in cert.principals]),
|
||||
"serial": cert.serial,
|
||||
"key_id": cert.key_id.decode(),
|
||||
"certificate_type": getattr(sshCertType, cert.type.name).value,
|
||||
"certificate_type": cert.type,
|
||||
}
|
||||
|
||||
ssh_cert_not_after.labels(**labels).set(cert.not_after)
|
||||
|
@ -255,199 +157,57 @@ async def update_metrics():
|
|||
ssh_cert_status.labels(**labels).set(cert.status.value)
|
||||
|
||||
|
||||
@app.get("/x509/certs", tags=["x509"], summary="Get a list of x509 certificates")
|
||||
@app.get("/x509/certs", tags=["x509"])
|
||||
def list_x509_certs(
|
||||
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
|
||||
cert_status: list[certStatus] = Query(["Valid"]),
|
||||
subject: str = None,
|
||||
san: str = None,
|
||||
provisioner: str = None,
|
||||
provisioner_type: list[provisionerType] = Query(list(provisionerType)),
|
||||
sort_key: str = "not_after", revoked: bool = False, expired: bool = False
|
||||
) -> list[x509Cert]:
|
||||
certs = x509_cert.list(db, sort_key=sort_key)
|
||||
cert_list = []
|
||||
|
||||
for cert in certs:
|
||||
if cert.status.name not in [item.name for item in cert_status]:
|
||||
if cert.status.value == x509_cert.status.EXPIRED and not expired:
|
||||
continue
|
||||
if (
|
||||
provisioner is not None
|
||||
and provisioner.casefold() not in cert.provisioner["name"].casefold()
|
||||
):
|
||||
if cert.status.value == x509_cert.status.REVOKED and not revoked:
|
||||
continue
|
||||
if cert.provisioner["type"] not in [item.name for item in provisioner_type]:
|
||||
continue
|
||||
if subject is not None and subject.casefold() not in cert.subject.casefold():
|
||||
continue
|
||||
if san is not None:
|
||||
for cert_san_name in cert.san_names:
|
||||
if san.casefold() in cert_san_name["value"].casefold():
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
cert.status = getattr(certStatus, cert.status.name)
|
||||
cert.status = str(cert.status)
|
||||
cert_list.append(cert)
|
||||
|
||||
return cert_list
|
||||
|
||||
|
||||
@app.get(
|
||||
"/x509/certs/{serial}", tags=["x509"], summary="Get details on an x509 certificate"
|
||||
)
|
||||
@app.get("/x509/certs/{serial}", tags=["x509"])
|
||||
def get_x509_cert(serial: str) -> Union[x509Cert, None]:
|
||||
cert = x509_cert.cert.from_serial(db, serial)
|
||||
if cert is None:
|
||||
return None
|
||||
cert.status = getattr(certStatus, cert.status.name)
|
||||
cert.status = str(cert.status)
|
||||
return cert
|
||||
|
||||
|
||||
@app.get("/ssh/certs", tags=["ssh"], summary="Get a list of SSH certificates")
|
||||
@app.get("/ssh/certs", tags=["ssh"])
|
||||
def list_ssh_certs(
|
||||
sort_key: str = Query(enum=["not_after", "not_before"], default="not_after"),
|
||||
cert_type: list[sshCertType] = Query(["Host", "User"]),
|
||||
cert_status: list[certStatus] = Query(["Valid"]),
|
||||
key: str = None,
|
||||
principal: str = None,
|
||||
sort_key: str = "not_after", revoked: bool = False, expired: bool = False
|
||||
) -> list[sshCert]:
|
||||
certs = ssh_cert.list(db, sort_key=sort_key)
|
||||
cert_list = []
|
||||
|
||||
for cert in certs:
|
||||
if cert.status.name not in [item.name for item in cert_status]:
|
||||
if cert.status.value == ssh_cert.status.EXPIRED and not expired:
|
||||
continue
|
||||
if cert.type.name not in [item.name for item in cert_type]:
|
||||
if cert.status.value == ssh_cert.status.REVOKED and not revoked:
|
||||
continue
|
||||
if key is not None and key.casefold() not in str(cert.key_id).casefold():
|
||||
continue
|
||||
if principal is not None:
|
||||
for cert_principal in cert.principals:
|
||||
if principal.casefold() in str(cert_principal).casefold():
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
cert.type = getattr(sshCertType, cert.type.name)
|
||||
cert.status = getattr(certStatus, cert.status.name)
|
||||
cert.status = str(cert.status)
|
||||
cert_list.append(cert)
|
||||
|
||||
return cert_list
|
||||
|
||||
|
||||
@app.get(
|
||||
"/ssh/certs/{serial}", tags=["ssh"], summary="Get details on an SSH certificate"
|
||||
)
|
||||
@app.get("/ssh/certs/{serial}", tags=["ssh"])
|
||||
def get_ssh_cert(serial: str) -> Union[sshCert, None]:
|
||||
cert = ssh_cert.cert.from_serial(db, serial)
|
||||
if cert is None:
|
||||
return None
|
||||
cert.type = getattr(sshCertType, cert.type.name)
|
||||
cert.status = getattr(certStatus, cert.status.name)
|
||||
cert.status = str(cert.status)
|
||||
return cert
|
||||
|
||||
|
||||
async def webhook_validate(
|
||||
request: Request,
|
||||
x_smallstep_webhook_id: str = Header(),
|
||||
x_smallstep_signature: str = Header(),
|
||||
) -> WebhookSettings:
|
||||
|
||||
logger.debug(f"Received webhook request for webhook ID {x_smallstep_webhook_id}")
|
||||
|
||||
webhook_config = next(
|
||||
(
|
||||
webhook
|
||||
for webhook in config.webhook_config
|
||||
if webhook.id == x_smallstep_webhook_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if webhook_config is None:
|
||||
logger.error("Invalid webhook ID")
|
||||
raise HTTPException(status_code=400, detail="Invalid webhook ID")
|
||||
|
||||
try:
|
||||
signing_secret = base64.b64decode(webhook_config.secret)
|
||||
except ValueError:
|
||||
logger.error("Misconfigured webhook secret")
|
||||
raise HTTPException(status_code=500)
|
||||
|
||||
try:
|
||||
sig = bytes.fromhex(x_smallstep_signature)
|
||||
except ValueError:
|
||||
logger.error("Invalid X-Smallstep-Signature header")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid X-Smallstep-Signature header"
|
||||
)
|
||||
|
||||
body = await request.body()
|
||||
|
||||
h = hmac.new(signing_secret, body, hashlib.sha256)
|
||||
|
||||
if not hmac.compare_digest(sig, h.digest()):
|
||||
logger.error("Invalid signature")
|
||||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||
|
||||
return webhook_config
|
||||
|
||||
|
||||
@app.post(
|
||||
"/webhook/scepchallenge", tags=["webhooks"], summary="Valiate a SCEP challenge"
|
||||
)
|
||||
def webhook_scepchallenge(
|
||||
req: webhookSCEPChallenge,
|
||||
webhook_config: dict = Depends(webhook_validate),
|
||||
) -> webhookResponse:
|
||||
|
||||
response = webhookResponse
|
||||
|
||||
logger.info("Received SCEP challenge webhook request")
|
||||
|
||||
if not hasattr(scep_challenge, webhook_config.plugin.name):
|
||||
logger.error("Invalid challenge plugin configured")
|
||||
raise HTTPException(status_code=500)
|
||||
|
||||
validator = getattr(scep_challenge, webhook_config.plugin.name)(
|
||||
webhook_config.plugin
|
||||
)
|
||||
|
||||
if validator.validate(req):
|
||||
logger.info("Validator approved certificate request")
|
||||
response.allow = True
|
||||
else:
|
||||
logger.warning("Validator refused certificate request")
|
||||
response.allow = False
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.post(
|
||||
"/webhook/oidc",
|
||||
tags=["webhooks"],
|
||||
summary="Valiate and enrich an OIDC certificate request",
|
||||
)
|
||||
async def webhook_oidc(
|
||||
req: webhookx509CertificateRequest,
|
||||
webhook_config: WebhookSettings = Depends(webhook_validate),
|
||||
) -> webhookResponse:
|
||||
|
||||
response = webhookResponse
|
||||
|
||||
logger.info("Received OIDC webhook request")
|
||||
|
||||
if not hasattr(x509, webhook_config.plugin.name):
|
||||
logger.error("Invalid x509 plugin configured")
|
||||
raise HTTPException(status_code=500)
|
||||
|
||||
validator = getattr(x509, webhook_config.plugin.name)(webhook_config.plugin)
|
||||
|
||||
if validator.validate(req):
|
||||
logger.info("Validator approved certificate request")
|
||||
response.allow = True
|
||||
else:
|
||||
logger.warning("Validator refused certificate request")
|
||||
response.allow = False
|
||||
|
||||
return response
|
||||
|
|
|
@ -5,7 +5,6 @@ import mariadb
|
|||
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from struct import unpack
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class list:
|
||||
|
@ -59,7 +58,12 @@ class cert:
|
|||
cert = serialization.load_ssh_public_identity(cert_pub_id)
|
||||
self.serial = str(cert.serial)
|
||||
self.alg = cert_alg
|
||||
self.type = cert.type
|
||||
if cert.type == serialization.SSHCertificateType.USER:
|
||||
self.type = "User"
|
||||
elif cert.type == serialization.SSHCertificateType.HOST:
|
||||
self.type = "Host"
|
||||
else:
|
||||
self.type = "Unknown"
|
||||
self.key_id = cert.key_id
|
||||
self.principals = cert.valid_principals
|
||||
self.not_after = cert.valid_before
|
||||
|
@ -92,11 +96,11 @@ class cert:
|
|||
)
|
||||
|
||||
if self.revoked_at is not None and self.revoked_at < now_with_tz:
|
||||
self.status = status.REVOKED
|
||||
self.status = status(status.REVOKED)
|
||||
elif self.not_after < now_with_tz:
|
||||
self.status = status.EXPIRED
|
||||
self.status = status(status.EXPIRED)
|
||||
else:
|
||||
self.status = status.VALID
|
||||
self.status = status(status.VALID)
|
||||
|
||||
def get_cert(self, db, cert_serial):
|
||||
cur = db.cursor()
|
||||
|
@ -138,7 +142,20 @@ class cert:
|
|||
return key_str, key_type, key_hash
|
||||
|
||||
|
||||
class status(Enum):
|
||||
class status:
|
||||
REVOKED = 1
|
||||
EXPIRED = 2
|
||||
VALID = 3
|
||||
|
||||
def __init__(self, status):
|
||||
self.value = status
|
||||
|
||||
def __str__(self):
|
||||
if self.value == self.EXPIRED:
|
||||
return "Expired"
|
||||
elif self.value == self.REVOKED:
|
||||
return "Revoked"
|
||||
elif self.value == self.VALID:
|
||||
return "Valid"
|
||||
else:
|
||||
return "Undefined"
|
||||
|
|
|
@ -5,7 +5,6 @@ import mariadb
|
|||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class list:
|
||||
|
@ -98,11 +97,11 @@ class cert:
|
|||
)
|
||||
|
||||
if self.revoked_at is not None and self.revoked_at < now_with_tz:
|
||||
self.status = status.REVOKED
|
||||
self.status = status(status.REVOKED)
|
||||
elif self.not_after < now_with_tz:
|
||||
self.status = status.EXPIRED
|
||||
self.status = status(status.EXPIRED)
|
||||
else:
|
||||
self.status = status.VALID
|
||||
self.status = status(status.VALID)
|
||||
|
||||
def get_cert(self, db, cert_serial):
|
||||
cur = db.cursor()
|
||||
|
@ -154,7 +153,20 @@ class cert:
|
|||
return sans
|
||||
|
||||
|
||||
class status(Enum):
|
||||
class status:
|
||||
REVOKED = 1
|
||||
EXPIRED = 2
|
||||
VALID = 3
|
||||
|
||||
def __init__(self, status):
|
||||
self.value = status
|
||||
|
||||
def __str__(self):
|
||||
if self.value == self.EXPIRED:
|
||||
return "Expired"
|
||||
elif self.value == self.REVOKED:
|
||||
return "Revoked"
|
||||
elif self.value == self.VALID:
|
||||
return "Valid"
|
||||
else:
|
||||
return "Undefined"
|
||||
|
|
|
@ -1,149 +0,0 @@
|
|||
from fastapi import HTTPException
|
||||
from fnmatch import fnmatch
|
||||
from config import VaultAuthMethod
|
||||
import hvac
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class hashicorp_vault:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.client = hvac.Client(**self.config.hvac_connection)
|
||||
|
||||
if self.config.hvac_auth_method == VaultAuthMethod.TOKEN:
|
||||
self.client.token = self.config.hvac_token
|
||||
elif self.config.hvac_auth_method == VaultAuthMethod.APPROLE:
|
||||
try:
|
||||
self.client.auth.approle.login(
|
||||
role_id=self.config.hvac_role_id,
|
||||
secret_id=self.config.hvac_secret_id,
|
||||
)
|
||||
except hvac.exceptions.VaultError as e:
|
||||
logger.error(f"HashiCorp Vault error: {e}")
|
||||
raise HTTPException(status_code=500)
|
||||
|
||||
if not self.client.is_authenticated():
|
||||
logger.error("HashiCorp Vault client is not authenticated")
|
||||
raise HTTPException(status_code=500)
|
||||
|
||||
def validate(self, req):
|
||||
logger.debug("Validating with hashicorp_vault plugin")
|
||||
cn = req.x509CertificateRequest.subject.get("commonName")
|
||||
|
||||
try:
|
||||
secret = self.client.secrets.kv.v2.read_secret(
|
||||
path=self.config.hvac_secret_path % cn,
|
||||
mount_point=self.config.hvac_engine,
|
||||
)
|
||||
except hvac.exceptions.VaultError as e:
|
||||
logger.warning(f"HashiCorp Vault error: {e}")
|
||||
return False
|
||||
|
||||
challenge = secret["data"]["data"].get(self.config.hvac_challenge_key)
|
||||
|
||||
if req.scepChallenge != challenge:
|
||||
logger.error("SCEP challenge does not match")
|
||||
return False
|
||||
|
||||
allowed_dns_names = secret["data"]["data"].get(
|
||||
self.config.hvac_allowed_dns_names_key, []
|
||||
) + [cn]
|
||||
allowed_email_addresses = secret["data"]["data"].get(
|
||||
self.config.hvac_allowed_email_addresses_key, []
|
||||
)
|
||||
allowed_ip_addresses = secret["data"]["data"].get(
|
||||
self.config.hvac_allowed_ip_addresses_key, []
|
||||
)
|
||||
allowed_uris = secret["data"]["data"].get(self.config.hvac_allowed_uris, [])
|
||||
|
||||
for dns_name in req.x509CertificateRequest.dnsNames or []:
|
||||
for allowed_dns_name in allowed_dns_names:
|
||||
if fnmatch(dns_name, allowed_dns_name):
|
||||
logger.debug(f"DNS name {dns_name} is allowed")
|
||||
break
|
||||
else:
|
||||
logger.error(f"DNS name {dns_name} is not allowed")
|
||||
return False
|
||||
|
||||
for email_address in req.x509CertificateRequest.emailAddresses or []:
|
||||
if email_address not in allowed_email_addresses:
|
||||
logger.error(f"Email address {email_address} is not allowed")
|
||||
return False
|
||||
logger.debug(f"Email address {email_address} is allowed")
|
||||
|
||||
for ip_address in req.x509CertificateRequest.ipAddresses or []:
|
||||
if ip_address not in allowed_ip_addresses:
|
||||
logger.error(f"IP address {ip_address} is not allowed")
|
||||
return False
|
||||
logger.debug(f"IP address {ip_address} is allowed")
|
||||
|
||||
for uri in req.x509CertificateRequest.uris or []:
|
||||
if uri not in allowed_uris:
|
||||
logger.error(f"URI {uri} is not allowed")
|
||||
return False
|
||||
logger.debug(f"URI {uri} is allowed")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class scep_static:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def validate(self, req):
|
||||
logger.debug("Validating with static plugin")
|
||||
|
||||
challenge_config = next(
|
||||
(
|
||||
challenge
|
||||
for challenge in self.config.challenges
|
||||
if challenge.secret == req.scepChallenge
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if challenge_config is None:
|
||||
logger.error("SCEP challenge does not match")
|
||||
return False
|
||||
|
||||
cn = req.x509CertificateRequest.subject.get("commonName")
|
||||
|
||||
for allowed_dns_name in challenge_config.allowed_dns_names:
|
||||
if fnmatch(cn, allowed_dns_name):
|
||||
logger.debug(f"Subject CN={cn} is allowed")
|
||||
break
|
||||
else:
|
||||
logger.error(f"Subject CN={cn} is not allowed")
|
||||
return False
|
||||
|
||||
for dns_name in req.x509CertificateRequest.dnsNames or []:
|
||||
for allowed_dns_name in challenge_config.allowed_dns_names:
|
||||
if fnmatch(dns_name, allowed_dns_name):
|
||||
logger.debug(f"DNS name {dns_name} is allowed")
|
||||
break
|
||||
else:
|
||||
logger.error(f"DNS name {dns_name} is not allowed")
|
||||
return False
|
||||
|
||||
for email_address in req.x509CertificateRequest.emailAddresses or []:
|
||||
if email_address not in challenge_config.allowed_email_addresses:
|
||||
logger.error(f"Email address {email_address} is not allowed")
|
||||
return False
|
||||
logger.debug(f"Email address {email_address} is allowed")
|
||||
|
||||
for ip_address in req.x509CertificateRequest.ipAddresses or []:
|
||||
if ip_address not in challenge_config.allowed_ip_addresses:
|
||||
logger.error(f"IP address {ip_address} is not allowed")
|
||||
return False
|
||||
logger.debug(f"IP address {ip_address} is allowed")
|
||||
|
||||
for uri in req.x509CertificateRequest.uris or []:
|
||||
if uri not in challenge_config.allowed_uris:
|
||||
logger.error(f"URI {uri} is not allowed")
|
||||
return False
|
||||
logger.debug(f"URI {uri} is allowed")
|
||||
|
||||
return True
|
|
@ -1,149 +0,0 @@
|
|||
import base64
|
||||
import binascii
|
||||
import logging
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from datetime import datetime, timezone
|
||||
|
||||
PIN_POLICY = {"01": "never", "02": "once", "03": "always"}
|
||||
|
||||
TOUCH_POLICY = {"01": "never", "02": "always", "03": "cached"}
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class yubikey_embedded_attestation:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def validate(self, req):
|
||||
logger.debug("Validating with yubikey_embedded_attestation plugin")
|
||||
pub_key = req.x509CertificateRequest.publicKey
|
||||
pub_alg = req.x509CertificateRequest.publicKeyAlgorithm
|
||||
extensions = req.x509CertificateRequest.extensions
|
||||
|
||||
attestation_cert = None
|
||||
intermediate_cert = None
|
||||
with open(self.config.yubikey_attestation_root, "rb") as file:
|
||||
root_cert = x509.load_pem_x509_certificate(file.read())
|
||||
|
||||
for extension in extensions:
|
||||
if extension.id == "1.3.6.1.4.1.41482.3.1":
|
||||
attestation_cert = x509.load_der_x509_certificate(
|
||||
base64.b64decode(extension.value)
|
||||
)
|
||||
|
||||
elif extension.id == "1.3.6.1.4.1.41482.3.2":
|
||||
intermediate_cert = x509.load_der_x509_certificate(
|
||||
base64.b64decode(extension.value)
|
||||
)
|
||||
|
||||
if attestation_cert is None:
|
||||
logger.error("CSR does not include an attestation certificate")
|
||||
return False
|
||||
|
||||
if intermediate_cert is None:
|
||||
logger.error("CSR does not include an intermediate attestation certificate")
|
||||
return False
|
||||
|
||||
try:
|
||||
intermediate_cert.public_key().verify(
|
||||
attestation_cert.signature,
|
||||
attestation_cert.tbs_certificate_bytes,
|
||||
padding.PKCS1v15(),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
logger.debug("Valid intermediate attestation certificate signature")
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid intermediate attestation certificate signature {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
root_cert.public_key().verify(
|
||||
intermediate_cert.signature,
|
||||
intermediate_cert.tbs_certificate_bytes,
|
||||
padding.PKCS1v15(),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
logger.debug("Valid root attestation certificate signature")
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid root attestation certificate signature: {e}")
|
||||
return False
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
for cert in [attestation_cert, intermediate_cert, root_cert]:
|
||||
if cert.not_valid_before_utc <= current_time <= cert.not_valid_after_utc:
|
||||
logger.debug(f"Certificate {cert.subject.rfc4514_string()} is valid")
|
||||
else:
|
||||
logger.error(
|
||||
f"Certificate {cert.subject.rfc4514_string()} is not valid"
|
||||
)
|
||||
return False
|
||||
|
||||
csr_public_key_bytes = base64.b64decode(pub_key)
|
||||
attestation_public_key_bytes = attestation_cert.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
if csr_public_key_bytes == attestation_public_key_bytes:
|
||||
logger.debug("CSR and attestation public keys match")
|
||||
else:
|
||||
logger.error("CSR and attestation public keys do not match")
|
||||
return False
|
||||
|
||||
firmware_version = serial_number = pin_policy = touch_policy = "Not Found"
|
||||
# https://docs.yubico.com/hardware/oid/webdocs.pdf
|
||||
for ext in attestation_cert.extensions:
|
||||
if ext.oid.dotted_string == "1.3.6.1.4.1.41482.3.3":
|
||||
# Decode Firmware Version
|
||||
ext_data = binascii.hexlify(ext.value.value).decode("utf-8")
|
||||
firmware_version = f"{int(ext_data[:2], 16)}.{int(ext_data[2:4], 16)}.{int(ext_data[4:6], 16)}"
|
||||
elif ext.oid.dotted_string == "1.3.6.1.4.1.41482.3.7":
|
||||
# Decode Serial Number
|
||||
ext_data = ext.value.value
|
||||
# Assuming the first two bytes are not part of the serial number, skip them
|
||||
serial_number = int(binascii.hexlify(ext_data[2:]), 16)
|
||||
elif ext.oid.dotted_string == "1.3.6.1.4.1.41482.3.8":
|
||||
# Decode Pin Policy and Touch Policy
|
||||
ext_data = binascii.hexlify(ext.value.value).decode("utf-8")
|
||||
pin_policy = ext_data[:2]
|
||||
pin_policy_value = PIN_POLICY.get(pin_policy)
|
||||
touch_policy = ext_data[2:4]
|
||||
touch_policy_value = TOUCH_POLICY.get(touch_policy)
|
||||
|
||||
if self.config.yubikey_allowed_serials is None:
|
||||
logger.debug("No serial filtering configured")
|
||||
pass
|
||||
elif serial_number not in self.config.yubikey_allowed_serials:
|
||||
logger.error(f"Yubikey S/N {serial_number} is not allowed")
|
||||
return False
|
||||
else:
|
||||
logger.debug(f"Yubikey S/N {serial_number} is allowed")
|
||||
|
||||
if pin_policy_value is None:
|
||||
logger.error(f"Unknown PIN policy")
|
||||
return False
|
||||
elif not getattr(self.config.yubikey_pin_policies, pin_policy_value):
|
||||
logger.error(
|
||||
f"PIN policy “{pin_policy_value}” ({pin_policy}) is not allowed"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.debug(f"PIN policy “{pin_policy_value}” ({pin_policy}) is allowed")
|
||||
|
||||
if pin_policy_value is None:
|
||||
logger.error(f"Unknown touch policy")
|
||||
return False
|
||||
elif not getattr(self.config.yubikey_touch_policies, touch_policy_value):
|
||||
logger.error(
|
||||
f"Touch policy “{touch_policy_value}” ({touch_policy}) is not allowed"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.debug(
|
||||
f"Touch policy “{touch_policy_value}” ({touch_policy}) is allowed"
|
||||
)
|
||||
|
||||
return True
|
Loading…
Add table
Add a link
Reference in a new issue