cacert-webdb/scripts/check_db_certificates.py

290 lines
9.7 KiB
Python

#!/usr/bin/env python3
import logging
import os
import re
import typing
from datetime import datetime, timezone
from typing import NamedTuple
from cryptography import x509
from cryptography.exceptions import UnsupportedAlgorithm
from cryptography.hazmat.primitives.asymmetric import rsa, ec
from sqlalchemy import create_engine, select, MetaData, Table
class CheckResult(NamedTuple):
name: str
code: int
def __str__(self):
return f"[{self.code:04d}: {self.name}]"
def __lt__(self, other):
return self.code < other.code
def __gt__(self, other):
return self.code > other.code
CODE_OK = CheckResult("OK", 0)
CODE_FILE_MISSING = CheckResult("file missing", 1000)
CODE_UNSUPPORTED_FORMAT = CheckResult("unsupported format", 1001)
CODE_EMPTY = CheckResult("empty", 1002)
CODE_DEPRECATED_SPKAC = CheckResult("deprecated SPKAC", 1003)
CODE_INVALID_SIGNATURE = CheckResult("invalid signature", 1004)
CODE_UNSUPPORTED_SIGNATURE_ALGORITHM = CheckResult("unsupported signature algorithm", 1005)
CODE_PUBLIC_KEY_TOO_WEAK = CheckResult("public key too weak", 1006)
CODE_UNSUPPORTED_PUBLIC_KEY = CheckResult("unsupported public key", 1007)
CODE_CSR_AND_CRT_PUBLIC_KEY_MISMATCH = CheckResult("CSR and CRT public key mismatch", 1008)
CODE_CERTIFICATE_FOR_INVALID_CSR = CheckResult("certificate for invalid CSR", 1009)
CODE_NOT_SIGNED_BY_EXPECTED_CA_CERTIFICATE = CheckResult("not signed by expected CA", 1010)
CODE_CERTIFICATE_IS_EXPIRED = CheckResult("certificate is expired", 1011)
SUPPORTED_SIGNATURE_ALGORITHMS = [
x509.oid.SignatureAlgorithmOID.RSA_WITH_SHA256, x509.oid.SignatureAlgorithmOID.RSA_WITH_SHA384,
x509.oid.SignatureAlgorithmOID.RSA_WITH_SHA512, x509.oid.SignatureAlgorithmOID.ECDSA_WITH_SHA256,
x509.oid.SignatureAlgorithmOID.ECDSA_WITH_SHA384, x509.oid.SignatureAlgorithmOID.ECDSA_WITH_SHA512
]
def load_ca_certificates(root_ca_cert, sub_ca_cert) -> dict[int, x509.Certificate]:
with open(root_ca_cert, "rb") as f:
root_cert = x509.load_pem_x509_certificate(f.read())
with open(sub_ca_cert, "rb") as f:
class3_cert = x509.load_pem_x509_certificate(f.read())
return {
1: root_cert,
2: class3_cert,
}
class Counters:
fail = 0
good = 0
good_csr = 0
good_crt = 0
missing_crt = 0
expired_crt = 0
csr_codes: dict[CheckResult, int] = {}
crt_codes: dict[CheckResult, int] = {}
def count_fail(self):
self.fail += 1
def count_good(self):
self.good += 1
def count_good_csr(self):
self.good_csr += 1
def count_good_crt(self):
self.good_crt += 1
def count_missing_crt(self):
self.missing_crt += 1
def count_expired_crt(self):
self.expired_crt += 1
def count_csr(self, csr_code: CheckResult):
self.csr_codes.setdefault(csr_code, 0)
self.csr_codes[csr_code] += 1
def count_crt(self, crt_code: CheckResult):
self.crt_codes.setdefault(crt_code, 0)
self.crt_codes[crt_code] += 1
def __str__(self):
return (
f"good CSR and certificate: {self.good}\n"
f"good CSR, issue with certificate: {self.good_csr}\n"
f"good certificate, issue with CSR: {self.good_crt}\n"
f"failed CSR and certificate: {self.fail}\n"
f"missing certificate: {self.missing_crt}\n"
f"expired certificate: {self.expired_crt}\n\nCSR results:\n"
) + "\n".join([f"{code}: {count:d}" for code, count in sorted(self.csr_codes.items())]) + (
"\n\nCertificate results:\n"
) + (
"\n".join([f"{code}: {count:d}" for code, count in sorted(self.crt_codes.items())])
)
class Analyzer:
def __init__(self, logger: logging.Logger, dsn: str, ca_certificates: dict[int, x509.Certificate]):
self.logger = logger
self.engine = create_engine(dsn)
self.ca_certificates = ca_certificates
self.c = Counters()
def analyze(self):
metadata_obj = MetaData()
for table in ("emailcerts", "domaincerts", "orgemailcerts", "orgdomaincerts"):
certs_table = Table(table, metadata_obj, autoload_with=self.engine)
stmt = select(certs_table)
with self.engine.connect() as conn:
for row in conn.execute(stmt):
self.analyze_row(table, row)
def analyze_row(self, table: str, row) -> None:
ca_cert = self.ca_certificates[row.rootcert]
csr_code, public_numbers = self.check_csr(row.csr_name)
self.c.count_csr(csr_code)
crt_code = self.check_crt(row.crt_name, ca_cert, public_numbers)
self.c.count_crt(crt_code)
if csr_code != CODE_OK and crt_code not in (CODE_OK, CODE_CERTIFICATE_IS_EXPIRED):
self.logger.debug(
"%06d %s: %06d -> csr_code: %s, crt_code: %s",
self.c.fail, table, row.id, csr_code, crt_code
)
self.c.count_fail()
return
if csr_code != CODE_OK:
self.c.count_good_crt()
if crt_code == CODE_CERTIFICATE_IS_EXPIRED:
self.c.count_expired_crt()
return
if crt_code == CODE_EMPTY:
self.c.count_missing_crt()
return
if csr_code == CODE_OK and crt_code not in (CODE_OK, CODE_CERTIFICATE_IS_EXPIRED):
self.c.count_good_csr()
return
if crt_code == CODE_CERTIFICATE_IS_EXPIRED:
self.c.count_expired_crt()
self.c.count_good()
def check_csr(self, csr_name: str) -> [CheckResult, typing.Any]:
if not csr_name:
return CODE_EMPTY, None
if not os.path.isfile(csr_name):
return CODE_FILE_MISSING, None
with open(csr_name, "rb") as f:
csr_data = f.read()
if re.search(r"SPKAC = ", csr_data.decode("iso-8859-1")):
return CODE_DEPRECATED_SPKAC, None
try:
csr = x509.load_pem_x509_csr(csr_data)
except Exception as e:
self.logger.debug("unsupported CSR format: %s for\n%s", e, csr_data)
return CODE_UNSUPPORTED_FORMAT, None
if csr.signature_algorithm_oid not in SUPPORTED_SIGNATURE_ALGORITHMS:
return CODE_UNSUPPORTED_SIGNATURE_ALGORITHM, None
try:
if not csr.is_signature_valid:
return CODE_INVALID_SIGNATURE, None
except Exception as e:
self.logger.debug("CSR signature check failed: %s for \n%s", e, csr_data)
return CODE_INVALID_SIGNATURE, None
public_key = csr.public_key()
if isinstance(public_key, rsa.RSAPublicKey):
if public_key.key_size < 2048:
return CODE_PUBLIC_KEY_TOO_WEAK, None
elif isinstance(public_key, ec.EllipticCurvePublicKey):
if public_key.key_size < 256:
return CODE_PUBLIC_KEY_TOO_WEAK, None
else:
return CODE_UNSUPPORTED_PUBLIC_KEY, None
return CODE_OK, public_key.public_numbers()
def check_crt(self, crt_name: str, ca_certificate: x509.Certificate, public_numbers: typing.Any) -> CheckResult:
if not crt_name:
return CODE_EMPTY
if not os.path.isfile(crt_name):
return CODE_FILE_MISSING
with open(crt_name, "rb") as f:
crt_data = f.read()
try:
crt = x509.load_pem_x509_certificate(crt_data)
except ValueError:
return CODE_UNSUPPORTED_FORMAT
if crt.signature_algorithm_oid not in SUPPORTED_SIGNATURE_ALGORITHMS:
return CODE_UNSUPPORTED_SIGNATURE_ALGORITHM
try:
public_key = crt.public_key()
except UnsupportedAlgorithm:
return CODE_UNSUPPORTED_PUBLIC_KEY
if isinstance(public_key, rsa.RSAPublicKey):
if public_key.key_size < 2048:
return CODE_PUBLIC_KEY_TOO_WEAK
elif isinstance(public_key, ec.EllipticCurvePublicKey):
if public_key.key_size < 256:
return CODE_PUBLIC_KEY_TOO_WEAK
else:
return CODE_UNSUPPORTED_PUBLIC_KEY
if public_numbers and public_key.public_numbers() != public_numbers:
return CODE_CSR_AND_CRT_PUBLIC_KEY_MISMATCH
if crt.not_valid_after_utc < datetime.now(timezone.utc):
return CODE_CERTIFICATE_IS_EXPIRED
try:
crt.verify_directly_issued_by(ca_certificate)
except Exception as e:
self.logger.debug("certificate verification failed: %s\n issuer of certificate: %s\n"
" CA certificate: %s", e, crt.issuer,
ca_certificate.subject)
return CODE_NOT_SIGNED_BY_EXPECTED_CA_CERTIFICATE
return CODE_OK
def get_statistics(self):
self.logger.info("Statistics:\n%s", self.c)
def main():
db_user = os.getenv("DB_USER", default="cacert")
db_password = os.getenv("DB_PASSWORD", default="cacert")
db_host = os.getenv("DB_HOST", "localhost")
db_port = os.getenv("DB_PORT", "3306")
db_name = os.getenv("DB_NAME", "cacert")
root_ca_cert: str = os.getenv("ROOT_CA_CERTIFICATE", "../www/certs/root_X0F.crt")
sub_ca_cert = os.getenv("SUB_CA_CERTIFICATE", "../www/certs/CAcert_Class3Root_x14E228.crt")
debug = bool(os.getenv("DEBUG", "false"))
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
if debug:
logger.level = logging.DEBUG
analyzer = Analyzer(
logger=logger,
dsn=f"mariadb+mariadbconnector://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}",
ca_certificates=load_ca_certificates(root_ca_cert, sub_ca_cert),
)
analyzer.analyze()
analyzer.get_statistics()
if __name__ == "__main__":
main()