Optimise SQL queries

This commit is contained in:
Benjamin Collet 2025-01-07 17:49:01 +01:00
parent c469dccb5c
commit 9859b1cd29
Signed by: bcollet
SSH key fingerprint: SHA256:8UJspOIcCOS+MtSOcnuq2HjKFube4ox1s/+A62ixov4
3 changed files with 75 additions and 64 deletions

View file

@ -23,10 +23,15 @@ class list:
def __new__(cls, sort_key=None):
cur = conn.cursor()
cur.execute("SELECT nkey FROM ssh_certs")
cur.execute(
"""SELECT ssh_certs.nvalue AS cert,
revoked_ssh_certs.nvalue AS revoked
FROM ssh_certs
LEFT JOIN revoked_ssh_certs USING(nkey)"""
)
for (cert_serial,) in cur:
cert_object = cert(cert_serial)
for result in cur:
cert_object = cert(result)
cls.certs.append(cert_object)
cur.close()
@ -38,15 +43,25 @@ class list:
class cert:
def __init__(self, serial):
cert_raw = self.get_cert(serial)
def __init__(self, cert):
(cert_raw, cert_revoked_raw) = cert
size = unpack(">I", cert_raw[:4])[0] + 4
alg = cert_raw[4:size]
cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)])
cert_revoked = self.get_cert_revoked(serial)
if cert_revoked_raw is not None:
cert_revoked = json.loads(cert_revoked_raw)
else:
cert_revoked = None
self.load(cert_pub_id, cert_revoked, alg)
@classmethod
def from_serial(cls, serial):
return cls(cert=cls.get_cert(cls, serial))
def load(self, cert_pub_id, cert_revoked, cert_alg):
cert = serialization.load_ssh_public_identity(cert_pub_id)
self.serial = cert.serial
@ -95,27 +110,22 @@ class cert:
def get_cert(self, cert_serial):
cur = conn.cursor()
cur.execute("SELECT nvalue FROM ssh_certs WHERE nkey=?", (cert_serial,))
cur.execute(
"""SELECT ssh_certs.nvalue AS cert,
revoked_ssh_certs.nvalue AS revoked
FROM ssh_certs
LEFT JOIN revoked_ssh_certs USING(nkey)
WHERE nkey=?""",
(cert_serial,),
)
if cur.rowcount > 0:
(cert,) = cur.fetchone()
cert = cur.fetchone()
else:
cert = None
cur.close()
return cert
def get_cert_revoked(self, cert_serial):
cur = conn.cursor()
cur.execute("SELECT nvalue FROM revoked_ssh_certs WHERE nkey=?", (cert_serial,))
if cur.rowcount > 0:
(cert_revoked_raw,) = cur.fetchone()
cert_revoked = json.loads(cert_revoked_raw)
else:
cert_revoked = None
cur.close()
return cert_revoked
def get_public_key_params(self, public_key):
if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey):
key_type = "ECDSA"

View file

@ -13,7 +13,7 @@ conn = mariadb.connect(
host=config.database_host,
user=config.database_user,
password=config.database_password,
database = config.database_name
database=config.database_name,
)
@ -22,10 +22,17 @@ class list:
def __new__(cls, sort_key=None):
cur = conn.cursor()
cur.execute("SELECT nkey FROM x509_certs")
cur.execute(
"""SELECT x509_certs.nvalue AS cert,
x509_certs_data.nvalue AS data,
revoked_x509_certs.nvalue AS revoked
FROM x509_certs
INNER JOIN x509_certs_data USING(nkey)
LEFT JOIN revoked_x509_certs USING(nkey)"""
)
for (cert_serial,) in cur:
cert_object = cert(cert_serial)
for result in cur:
cert_object = cert(result)
cls.certs.append(cert_object)
cur.close()
@ -36,14 +43,21 @@ class list:
return cls.certs
class cert:
def __init__(self, serial):
cert_der = self.get_cert(serial)
cert_data = self.get_cert_data(serial)
cert_revoked = self.get_cert_revoked(serial)
def __init__(self, cert):
(cert_der, cert_data_raw, cert_revoked_raw) = cert
cert_data = json.loads(cert_data_raw)
if cert_revoked_raw is not None:
cert_revoked = json.loads(cert_revoked_raw)
else:
cert_revoked = None
self.load(cert_der, cert_data, cert_revoked)
@classmethod
def from_serial(cls, serial):
return cls(cert=cls.get_cert(cls, serial))
def load(self, cert_der, cert_data, cert_revoked):
cert = x509.load_der_x509_certificate(cert_der)
@ -60,7 +74,9 @@ class cert:
self.not_before = cert.not_valid_before_utc.replace(microsecond=0)
self.not_after = cert.not_valid_after_utc.replace(microsecond=0)
try:
san_data = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName)
san_data = cert.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
self.san_names = san_data.value.get_values_for_type(x509.GeneralName)
except x509.extensions.ExtensionNotFound:
self.san_names = []
@ -85,12 +101,21 @@ class cert:
else:
self.status = status(status.VALID)
def get_cert(self, cert_serial):
cur = conn.cursor()
cur.execute("SELECT nvalue FROM x509_certs WHERE nkey=?", (cert_serial,))
cur.execute(
"""SELECT x509_certs.nvalue AS cert,
x509_certs_data.nvalue AS data,
revoked_x509_certs.nvalue AS revoked
FROM x509_certs
INNER JOIN x509_certs_data USING(nkey)
LEFT JOIN revoked_x509_certs USING(nkey)
WHERE nkey=?""",
(cert_serial,),
)
if cur.rowcount > 0:
(cert,) = cur.fetchone()
cert = cur.fetchone()
else:
cert = None
@ -98,30 +123,6 @@ class cert:
return cert
def get_cert_data(self, cert_serial):
cur = conn.cursor()
cur.execute("SELECT nvalue FROM x509_certs_data WHERE nkey=?", (cert_serial,))
(cert_data_raw,) = cur.fetchone()
cur.close()
cert_data = json.loads(cert_data_raw)
return cert_data
def get_cert_revoked(self, cert_serial):
cur = conn.cursor()
cur.execute(
"SELECT nvalue FROM revoked_x509_certs WHERE nkey=?", (cert_serial,)
)
if cur.rowcount > 0:
(cert_revoked_raw,) = cur.fetchone()
cert_revoked = json.loads(cert_revoked_raw)
else:
cert_revoked = None
cur.close()
return cert_revoked
class status:
REVOKED = 1
EXPIRED = 2

View file

@ -42,7 +42,7 @@ def list_ssh_certs(sort_key, revoked=False, expired=False):
def get_ssh_cert(serial):
cert = ssh_cert.cert(serial)
cert = ssh_cert.cert.from_serial(serial)
cert_tbl = []
cert_tbl.append(["Serial", cert.serial])
@ -68,7 +68,7 @@ def get_ssh_cert(serial):
def dump_ssh_cert(serial):
cert = ssh_cert.cert(serial)
cert = ssh_cert.cert.from_serial(serial)
print(cert.public_identity.decode())
@ -105,7 +105,7 @@ def list_x509_certs(sort_key, revoked=False, expired=False):
def get_x509_cert(serial, show_pem=False):
cert = x509_cert.cert(serial)
cert = x509_cert.cert.from_serial(serial)
cert_tbl = []
cert_tbl.append(["Serial", cert.serial])
@ -135,7 +135,7 @@ def get_x509_cert(serial, show_pem=False):
def dump_x509_cert(serial, cert_format="pem"):
cert = x509_cert.cert(serial)
cert = x509_cert.cert.from_serial(serial)
print(cert.pem.decode("utf-8").rstrip())