Optimise SQL queries
This commit is contained in:
parent
c469dccb5c
commit
9859b1cd29
3 changed files with 75 additions and 64 deletions
|
@ -23,10 +23,15 @@ class list:
|
|||
|
||||
def __new__(cls, sort_key=None):
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT nkey FROM ssh_certs")
|
||||
cur.execute(
|
||||
"""SELECT ssh_certs.nvalue AS cert,
|
||||
revoked_ssh_certs.nvalue AS revoked
|
||||
FROM ssh_certs
|
||||
LEFT JOIN revoked_ssh_certs USING(nkey)"""
|
||||
)
|
||||
|
||||
for (cert_serial,) in cur:
|
||||
cert_object = cert(cert_serial)
|
||||
for result in cur:
|
||||
cert_object = cert(result)
|
||||
cls.certs.append(cert_object)
|
||||
|
||||
cur.close()
|
||||
|
@ -38,15 +43,25 @@ class list:
|
|||
|
||||
|
||||
class cert:
|
||||
def __init__(self, serial):
|
||||
cert_raw = self.get_cert(serial)
|
||||
def __init__(self, cert):
|
||||
(cert_raw, cert_revoked_raw) = cert
|
||||
|
||||
size = unpack(">I", cert_raw[:4])[0] + 4
|
||||
alg = cert_raw[4:size]
|
||||
|
||||
cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)])
|
||||
cert_revoked = self.get_cert_revoked(serial)
|
||||
|
||||
if cert_revoked_raw is not None:
|
||||
cert_revoked = json.loads(cert_revoked_raw)
|
||||
else:
|
||||
cert_revoked = None
|
||||
|
||||
self.load(cert_pub_id, cert_revoked, alg)
|
||||
|
||||
@classmethod
|
||||
def from_serial(cls, serial):
|
||||
return cls(cert=cls.get_cert(cls, serial))
|
||||
|
||||
def load(self, cert_pub_id, cert_revoked, cert_alg):
|
||||
cert = serialization.load_ssh_public_identity(cert_pub_id)
|
||||
self.serial = cert.serial
|
||||
|
@ -95,27 +110,22 @@ class cert:
|
|||
|
||||
def get_cert(self, cert_serial):
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT nvalue FROM ssh_certs WHERE nkey=?", (cert_serial,))
|
||||
cur.execute(
|
||||
"""SELECT ssh_certs.nvalue AS cert,
|
||||
revoked_ssh_certs.nvalue AS revoked
|
||||
FROM ssh_certs
|
||||
LEFT JOIN revoked_ssh_certs USING(nkey)
|
||||
WHERE nkey=?""",
|
||||
(cert_serial,),
|
||||
)
|
||||
if cur.rowcount > 0:
|
||||
(cert,) = cur.fetchone()
|
||||
cert = cur.fetchone()
|
||||
else:
|
||||
cert = None
|
||||
|
||||
cur.close()
|
||||
return cert
|
||||
|
||||
def get_cert_revoked(self, cert_serial):
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT nvalue FROM revoked_ssh_certs WHERE nkey=?", (cert_serial,))
|
||||
if cur.rowcount > 0:
|
||||
(cert_revoked_raw,) = cur.fetchone()
|
||||
cert_revoked = json.loads(cert_revoked_raw)
|
||||
else:
|
||||
cert_revoked = None
|
||||
|
||||
cur.close()
|
||||
return cert_revoked
|
||||
|
||||
def get_public_key_params(self, public_key):
|
||||
if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey):
|
||||
key_type = "ECDSA"
|
||||
|
|
|
@ -10,10 +10,10 @@ from models.config import config
|
|||
|
||||
config()
|
||||
conn = mariadb.connect(
|
||||
host = config.database_host,
|
||||
user = config.database_user,
|
||||
password = config.database_password,
|
||||
database = config.database_name
|
||||
host=config.database_host,
|
||||
user=config.database_user,
|
||||
password=config.database_password,
|
||||
database=config.database_name,
|
||||
)
|
||||
|
||||
|
||||
|
@ -22,10 +22,17 @@ class list:
|
|||
|
||||
def __new__(cls, sort_key=None):
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT nkey FROM x509_certs")
|
||||
cur.execute(
|
||||
"""SELECT x509_certs.nvalue AS cert,
|
||||
x509_certs_data.nvalue AS data,
|
||||
revoked_x509_certs.nvalue AS revoked
|
||||
FROM x509_certs
|
||||
INNER JOIN x509_certs_data USING(nkey)
|
||||
LEFT JOIN revoked_x509_certs USING(nkey)"""
|
||||
)
|
||||
|
||||
for (cert_serial,) in cur:
|
||||
cert_object = cert(cert_serial)
|
||||
for result in cur:
|
||||
cert_object = cert(result)
|
||||
cls.certs.append(cert_object)
|
||||
|
||||
cur.close()
|
||||
|
@ -36,14 +43,21 @@ class list:
|
|||
return cls.certs
|
||||
|
||||
|
||||
|
||||
class cert:
|
||||
def __init__(self, serial):
|
||||
cert_der = self.get_cert(serial)
|
||||
cert_data = self.get_cert_data(serial)
|
||||
cert_revoked = self.get_cert_revoked(serial)
|
||||
def __init__(self, cert):
|
||||
(cert_der, cert_data_raw, cert_revoked_raw) = cert
|
||||
|
||||
cert_data = json.loads(cert_data_raw)
|
||||
if cert_revoked_raw is not None:
|
||||
cert_revoked = json.loads(cert_revoked_raw)
|
||||
else:
|
||||
cert_revoked = None
|
||||
|
||||
self.load(cert_der, cert_data, cert_revoked)
|
||||
|
||||
@classmethod
|
||||
def from_serial(cls, serial):
|
||||
return cls(cert=cls.get_cert(cls, serial))
|
||||
|
||||
def load(self, cert_der, cert_data, cert_revoked):
|
||||
cert = x509.load_der_x509_certificate(cert_der)
|
||||
|
@ -60,7 +74,9 @@ class cert:
|
|||
self.not_before = cert.not_valid_before_utc.replace(microsecond=0)
|
||||
self.not_after = cert.not_valid_after_utc.replace(microsecond=0)
|
||||
try:
|
||||
san_data = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName)
|
||||
san_data = cert.extensions.get_extension_for_class(
|
||||
x509.SubjectAlternativeName
|
||||
)
|
||||
self.san_names = san_data.value.get_values_for_type(x509.GeneralName)
|
||||
except x509.extensions.ExtensionNotFound:
|
||||
self.san_names = []
|
||||
|
@ -85,12 +101,21 @@ class cert:
|
|||
else:
|
||||
self.status = status(status.VALID)
|
||||
|
||||
|
||||
def get_cert(self, cert_serial):
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT nvalue FROM x509_certs WHERE nkey=?", (cert_serial,))
|
||||
cur.execute(
|
||||
"""SELECT x509_certs.nvalue AS cert,
|
||||
x509_certs_data.nvalue AS data,
|
||||
revoked_x509_certs.nvalue AS revoked
|
||||
FROM x509_certs
|
||||
INNER JOIN x509_certs_data USING(nkey)
|
||||
LEFT JOIN revoked_x509_certs USING(nkey)
|
||||
WHERE nkey=?""",
|
||||
(cert_serial,),
|
||||
)
|
||||
|
||||
if cur.rowcount > 0:
|
||||
(cert,) = cur.fetchone()
|
||||
cert = cur.fetchone()
|
||||
else:
|
||||
cert = None
|
||||
|
||||
|
@ -98,30 +123,6 @@ class cert:
|
|||
return cert
|
||||
|
||||
|
||||
def get_cert_data(self, cert_serial):
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT nvalue FROM x509_certs_data WHERE nkey=?", (cert_serial,))
|
||||
(cert_data_raw,) = cur.fetchone()
|
||||
cur.close()
|
||||
cert_data = json.loads(cert_data_raw)
|
||||
return cert_data
|
||||
|
||||
|
||||
def get_cert_revoked(self, cert_serial):
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"SELECT nvalue FROM revoked_x509_certs WHERE nkey=?", (cert_serial,)
|
||||
)
|
||||
if cur.rowcount > 0:
|
||||
(cert_revoked_raw,) = cur.fetchone()
|
||||
cert_revoked = json.loads(cert_revoked_raw)
|
||||
else:
|
||||
cert_revoked = None
|
||||
|
||||
cur.close()
|
||||
return cert_revoked
|
||||
|
||||
|
||||
class status:
|
||||
REVOKED = 1
|
||||
EXPIRED = 2
|
||||
|
|
|
@ -42,7 +42,7 @@ def list_ssh_certs(sort_key, revoked=False, expired=False):
|
|||
|
||||
|
||||
def get_ssh_cert(serial):
|
||||
cert = ssh_cert.cert(serial)
|
||||
cert = ssh_cert.cert.from_serial(serial)
|
||||
cert_tbl = []
|
||||
|
||||
cert_tbl.append(["Serial", cert.serial])
|
||||
|
@ -68,7 +68,7 @@ def get_ssh_cert(serial):
|
|||
|
||||
|
||||
def dump_ssh_cert(serial):
|
||||
cert = ssh_cert.cert(serial)
|
||||
cert = ssh_cert.cert.from_serial(serial)
|
||||
print(cert.public_identity.decode())
|
||||
|
||||
|
||||
|
@ -105,7 +105,7 @@ def list_x509_certs(sort_key, revoked=False, expired=False):
|
|||
|
||||
|
||||
def get_x509_cert(serial, show_pem=False):
|
||||
cert = x509_cert.cert(serial)
|
||||
cert = x509_cert.cert.from_serial(serial)
|
||||
cert_tbl = []
|
||||
|
||||
cert_tbl.append(["Serial", cert.serial])
|
||||
|
@ -135,7 +135,7 @@ def get_x509_cert(serial, show_pem=False):
|
|||
|
||||
|
||||
def dump_x509_cert(serial, cert_format="pem"):
|
||||
cert = x509_cert.cert(serial)
|
||||
cert = x509_cert.cert.from_serial(serial)
|
||||
print(cert.pem.decode("utf-8").rstrip())
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue