166 lines
4.9 KiB
Python
166 lines
4.9 KiB
Python
import base64
|
|
import dateutil
|
|
import json
|
|
import mariadb
|
|
from cryptography import x509
|
|
from cryptography.hazmat.primitives import asymmetric, hashes, serialization
|
|
from datetime import datetime, timedelta, timezone
|
|
from models.config import config
|
|
from struct import unpack
|
|
|
|
|
|
config()
|
|
conn = mariadb.connect(
|
|
host=config.database_host,
|
|
user=config.database_user,
|
|
password=config.database_password,
|
|
database=config.database_name,
|
|
)
|
|
|
|
|
|
class list:
|
|
certs = []
|
|
|
|
def __new__(cls, sort_key=None):
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""SELECT ssh_certs.nvalue AS cert,
|
|
revoked_ssh_certs.nvalue AS revoked
|
|
FROM ssh_certs
|
|
LEFT JOIN revoked_ssh_certs USING(nkey)"""
|
|
)
|
|
|
|
for result in cur:
|
|
cert_object = cert(result)
|
|
cls.certs.append(cert_object)
|
|
|
|
cur.close()
|
|
|
|
if sort_key is not None:
|
|
cls.certs.sort(key=lambda item: getattr(item, sort_key))
|
|
|
|
return cls.certs
|
|
|
|
|
|
class cert:
|
|
def __init__(self, cert):
|
|
(cert_raw, cert_revoked_raw) = cert
|
|
|
|
size = unpack(">I", cert_raw[:4])[0] + 4
|
|
alg = cert_raw[4:size]
|
|
|
|
cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)])
|
|
|
|
if cert_revoked_raw is not None:
|
|
cert_revoked = json.loads(cert_revoked_raw)
|
|
else:
|
|
cert_revoked = None
|
|
|
|
self.load(cert_pub_id, cert_revoked, alg)
|
|
|
|
@classmethod
|
|
def from_serial(cls, serial):
|
|
return cls(cert=cls.get_cert(cls, serial))
|
|
|
|
def load(self, cert_pub_id, cert_revoked, cert_alg):
|
|
cert = serialization.load_ssh_public_identity(cert_pub_id)
|
|
self.serial = cert.serial
|
|
self.alg = cert_alg
|
|
if cert.type == serialization.SSHCertificateType.USER:
|
|
self.type = "User"
|
|
self.key_id = cert.key_id
|
|
self.principals = cert.valid_principals
|
|
self.not_after = datetime.fromtimestamp(cert.valid_before).replace(
|
|
tzinfo=timezone(offset=timedelta()), microsecond=0
|
|
)
|
|
self.not_before = datetime.fromtimestamp(cert.valid_after).replace(
|
|
tzinfo=timezone(offset=timedelta()), microsecond=0
|
|
)
|
|
# TODO: Implement critical options parsing
|
|
# cert.critical_options
|
|
self.extensions = cert.extensions
|
|
|
|
(self.signing_key, self.signing_key_type, self.signing_key_hash) = (
|
|
self.get_public_key_params(cert.signature_key())
|
|
)
|
|
|
|
(self.public_key, self.public_key_type, self.public_key_hash) = (
|
|
self.get_public_key_params(cert.public_key())
|
|
)
|
|
|
|
self.public_identity = cert.public_bytes()
|
|
|
|
if cert_revoked is not None:
|
|
self.revoked_at = dateutil.parser.isoparse(
|
|
cert_revoked.get("RevokedAt")
|
|
).replace(microsecond=0)
|
|
else:
|
|
self.revoked_at = None
|
|
|
|
now_with_tz = datetime.utcnow().replace(
|
|
tzinfo=timezone(offset=timedelta()), microsecond=0
|
|
)
|
|
|
|
if self.revoked_at is not None and self.revoked_at < now_with_tz:
|
|
self.status = status(status.REVOKED)
|
|
elif self.not_after < now_with_tz:
|
|
self.status = status(status.EXPIRED)
|
|
else:
|
|
self.status = status(status.VALID)
|
|
|
|
def get_cert(self, cert_serial):
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""SELECT ssh_certs.nvalue AS cert,
|
|
revoked_ssh_certs.nvalue AS revoked
|
|
FROM ssh_certs
|
|
LEFT JOIN revoked_ssh_certs USING(nkey)
|
|
WHERE nkey=?""",
|
|
(cert_serial,),
|
|
)
|
|
if cur.rowcount > 0:
|
|
cert = cur.fetchone()
|
|
else:
|
|
cert = None
|
|
|
|
cur.close()
|
|
return cert
|
|
|
|
def get_public_key_params(self, public_key):
|
|
if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey):
|
|
key_type = "ECDSA"
|
|
elif isinstance(public_key, asymmetric.ed25519.Ed25519PublicKey):
|
|
key_type = "ED25519"
|
|
elif isinstance(public_key, asymmetric.rsa.RSAPublicKey):
|
|
key_type = "RSA"
|
|
|
|
key_str = public_key.public_bytes(
|
|
serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
|
|
)
|
|
|
|
key_data = key_str.strip().split()[1]
|
|
digest = hashes.Hash(hashes.SHA256())
|
|
digest.update(base64.b64decode(key_data))
|
|
hash_sha256 = digest.finalize()
|
|
key_hash = base64.b64encode(hash_sha256)
|
|
|
|
return key_str, key_type, key_hash
|
|
|
|
|
|
class status:
|
|
REVOKED = 1
|
|
EXPIRED = 2
|
|
VALID = 3
|
|
|
|
def __init__(self, status):
|
|
self.value = status
|
|
|
|
def __str__(self):
|
|
if self.value == self.EXPIRED:
|
|
return "Expired"
|
|
elif self.value == self.REVOKED:
|
|
return "Revoked"
|
|
elif self.value == self.VALID:
|
|
return "Valid"
|
|
else:
|
|
return "Undefined"
|