From fe02217028240d1691e243fac37fa766e23f935a Mon Sep 17 00:00:00 2001 From: Jan Dittberner Date: Sun, 26 May 2024 11:29:20 +0200 Subject: [PATCH] Format using isort and black --- scripts/check_db_certificates.py | 108 ++++++++++++++++++++++--------- 1 file changed, 77 insertions(+), 31 deletions(-) diff --git a/scripts/check_db_certificates.py b/scripts/check_db_certificates.py index edd3cf7..c38d775 100644 --- a/scripts/check_db_certificates.py +++ b/scripts/check_db_certificates.py @@ -3,13 +3,12 @@ import logging import os import re import typing -from datetime import datetime, timezone -from typing import NamedTuple - from cryptography import x509 from cryptography.exceptions import UnsupportedAlgorithm -from cryptography.hazmat.primitives.asymmetric import rsa, ec -from sqlalchemy import create_engine, select, MetaData, Table +from cryptography.hazmat.primitives.asymmetric import ec, rsa +from datetime import datetime, timezone +from sqlalchemy import MetaData, Table, create_engine, select +from typing import NamedTuple class CheckResult(NamedTuple): @@ -32,18 +31,27 @@ CODE_UNSUPPORTED_FORMAT = CheckResult("unsupported format", 1001) CODE_EMPTY = CheckResult("empty", 1002) CODE_DEPRECATED_SPKAC = CheckResult("deprecated SPKAC", 1003) CODE_INVALID_SIGNATURE = CheckResult("invalid signature", 1004) -CODE_UNSUPPORTED_SIGNATURE_ALGORITHM = CheckResult("unsupported signature algorithm", 1005) +CODE_UNSUPPORTED_SIGNATURE_ALGORITHM = CheckResult( + "unsupported signature algorithm", 1005 +) CODE_PUBLIC_KEY_TOO_WEAK = CheckResult("public key too weak", 1006) CODE_UNSUPPORTED_PUBLIC_KEY = CheckResult("unsupported public key", 1007) -CODE_CSR_AND_CRT_PUBLIC_KEY_MISMATCH = CheckResult("CSR and CRT public key mismatch", 1008) +CODE_CSR_AND_CRT_PUBLIC_KEY_MISMATCH = CheckResult( + "CSR and CRT public key mismatch", 1008 +) CODE_CERTIFICATE_FOR_INVALID_CSR = CheckResult("certificate for invalid CSR", 1009) -CODE_NOT_SIGNED_BY_EXPECTED_CA_CERTIFICATE = CheckResult("not signed by expected CA", 1010) +CODE_NOT_SIGNED_BY_EXPECTED_CA_CERTIFICATE = CheckResult( + "not signed by expected CA", 1010 +) CODE_CERTIFICATE_IS_EXPIRED = CheckResult("certificate is expired", 1011) SUPPORTED_SIGNATURE_ALGORITHMS = [ - x509.oid.SignatureAlgorithmOID.RSA_WITH_SHA256, x509.oid.SignatureAlgorithmOID.RSA_WITH_SHA384, - x509.oid.SignatureAlgorithmOID.RSA_WITH_SHA512, x509.oid.SignatureAlgorithmOID.ECDSA_WITH_SHA256, - x509.oid.SignatureAlgorithmOID.ECDSA_WITH_SHA384, x509.oid.SignatureAlgorithmOID.ECDSA_WITH_SHA512 + x509.oid.SignatureAlgorithmOID.RSA_WITH_SHA256, + x509.oid.SignatureAlgorithmOID.RSA_WITH_SHA384, + x509.oid.SignatureAlgorithmOID.RSA_WITH_SHA512, + x509.oid.SignatureAlgorithmOID.ECDSA_WITH_SHA256, + x509.oid.SignatureAlgorithmOID.ECDSA_WITH_SHA384, + x509.oid.SignatureAlgorithmOID.ECDSA_WITH_SHA512, ] @@ -99,21 +107,36 @@ class Counters: def __str__(self): return ( - f"good CSR and certificate: {self.good}\n" - f"good CSR, issue with certificate: {self.good_csr}\n" - f"good certificate, issue with CSR: {self.good_crt}\n" - f"failed CSR and certificate: {self.fail}\n" - f"missing certificate: {self.missing_crt}\n" - f"expired certificate: {self.expired_crt}\n\nCSR results:\n" - ) + "\n".join([f"{code}: {count:d}" for code, count in sorted(self.csr_codes.items())]) + ( - "\n\nCertificate results:\n" - ) + ( - "\n".join([f"{code}: {count:d}" for code, count in sorted(self.crt_codes.items())]) + ( + f"good CSR and certificate: {self.good}\n" + f"good CSR, issue with certificate: {self.good_csr}\n" + f"good certificate, issue with CSR: {self.good_crt}\n" + f"failed CSR and certificate: {self.fail}\n" + f"missing certificate: {self.missing_crt}\n" + f"expired certificate: {self.expired_crt}\n\nCSR results:\n" + ) + + "\n".join( + [f"{code}: {count:d}" for code, count in sorted(self.csr_codes.items())] + ) + + ("\n\nCertificate results:\n") + + ( + "\n".join( + [ + f"{code}: {count:d}" + for code, count in sorted(self.crt_codes.items()) + ] + ) + ) ) class Analyzer: - def __init__(self, logger: logging.Logger, dsn: str, ca_certificates: dict[int, x509.Certificate]): + def __init__( + self, + logger: logging.Logger, + dsn: str, + ca_certificates: dict[int, x509.Certificate], + ): self.logger = logger self.engine = create_engine(dsn) self.ca_certificates = ca_certificates @@ -138,10 +161,17 @@ class Analyzer: crt_code = self.check_crt(row.crt_name, ca_cert, public_numbers) self.c.count_crt(crt_code) - if csr_code != CODE_OK and crt_code not in (CODE_OK, CODE_CERTIFICATE_IS_EXPIRED): + if csr_code != CODE_OK and crt_code not in ( + CODE_OK, + CODE_CERTIFICATE_IS_EXPIRED, + ): self.logger.debug( "%06d %s: %06d -> csr_code: %s, crt_code: %s", - self.c.fail, table, row.id, csr_code, crt_code + self.c.fail, + table, + row.id, + csr_code, + crt_code, ) self.c.count_fail() return @@ -156,7 +186,10 @@ class Analyzer: self.c.count_missing_crt() return - if csr_code == CODE_OK and crt_code not in (CODE_OK, CODE_CERTIFICATE_IS_EXPIRED): + if csr_code == CODE_OK and crt_code not in ( + CODE_OK, + CODE_CERTIFICATE_IS_EXPIRED, + ): self.c.count_good_csr() return @@ -207,7 +240,12 @@ class Analyzer: return CODE_OK, public_key.public_numbers() - def check_crt(self, crt_name: str, ca_certificate: x509.Certificate, public_numbers: typing.Any) -> CheckResult: + def check_crt( + self, + crt_name: str, + ca_certificate: x509.Certificate, + public_numbers: typing.Any, + ) -> CheckResult: if not crt_name: return CODE_EMPTY @@ -248,9 +286,13 @@ class Analyzer: try: crt.verify_directly_issued_by(ca_certificate) except Exception as e: - self.logger.debug("certificate verification failed: %s\n issuer of certificate: %s\n" - " CA certificate: %s", e, crt.issuer, - ca_certificate.subject) + self.logger.debug( + "certificate verification failed: %s\n issuer of certificate: %s\n" + " CA certificate: %s", + e, + crt.issuer, + ca_certificate.subject, + ) return CODE_NOT_SIGNED_BY_EXPECTED_CA_CERTIFICATE return CODE_OK @@ -266,10 +308,14 @@ def main(): db_port = os.getenv("DB_PORT", "3306") db_name = os.getenv("DB_NAME", "cacert") root_ca_cert: str = os.getenv("ROOT_CA_CERTIFICATE", "../www/certs/root_X0F.crt") - sub_ca_cert = os.getenv("SUB_CA_CERTIFICATE", "../www/certs/CAcert_Class3Root_x14E228.crt") + sub_ca_cert = os.getenv( + "SUB_CA_CERTIFICATE", "../www/certs/CAcert_Class3Root_x14E228.crt" + ) debug = bool(os.getenv("DEBUG", "false")) - logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s") + logging.basicConfig( + level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s" + ) logger = logging.getLogger(__name__) if debug: