#!/usr/bin/env python3 import logging import os import re import typing from datetime import datetime, timezone from typing import NamedTuple from cryptography import x509 from cryptography.exceptions import UnsupportedAlgorithm from cryptography.hazmat.primitives.asymmetric import rsa, ec from sqlalchemy import create_engine, select, MetaData, Table class CheckResult(NamedTuple): name: str code: int def __str__(self): return f"[{self.code:04d}: {self.name}]" def __lt__(self, other): return self.code < other.code def __gt__(self, other): return self.code > other.code CODE_OK = CheckResult("OK", 0) CODE_FILE_MISSING = CheckResult("file missing", 1000) CODE_UNSUPPORTED_FORMAT = CheckResult("unsupported format", 1001) CODE_EMPTY = CheckResult("empty", 1002) CODE_DEPRECATED_SPKAC = CheckResult("deprecated SPKAC", 1003) CODE_INVALID_SIGNATURE = CheckResult("invalid signature", 1004) CODE_UNSUPPORTED_SIGNATURE_ALGORITHM = CheckResult("unsupported signature algorithm", 1005) CODE_PUBLIC_KEY_TOO_WEAK = CheckResult("public key too weak", 1006) CODE_UNSUPPORTED_PUBLIC_KEY = CheckResult("unsupported public key", 1007) CODE_CSR_AND_CRT_PUBLIC_KEY_MISMATCH = CheckResult("CSR and CRT public key mismatch", 1008) CODE_CERTIFICATE_FOR_INVALID_CSR = CheckResult("certificate for invalid CSR", 1009) CODE_NOT_SIGNED_BY_EXPECTED_CA_CERTIFICATE = CheckResult("not signed by expected CA", 1010) CODE_CERTIFICATE_IS_EXPIRED = CheckResult("certificate is expired", 1011) SUPPORTED_SIGNATURE_ALGORITHMS = [ x509.oid.SignatureAlgorithmOID.RSA_WITH_SHA256, x509.oid.SignatureAlgorithmOID.RSA_WITH_SHA384, x509.oid.SignatureAlgorithmOID.RSA_WITH_SHA512, x509.oid.SignatureAlgorithmOID.ECDSA_WITH_SHA256, x509.oid.SignatureAlgorithmOID.ECDSA_WITH_SHA384, x509.oid.SignatureAlgorithmOID.ECDSA_WITH_SHA512 ] def load_ca_certificates(root_ca_cert, sub_ca_cert) -> dict[int, x509.Certificate]: with open(root_ca_cert, "rb") as f: root_cert = x509.load_pem_x509_certificate(f.read()) with open(sub_ca_cert, "rb") as f: class3_cert = x509.load_pem_x509_certificate(f.read()) return { 1: root_cert, 2: class3_cert, } class Counters: fail = 0 good = 0 good_csr = 0 good_crt = 0 missing_crt = 0 expired_crt = 0 csr_codes: dict[CheckResult, int] = {} crt_codes: dict[CheckResult, int] = {} def count_fail(self): self.fail += 1 def count_good(self): self.good += 1 def count_good_csr(self): self.good_csr += 1 def count_good_crt(self): self.good_crt += 1 def count_missing_crt(self): self.missing_crt += 1 def count_expired_crt(self): self.expired_crt += 1 def count_csr(self, csr_code: CheckResult): self.csr_codes.setdefault(csr_code, 0) self.csr_codes[csr_code] += 1 def count_crt(self, crt_code: CheckResult): self.crt_codes.setdefault(crt_code, 0) self.crt_codes[crt_code] += 1 def __str__(self): return ( f"good CSR and certificate: {self.good}\n" f"good CSR, issue with certificate: {self.good_csr}\n" f"good certificate, issue with CSR: {self.good_crt}\n" f"failed CSR and certificate: {self.fail}\n" f"missing certificate: {self.missing_crt}\n" f"expired certificate: {self.expired_crt}\n\nCSR results:\n" ) + "\n".join([f"{code}: {count:d}" for code, count in sorted(self.csr_codes.items())]) + ( "\n\nCertificate results:\n" ) + ( "\n".join([f"{code}: {count:d}" for code, count in sorted(self.crt_codes.items())]) ) class Analyzer: def __init__(self, logger: logging.Logger, dsn: str, ca_certificates: dict[int, x509.Certificate]): self.logger = logger self.engine = create_engine(dsn) self.ca_certificates = ca_certificates self.c = Counters() def analyze(self): metadata_obj = MetaData() for table in ("emailcerts", "domaincerts", "orgemailcerts", "orgdomaincerts"): certs_table = Table(table, metadata_obj, autoload_with=self.engine) stmt = select(certs_table) with self.engine.connect() as conn: for row in conn.execute(stmt): self.analyze_row(table, row) def analyze_row(self, table: str, row) -> None: ca_cert = self.ca_certificates[row.rootcert] csr_code, public_numbers = self.check_csr(row.csr_name) self.c.count_csr(csr_code) crt_code = self.check_crt(row.crt_name, ca_cert, public_numbers) self.c.count_crt(crt_code) if csr_code != CODE_OK and crt_code not in (CODE_OK, CODE_CERTIFICATE_IS_EXPIRED): self.logger.debug( "%06d %s: %06d -> csr_code: %s, crt_code: %s", self.c.fail, table, row.id, csr_code, crt_code ) self.c.count_fail() return if csr_code != CODE_OK: self.c.count_good_crt() if crt_code == CODE_CERTIFICATE_IS_EXPIRED: self.c.count_expired_crt() return if crt_code == CODE_EMPTY: self.c.count_missing_crt() return if csr_code == CODE_OK and crt_code not in (CODE_OK, CODE_CERTIFICATE_IS_EXPIRED): self.c.count_good_csr() return if crt_code == CODE_CERTIFICATE_IS_EXPIRED: self.c.count_expired_crt() self.c.count_good() def check_csr(self, csr_name: str) -> [CheckResult, typing.Any]: if not csr_name: return CODE_EMPTY, None if not os.path.isfile(csr_name): return CODE_FILE_MISSING, None with open(csr_name, "rb") as f: csr_data = f.read() if re.search(r"SPKAC = ", csr_data.decode("iso-8859-1")): return CODE_DEPRECATED_SPKAC, None try: csr = x509.load_pem_x509_csr(csr_data) except Exception as e: self.logger.debug("unsupported CSR format: %s for\n%s", e, csr_data) return CODE_UNSUPPORTED_FORMAT, None if csr.signature_algorithm_oid not in SUPPORTED_SIGNATURE_ALGORITHMS: return CODE_UNSUPPORTED_SIGNATURE_ALGORITHM, None try: if not csr.is_signature_valid: return CODE_INVALID_SIGNATURE, None except Exception as e: self.logger.debug("CSR signature check failed: %s for \n%s", e, csr_data) return CODE_INVALID_SIGNATURE, None public_key = csr.public_key() if isinstance(public_key, rsa.RSAPublicKey): if public_key.key_size < 2048: return CODE_PUBLIC_KEY_TOO_WEAK, None elif isinstance(public_key, ec.EllipticCurvePublicKey): if public_key.key_size < 256: return CODE_PUBLIC_KEY_TOO_WEAK, None else: return CODE_UNSUPPORTED_PUBLIC_KEY, None return CODE_OK, public_key.public_numbers() def check_crt(self, crt_name: str, ca_certificate: x509.Certificate, public_numbers: typing.Any) -> CheckResult: if not crt_name: return CODE_EMPTY if not os.path.isfile(crt_name): return CODE_FILE_MISSING with open(crt_name, "rb") as f: crt_data = f.read() try: crt = x509.load_pem_x509_certificate(crt_data) except ValueError: return CODE_UNSUPPORTED_FORMAT if crt.signature_algorithm_oid not in SUPPORTED_SIGNATURE_ALGORITHMS: return CODE_UNSUPPORTED_SIGNATURE_ALGORITHM try: public_key = crt.public_key() except UnsupportedAlgorithm: return CODE_UNSUPPORTED_PUBLIC_KEY if isinstance(public_key, rsa.RSAPublicKey): if public_key.key_size < 2048: return CODE_PUBLIC_KEY_TOO_WEAK elif isinstance(public_key, ec.EllipticCurvePublicKey): if public_key.key_size < 256: return CODE_PUBLIC_KEY_TOO_WEAK else: return CODE_UNSUPPORTED_PUBLIC_KEY if public_numbers and public_key.public_numbers() != public_numbers: return CODE_CSR_AND_CRT_PUBLIC_KEY_MISMATCH if crt.not_valid_after_utc < datetime.now(timezone.utc): return CODE_CERTIFICATE_IS_EXPIRED try: crt.verify_directly_issued_by(ca_certificate) except Exception as e: self.logger.debug("certificate verification failed: %s\n issuer of certificate: %s\n" " CA certificate: %s", e, crt.issuer, ca_certificate.subject) return CODE_NOT_SIGNED_BY_EXPECTED_CA_CERTIFICATE return CODE_OK def get_statistics(self): self.logger.info("Statistics:\n%s", self.c) def main(): db_user = os.getenv("DB_USER", default="cacert") db_password = os.getenv("DB_PASSWORD", default="cacert") db_host = os.getenv("DB_HOST", "localhost") db_port = os.getenv("DB_PORT", "3306") db_name = os.getenv("DB_NAME", "cacert") root_ca_cert: str = os.getenv("ROOT_CA_CERTIFICATE", "../www/certs/root_X0F.crt") sub_ca_cert = os.getenv("SUB_CA_CERTIFICATE", "../www/certs/CAcert_Class3Root_x14E228.crt") debug = bool(os.getenv("DEBUG", "false")) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s") logger = logging.getLogger(__name__) if debug: logger.level = logging.DEBUG analyzer = Analyzer( logger=logger, dsn=f"mariadb+mariadbconnector://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}", ca_certificates=load_ca_certificates(root_ca_cert, sub_ca_cert), ) analyzer.analyze() analyzer.get_statistics() if __name__ == "__main__": main()