Optimise SQL queries
This commit is contained in:
parent
c469dccb5c
commit
9859b1cd29
3 changed files with 75 additions and 64 deletions
|
@ -23,10 +23,15 @@ class list:
|
||||||
|
|
||||||
def __new__(cls, sort_key=None):
|
def __new__(cls, sort_key=None):
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute("SELECT nkey FROM ssh_certs")
|
cur.execute(
|
||||||
|
"""SELECT ssh_certs.nvalue AS cert,
|
||||||
|
revoked_ssh_certs.nvalue AS revoked
|
||||||
|
FROM ssh_certs
|
||||||
|
LEFT JOIN revoked_ssh_certs USING(nkey)"""
|
||||||
|
)
|
||||||
|
|
||||||
for (cert_serial,) in cur:
|
for result in cur:
|
||||||
cert_object = cert(cert_serial)
|
cert_object = cert(result)
|
||||||
cls.certs.append(cert_object)
|
cls.certs.append(cert_object)
|
||||||
|
|
||||||
cur.close()
|
cur.close()
|
||||||
|
@ -38,15 +43,25 @@ class list:
|
||||||
|
|
||||||
|
|
||||||
class cert:
|
class cert:
|
||||||
def __init__(self, serial):
|
def __init__(self, cert):
|
||||||
cert_raw = self.get_cert(serial)
|
(cert_raw, cert_revoked_raw) = cert
|
||||||
|
|
||||||
size = unpack(">I", cert_raw[:4])[0] + 4
|
size = unpack(">I", cert_raw[:4])[0] + 4
|
||||||
alg = cert_raw[4:size]
|
alg = cert_raw[4:size]
|
||||||
|
|
||||||
cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)])
|
cert_pub_id = b" ".join([alg, base64.b64encode(cert_raw)])
|
||||||
cert_revoked = self.get_cert_revoked(serial)
|
|
||||||
|
if cert_revoked_raw is not None:
|
||||||
|
cert_revoked = json.loads(cert_revoked_raw)
|
||||||
|
else:
|
||||||
|
cert_revoked = None
|
||||||
|
|
||||||
self.load(cert_pub_id, cert_revoked, alg)
|
self.load(cert_pub_id, cert_revoked, alg)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_serial(cls, serial):
|
||||||
|
return cls(cert=cls.get_cert(cls, serial))
|
||||||
|
|
||||||
def load(self, cert_pub_id, cert_revoked, cert_alg):
|
def load(self, cert_pub_id, cert_revoked, cert_alg):
|
||||||
cert = serialization.load_ssh_public_identity(cert_pub_id)
|
cert = serialization.load_ssh_public_identity(cert_pub_id)
|
||||||
self.serial = cert.serial
|
self.serial = cert.serial
|
||||||
|
@ -95,27 +110,22 @@ class cert:
|
||||||
|
|
||||||
def get_cert(self, cert_serial):
|
def get_cert(self, cert_serial):
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute("SELECT nvalue FROM ssh_certs WHERE nkey=?", (cert_serial,))
|
cur.execute(
|
||||||
|
"""SELECT ssh_certs.nvalue AS cert,
|
||||||
|
revoked_ssh_certs.nvalue AS revoked
|
||||||
|
FROM ssh_certs
|
||||||
|
LEFT JOIN revoked_ssh_certs USING(nkey)
|
||||||
|
WHERE nkey=?""",
|
||||||
|
(cert_serial,),
|
||||||
|
)
|
||||||
if cur.rowcount > 0:
|
if cur.rowcount > 0:
|
||||||
(cert,) = cur.fetchone()
|
cert = cur.fetchone()
|
||||||
else:
|
else:
|
||||||
cert = None
|
cert = None
|
||||||
|
|
||||||
cur.close()
|
cur.close()
|
||||||
return cert
|
return cert
|
||||||
|
|
||||||
def get_cert_revoked(self, cert_serial):
|
|
||||||
cur = conn.cursor()
|
|
||||||
cur.execute("SELECT nvalue FROM revoked_ssh_certs WHERE nkey=?", (cert_serial,))
|
|
||||||
if cur.rowcount > 0:
|
|
||||||
(cert_revoked_raw,) = cur.fetchone()
|
|
||||||
cert_revoked = json.loads(cert_revoked_raw)
|
|
||||||
else:
|
|
||||||
cert_revoked = None
|
|
||||||
|
|
||||||
cur.close()
|
|
||||||
return cert_revoked
|
|
||||||
|
|
||||||
def get_public_key_params(self, public_key):
|
def get_public_key_params(self, public_key):
|
||||||
if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey):
|
if isinstance(public_key, asymmetric.ec.EllipticCurvePublicKey):
|
||||||
key_type = "ECDSA"
|
key_type = "ECDSA"
|
||||||
|
|
|
@ -10,10 +10,10 @@ from models.config import config
|
||||||
|
|
||||||
config()
|
config()
|
||||||
conn = mariadb.connect(
|
conn = mariadb.connect(
|
||||||
host = config.database_host,
|
host=config.database_host,
|
||||||
user = config.database_user,
|
user=config.database_user,
|
||||||
password = config.database_password,
|
password=config.database_password,
|
||||||
database = config.database_name
|
database=config.database_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,10 +22,17 @@ class list:
|
||||||
|
|
||||||
def __new__(cls, sort_key=None):
|
def __new__(cls, sort_key=None):
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute("SELECT nkey FROM x509_certs")
|
cur.execute(
|
||||||
|
"""SELECT x509_certs.nvalue AS cert,
|
||||||
|
x509_certs_data.nvalue AS data,
|
||||||
|
revoked_x509_certs.nvalue AS revoked
|
||||||
|
FROM x509_certs
|
||||||
|
INNER JOIN x509_certs_data USING(nkey)
|
||||||
|
LEFT JOIN revoked_x509_certs USING(nkey)"""
|
||||||
|
)
|
||||||
|
|
||||||
for (cert_serial,) in cur:
|
for result in cur:
|
||||||
cert_object = cert(cert_serial)
|
cert_object = cert(result)
|
||||||
cls.certs.append(cert_object)
|
cls.certs.append(cert_object)
|
||||||
|
|
||||||
cur.close()
|
cur.close()
|
||||||
|
@ -36,14 +43,21 @@ class list:
|
||||||
return cls.certs
|
return cls.certs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class cert:
|
class cert:
|
||||||
def __init__(self, serial):
|
def __init__(self, cert):
|
||||||
cert_der = self.get_cert(serial)
|
(cert_der, cert_data_raw, cert_revoked_raw) = cert
|
||||||
cert_data = self.get_cert_data(serial)
|
|
||||||
cert_revoked = self.get_cert_revoked(serial)
|
cert_data = json.loads(cert_data_raw)
|
||||||
|
if cert_revoked_raw is not None:
|
||||||
|
cert_revoked = json.loads(cert_revoked_raw)
|
||||||
|
else:
|
||||||
|
cert_revoked = None
|
||||||
|
|
||||||
self.load(cert_der, cert_data, cert_revoked)
|
self.load(cert_der, cert_data, cert_revoked)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_serial(cls, serial):
|
||||||
|
return cls(cert=cls.get_cert(cls, serial))
|
||||||
|
|
||||||
def load(self, cert_der, cert_data, cert_revoked):
|
def load(self, cert_der, cert_data, cert_revoked):
|
||||||
cert = x509.load_der_x509_certificate(cert_der)
|
cert = x509.load_der_x509_certificate(cert_der)
|
||||||
|
@ -60,7 +74,9 @@ class cert:
|
||||||
self.not_before = cert.not_valid_before_utc.replace(microsecond=0)
|
self.not_before = cert.not_valid_before_utc.replace(microsecond=0)
|
||||||
self.not_after = cert.not_valid_after_utc.replace(microsecond=0)
|
self.not_after = cert.not_valid_after_utc.replace(microsecond=0)
|
||||||
try:
|
try:
|
||||||
san_data = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName)
|
san_data = cert.extensions.get_extension_for_class(
|
||||||
|
x509.SubjectAlternativeName
|
||||||
|
)
|
||||||
self.san_names = san_data.value.get_values_for_type(x509.GeneralName)
|
self.san_names = san_data.value.get_values_for_type(x509.GeneralName)
|
||||||
except x509.extensions.ExtensionNotFound:
|
except x509.extensions.ExtensionNotFound:
|
||||||
self.san_names = []
|
self.san_names = []
|
||||||
|
@ -85,12 +101,21 @@ class cert:
|
||||||
else:
|
else:
|
||||||
self.status = status(status.VALID)
|
self.status = status(status.VALID)
|
||||||
|
|
||||||
|
|
||||||
def get_cert(self, cert_serial):
|
def get_cert(self, cert_serial):
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute("SELECT nvalue FROM x509_certs WHERE nkey=?", (cert_serial,))
|
cur.execute(
|
||||||
|
"""SELECT x509_certs.nvalue AS cert,
|
||||||
|
x509_certs_data.nvalue AS data,
|
||||||
|
revoked_x509_certs.nvalue AS revoked
|
||||||
|
FROM x509_certs
|
||||||
|
INNER JOIN x509_certs_data USING(nkey)
|
||||||
|
LEFT JOIN revoked_x509_certs USING(nkey)
|
||||||
|
WHERE nkey=?""",
|
||||||
|
(cert_serial,),
|
||||||
|
)
|
||||||
|
|
||||||
if cur.rowcount > 0:
|
if cur.rowcount > 0:
|
||||||
(cert,) = cur.fetchone()
|
cert = cur.fetchone()
|
||||||
else:
|
else:
|
||||||
cert = None
|
cert = None
|
||||||
|
|
||||||
|
@ -98,30 +123,6 @@ class cert:
|
||||||
return cert
|
return cert
|
||||||
|
|
||||||
|
|
||||||
def get_cert_data(self, cert_serial):
|
|
||||||
cur = conn.cursor()
|
|
||||||
cur.execute("SELECT nvalue FROM x509_certs_data WHERE nkey=?", (cert_serial,))
|
|
||||||
(cert_data_raw,) = cur.fetchone()
|
|
||||||
cur.close()
|
|
||||||
cert_data = json.loads(cert_data_raw)
|
|
||||||
return cert_data
|
|
||||||
|
|
||||||
|
|
||||||
def get_cert_revoked(self, cert_serial):
|
|
||||||
cur = conn.cursor()
|
|
||||||
cur.execute(
|
|
||||||
"SELECT nvalue FROM revoked_x509_certs WHERE nkey=?", (cert_serial,)
|
|
||||||
)
|
|
||||||
if cur.rowcount > 0:
|
|
||||||
(cert_revoked_raw,) = cur.fetchone()
|
|
||||||
cert_revoked = json.loads(cert_revoked_raw)
|
|
||||||
else:
|
|
||||||
cert_revoked = None
|
|
||||||
|
|
||||||
cur.close()
|
|
||||||
return cert_revoked
|
|
||||||
|
|
||||||
|
|
||||||
class status:
|
class status:
|
||||||
REVOKED = 1
|
REVOKED = 1
|
||||||
EXPIRED = 2
|
EXPIRED = 2
|
||||||
|
|
|
@ -42,7 +42,7 @@ def list_ssh_certs(sort_key, revoked=False, expired=False):
|
||||||
|
|
||||||
|
|
||||||
def get_ssh_cert(serial):
|
def get_ssh_cert(serial):
|
||||||
cert = ssh_cert.cert(serial)
|
cert = ssh_cert.cert.from_serial(serial)
|
||||||
cert_tbl = []
|
cert_tbl = []
|
||||||
|
|
||||||
cert_tbl.append(["Serial", cert.serial])
|
cert_tbl.append(["Serial", cert.serial])
|
||||||
|
@ -68,7 +68,7 @@ def get_ssh_cert(serial):
|
||||||
|
|
||||||
|
|
||||||
def dump_ssh_cert(serial):
|
def dump_ssh_cert(serial):
|
||||||
cert = ssh_cert.cert(serial)
|
cert = ssh_cert.cert.from_serial(serial)
|
||||||
print(cert.public_identity.decode())
|
print(cert.public_identity.decode())
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,7 +105,7 @@ def list_x509_certs(sort_key, revoked=False, expired=False):
|
||||||
|
|
||||||
|
|
||||||
def get_x509_cert(serial, show_pem=False):
|
def get_x509_cert(serial, show_pem=False):
|
||||||
cert = x509_cert.cert(serial)
|
cert = x509_cert.cert.from_serial(serial)
|
||||||
cert_tbl = []
|
cert_tbl = []
|
||||||
|
|
||||||
cert_tbl.append(["Serial", cert.serial])
|
cert_tbl.append(["Serial", cert.serial])
|
||||||
|
@ -135,7 +135,7 @@ def get_x509_cert(serial, show_pem=False):
|
||||||
|
|
||||||
|
|
||||||
def dump_x509_cert(serial, cert_format="pem"):
|
def dump_x509_cert(serial, cert_format="pem"):
|
||||||
cert = x509_cert.cert(serial)
|
cert = x509_cert.cert.from_serial(serial)
|
||||||
print(cert.pem.decode("utf-8").rstrip())
|
print(cert.pem.decode("utf-8").rstrip())
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue