From 9859b1cd293d69e40a902a202a2732de684fdb47 Mon Sep 17 00:00:00 2001
From: Benjamin Collet <benjamin@collet.eu>
Date: Tue, 7 Jan 2025 17:49:01 +0100
Subject: [PATCH] Optimise SQL queries

---
 models/ssh_cert.py   | 50 ++++++++++++++++-----------
 models/x509_cert.py  | 81 ++++++++++++++++++++++----------------------
 step-ca-inspector.py |  8 ++---
 3 files changed, 75 insertions(+), 64 deletions(-)

diff --git a/models/ssh_cert.py b/models/ssh_cert.py
index 3e021e0..0ade70c 100644
--- a/models/ssh_cert.py
+++ b/models/ssh_cert.py
@@ -23,10 +23,15 @@ class list:
 
     def __new__(cls, sort_key=None):
         cur = conn.cursor()
-        cur.execute("SELECT nkey FROM ssh_certs")
+        cur.execute(
+            """SELECT ssh_certs.nvalue AS cert,
+                      revoked_ssh_certs.nvalue AS revoked
+               FROM ssh_certs
+               LEFT JOIN revoked_ssh_certs USING(nkey)"""
+        )
 
-        for (cert_serial,) in cur:
-            cert_object = cert(cert_serial)
+        for result in cur:
+            cert_object = cert(result)
             cls.certs.append(cert_object)
 
         cur.close()
@@ -38,15 +43,25 @@ class list:
 
 
 class cert:
-    def __init__(self, serial):
-        cert_raw = self.get_cert(serial)
+    def __init__(self, cert):
+        (cert_raw, cert_revoked_raw) = cert
+
         size = unpack(">I", cert_raw[:4])[0] + 4
         alg = cert_raw[4:size]
 
         cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)])
-        cert_revoked = self.get_cert_revoked(serial)
+
+        if cert_revoked_raw is not None:
+            cert_revoked = json.loads(cert_revoked_raw)
+        else:
+            cert_revoked = None
+
         self.load(cert_pub_id, cert_revoked, alg)
 
+    @classmethod
+    def from_serial(cls, serial):
+        return cls(cert=cls.get_cert(cls, serial))
+
     def load(self, cert_pub_id, cert_revoked, cert_alg):
         cert = serialization.load_ssh_public_identity(cert_pub_id)
         self.serial = cert.serial
@@ -95,27 +110,22 @@ class cert:
 
     def get_cert(self, cert_serial):
         cur = conn.cursor()
-        cur.execute("SELECT nvalue FROM ssh_certs WHERE nkey=?", (cert_serial,))
+        cur.execute(
+            """SELECT ssh_certs.nvalue AS cert,
+                      revoked_ssh_certs.nvalue AS revoked
+               FROM ssh_certs
+               LEFT JOIN revoked_ssh_certs USING(nkey)
+               WHERE nkey=?""",
+            (cert_serial,),
+        )
         if cur.rowcount > 0:
-            (cert,) = cur.fetchone()
+            cert = cur.fetchone()
         else:
             cert = None
 
         cur.close()
         return cert
 
-    def get_cert_revoked(self, cert_serial):
-        cur = conn.cursor()
-        cur.execute("SELECT nvalue FROM revoked_ssh_certs WHERE nkey=?", (cert_serial,))
-        if cur.rowcount > 0:
-            (cert_revoked_raw,) = cur.fetchone()
-            cert_revoked = json.loads(cert_revoked_raw)
-        else:
-            cert_revoked = None
-
-        cur.close()
-        return cert_revoked
-
     def get_public_key_params(self, public_key):
         if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey):
             key_type = "ECDSA"
diff --git a/models/x509_cert.py b/models/x509_cert.py
index 19ce241..d5c1d7a 100644
--- a/models/x509_cert.py
+++ b/models/x509_cert.py
@@ -10,10 +10,10 @@ from models.config import config
 
 config()
 conn = mariadb.connect(
-    host = config.database_host,
-    user = config.database_user,
-    password = config.database_password,
-    database = config.database_name
+    host=config.database_host,
+    user=config.database_user,
+    password=config.database_password,
+    database=config.database_name,
 )
 
 
@@ -22,10 +22,17 @@ class list:
 
     def __new__(cls, sort_key=None):
         cur = conn.cursor()
-        cur.execute("SELECT nkey FROM x509_certs")
+        cur.execute(
+            """SELECT x509_certs.nvalue AS cert,
+                      x509_certs_data.nvalue AS data,
+                      revoked_x509_certs.nvalue AS revoked
+               FROM x509_certs
+               INNER JOIN x509_certs_data USING(nkey)
+               LEFT JOIN revoked_x509_certs USING(nkey)"""
+        )
 
-        for (cert_serial,) in cur:
-            cert_object = cert(cert_serial)
+        for result in cur:
+            cert_object = cert(result)
             cls.certs.append(cert_object)
 
         cur.close()
@@ -36,14 +43,21 @@ class list:
         return cls.certs
 
 
-
 class cert:
-    def __init__(self, serial):
-        cert_der = self.get_cert(serial)
-        cert_data = self.get_cert_data(serial)
-        cert_revoked = self.get_cert_revoked(serial)
+    def __init__(self, cert):
+        (cert_der, cert_data_raw, cert_revoked_raw) = cert
+
+        cert_data = json.loads(cert_data_raw)
+        if cert_revoked_raw is not None:
+            cert_revoked = json.loads(cert_revoked_raw)
+        else:
+            cert_revoked = None
+
         self.load(cert_der, cert_data, cert_revoked)
 
+    @classmethod
+    def from_serial(cls, serial):
+        return cls(cert=cls.get_cert(cls, serial))
 
     def load(self, cert_der, cert_data, cert_revoked):
         cert = x509.load_der_x509_certificate(cert_der)
@@ -60,7 +74,9 @@ class cert:
         self.not_before = cert.not_valid_before_utc.replace(microsecond=0)
         self.not_after = cert.not_valid_after_utc.replace(microsecond=0)
         try:
-            san_data = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName)
+            san_data = cert.extensions.get_extension_for_class(
+                x509.SubjectAlternativeName
+            )
             self.san_names = san_data.value.get_values_for_type(x509.GeneralName)
         except x509.extensions.ExtensionNotFound:
             self.san_names = []
@@ -85,12 +101,21 @@ class cert:
         else:
             self.status = status(status.VALID)
 
-
     def get_cert(self, cert_serial):
         cur = conn.cursor()
-        cur.execute("SELECT nvalue FROM x509_certs WHERE nkey=?", (cert_serial,))
+        cur.execute(
+            """SELECT x509_certs.nvalue AS cert,
+                      x509_certs_data.nvalue AS data,
+                      revoked_x509_certs.nvalue AS revoked
+               FROM x509_certs
+               INNER JOIN x509_certs_data USING(nkey)
+               LEFT JOIN revoked_x509_certs USING(nkey)
+               WHERE nkey=?""",
+            (cert_serial,),
+        )
+
         if cur.rowcount > 0:
-            (cert,) = cur.fetchone()
+            cert = cur.fetchone()
         else:
             cert = None
 
@@ -98,30 +123,6 @@ class cert:
         return cert
 
 
-    def get_cert_data(self, cert_serial):
-        cur = conn.cursor()
-        cur.execute("SELECT nvalue FROM x509_certs_data WHERE nkey=?", (cert_serial,))
-        (cert_data_raw,) = cur.fetchone()
-        cur.close()
-        cert_data = json.loads(cert_data_raw)
-        return cert_data
-
-
-    def get_cert_revoked(self, cert_serial):
-        cur = conn.cursor()
-        cur.execute(
-            "SELECT nvalue FROM revoked_x509_certs WHERE nkey=?", (cert_serial,)
-        )
-        if cur.rowcount > 0:
-            (cert_revoked_raw,) = cur.fetchone()
-            cert_revoked = json.loads(cert_revoked_raw)
-        else:
-            cert_revoked = None
-
-        cur.close()
-        return cert_revoked
-
-
 class status:
     REVOKED = 1
     EXPIRED = 2
diff --git a/step-ca-inspector.py b/step-ca-inspector.py
index e942a1f..17c6c75 100755
--- a/step-ca-inspector.py
+++ b/step-ca-inspector.py
@@ -42,7 +42,7 @@ def list_ssh_certs(sort_key, revoked=False, expired=False):
 
 
 def get_ssh_cert(serial):
-    cert = ssh_cert.cert(serial)
+    cert = ssh_cert.cert.from_serial(serial)
     cert_tbl = []
 
     cert_tbl.append(["Serial", cert.serial])
@@ -68,7 +68,7 @@ def get_ssh_cert(serial):
 
 
 def dump_ssh_cert(serial):
-    cert = ssh_cert.cert(serial)
+    cert = ssh_cert.cert.from_serial(serial)
     print(cert.public_identity.decode())
 
 
@@ -105,7 +105,7 @@ def list_x509_certs(sort_key, revoked=False, expired=False):
 
 
 def get_x509_cert(serial, show_pem=False):
-    cert = x509_cert.cert(serial)
+    cert = x509_cert.cert.from_serial(serial)
     cert_tbl = []
 
     cert_tbl.append(["Serial", cert.serial])
@@ -135,7 +135,7 @@ def get_x509_cert(serial, show_pem=False):
 
 
 def dump_x509_cert(serial, cert_format="pem"):
-    cert = x509_cert.cert(serial)
+    cert = x509_cert.cert.from_serial(serial)
     print(cert.pem.decode("utf-8").rstrip())