diff --git a/pyproject.toml b/pyproject.toml index 87068b18..03f30491 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,6 @@ requires-python = ">= 3.9" dependencies = [ "cryptography >=3.1", "defusedxml", - "pyopenssl <24.3.0", - "python-dateutil", "requests >=2.0.0,<3.0.0", # ^2 means compatible with 2.x "xmlschema >=2.0.0,<3.0.0" ] diff --git a/src/saml2/cert.py b/src/saml2/cert.py index e90651e4..926aec8a 100644 --- a/src/saml2/cert.py +++ b/src/saml2/cert.py @@ -3,11 +3,13 @@ __author__ = "haho0032" import base64 from os import remove from os.path import join -from datetime import datetime -from datetime import timezone +from datetime import datetime, timedelta, timezone -from OpenSSL import crypto -import dateutil.parser +from cryptography import x509 +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID import saml2.cryptography.pki @@ -36,7 +38,6 @@ class OpenSSLWrapper: valid_to=315360000, sn=1, key_length=1024, - hash_alg="sha256", write_to_file=False, cert_dir="", cipher_passphrase=None, @@ -87,8 +88,6 @@ class OpenSSLWrapper: is 1. :param key_length: Length of the key to be generated. Defaults to 1024. - :param hash_alg: Hash algorithm to use for the key. Default - is sha256. :param write_to_file: True if you want to write the certificate to a file. The method will then return a tuple with path to certificate file and @@ -131,49 +130,68 @@ class OpenSSLWrapper: k_f = join(cert_dir, key_file) # create a key pair - k = crypto.PKey() - k.generate_key(crypto.TYPE_RSA, key_length) + k = rsa.generate_private_key( + public_exponent=65537, + key_size=key_length, + ) # create a self-signed cert - cert = crypto.X509() + builder = x509.CertificateBuilder() if request: - cert = crypto.X509Req() + builder = x509.CertificateSigningRequestBuilder() if len(cert_info["country_code"]) != 2: raise WrongInput("Country code must be two letters!") - cert.get_subject().C = cert_info["country_code"] - cert.get_subject().ST = cert_info["state"] - cert.get_subject().L = cert_info["city"] - cert.get_subject().O = cert_info["organization"] # noqa: E741 - cert.get_subject().OU = cert_info["organization_unit"] - cert.get_subject().CN = cn + subject_name = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, + cert_info["country_code"]), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, + cert_info["state"]), + x509.NameAttribute(NameOID.LOCALITY_NAME, + cert_info["city"]), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, + cert_info["organization"]), + x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, + cert_info["organization_unit"]), + x509.NameAttribute(NameOID.COMMON_NAME, cn), + ]) + builder = builder.subject_name(subject_name) if not request: - cert.set_serial_number(sn) - cert.gmtime_adj_notBefore(valid_from) # Valid before present time - cert.gmtime_adj_notAfter(valid_to) # 3 650 days - cert.set_issuer(cert.get_subject()) - cert.set_pubkey(k) - cert.sign(k, hash_alg) + now = datetime.now(timezone.utc) + builder = builder.serial_number( + sn, + ).not_valid_before( + now + timedelta(seconds=valid_from), + ).not_valid_after( + now + timedelta(seconds=valid_to), + ).issuer_name( + subject_name, + ).public_key( + k.public_key(), + ) + cert = builder.sign(k, hashes.SHA256()) try: - if request: - tmp_cert = crypto.dump_certificate_request(crypto.FILETYPE_PEM, cert) - else: - tmp_cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) - tmp_key = None + tmp_cert = cert.public_bytes(serialization.Encoding.PEM) + key_encryption = None if cipher_passphrase is not None: passphrase = cipher_passphrase["passphrase"] if isinstance(cipher_passphrase["passphrase"], str): passphrase = passphrase.encode("utf-8") - tmp_key = crypto.dump_privatekey(crypto.FILETYPE_PEM, k, cipher_passphrase["cipher"], passphrase) + key_encryption = serialization.BestAvailableEncryption(passphrase) else: - tmp_key = crypto.dump_privatekey(crypto.FILETYPE_PEM, k) + key_encryption = serialization.NoEncryption() + tmp_key = k.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=key_encryption, + ) if write_to_file: - with open(c_f, "w") as fc: - fc.write(tmp_cert.decode("utf-8")) - with open(k_f, "w") as fk: - fk.write(tmp_key.decode("utf-8")) + with open(c_f, "wb") as fc: + fc.write(tmp_cert) + with open(k_f, "wb") as fk: + fk.write(tmp_key) return c_f, k_f return tmp_cert, tmp_key except Exception as ex: @@ -198,7 +216,6 @@ class OpenSSLWrapper: sign_cert_str, sign_key_str, request_cert_str, - hash_alg="sha256", valid_from=0, valid_to=315360000, sn=1, @@ -222,8 +239,6 @@ class OpenSSLWrapper: the requested certificate. If you only have a file use the method read_str_from_file to get a string representation. - :param hash_alg: Hash algorithm to use for the key. Default - is sha256. :param valid_from: When the certificate starts to be valid. Amount of seconds from when the certificate is generated. @@ -237,27 +252,29 @@ class OpenSSLWrapper: :return: String representation of the signed certificate. """ - ca_cert = crypto.load_certificate(crypto.FILETYPE_PEM, sign_cert_str) - ca_key = None - if passphrase is not None: - ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, sign_key_str, passphrase) - else: - ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, sign_key_str) - req_cert = crypto.load_certificate_request(crypto.FILETYPE_PEM, request_cert_str) - - cert = crypto.X509() - cert.set_subject(req_cert.get_subject()) - cert.set_serial_number(sn) - cert.gmtime_adj_notBefore(valid_from) - cert.gmtime_adj_notAfter(valid_to) - cert.set_issuer(ca_cert.get_subject()) - cert.set_pubkey(req_cert.get_pubkey()) - cert.sign(ca_key, hash_alg) - - cert_dump = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) - if isinstance(cert_dump, str): - return cert_dump - return cert_dump.decode("utf-8") + if isinstance(sign_cert_str, str): + sign_cert_str = sign_cert_str.encode("utf-8") + ca_cert = x509.load_pem_x509_certificate(sign_cert_str) + ca_key = serialization.load_pem_private_key( + sign_key_str, password=passphrase) + req_cert = x509.load_pem_x509_csr(request_cert_str) + + now = datetime.now(timezone.utc) + cert = x509.CertificateBuilder().subject_name( + req_cert.subject, + ).serial_number( + sn, + ).not_valid_before( + now + timedelta(seconds=valid_from), + ).not_valid_after( + now + timedelta(seconds=valid_to), + ).issuer_name( + ca_cert.subject, + ).public_key( + req_cert.public_key(), + ).sign(ca_key, hashes.SHA256()) + + return cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") def verify_chain(self, cert_chain_str_list, cert_str): """ @@ -276,13 +293,6 @@ class OpenSSLWrapper: cert_str = tmp_cert_str return (True, "Signed certificate is valid and correctly signed by CA " "certificate.") - def certificate_not_valid_yet(self, cert): - starts_to_be_valid = dateutil.parser.parse(cert.get_notBefore()) - now = datetime.now(timezone.utc) - if starts_to_be_valid < now: - return False - return True - def verify(self, signing_cert_str, cert_str): """ Verifies if a certificate is valid and signed by a given certificate. @@ -303,34 +313,34 @@ class OpenSSLWrapper: Message = Why the validation failed. """ try: - ca_cert = crypto.load_certificate(crypto.FILETYPE_PEM, signing_cert_str) - cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_str) - - if self.certificate_not_valid_yet(ca_cert): + if isinstance(signing_cert_str, str): + signing_cert_str = signing_cert_str.encode("utf-8") + if isinstance(cert_str, str): + cert_str = cert_str.encode("utf-8") + ca_cert = x509.load_pem_x509_certificate(signing_cert_str) + cert = x509.load_pem_x509_certificate(cert_str) + now = datetime.now(timezone.utc) + + if ca_cert.not_valid_before_utc >= now: return False, "CA certificate is not valid yet." - if ca_cert.has_expired() == 1: + if ca_cert.not_valid_after_utc < now: return False, "CA certificate is expired." - if cert.has_expired() == 1: + if cert.not_valid_after_utc < now: return False, "The signed certificate is expired." - if self.certificate_not_valid_yet(cert): + if cert.not_valid_before_utc >= now: return False, "The signed certificate is not valid yet." - if ca_cert.get_subject().CN == cert.get_subject().CN: + if ca_cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME) == \ + cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME): return False, ("CN may not be equal for CA certificate and the " "signed certificate.") - cert_algorithm = cert.get_signature_algorithm() - cert_algorithm = cert_algorithm.decode("ascii") - cert_str = cert_str.encode("ascii") - - cert_crypto = saml2.cryptography.pki.load_pem_x509_certificate(cert_str) - try: - crypto.verify(ca_cert, cert_crypto.signature, cert_crypto.tbs_certificate_bytes, cert_algorithm) + cert.verify_directly_issued_by(ca_cert) return True, "Signed certificate is valid and correctly signed by CA certificate." - except crypto.Error as e: + except (ValueError, TypeError, InvalidSignature) as e: return False, f"Certificate is incorrectly signed: {str(e)}" except Exception as e: return False, f"Certificate is not valid for an unknown reason. {str(e)}" @@ -352,8 +362,14 @@ def read_cert_from_file(cert_file, cert_type="pem"): data = fp.read() try: - cert = saml2.cryptography.pki.load_x509_certificate(data, cert_type) - pem_data = saml2.cryptography.pki.get_public_bytes_from_cert(cert) + cert = None + if cert_type == "pem": + cert = x509.load_pem_x509_certificate(data) + elif cert_type == "der": + cert = x509.load_der_x509_certificate(data) + else: + raise ValueError(f"cert-type {cert_type} not supported") + pem_data = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") except Exception as e: raise CertificateError(e) diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index 738ac04b..60c83718 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -18,8 +18,9 @@ from time import mktime from urllib import parse from uuid import uuid4 as gen_random_key -from OpenSSL import crypto -import dateutil +from urllib import parse + +from cryptography import x509 from saml2 import ExtensionElement from saml2 import SamlBase @@ -373,14 +374,14 @@ def active_cert(key): """ try: cert_str = pem_format(key) - cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_str) + cert = x509.load_pem_x509_certificate(cert_str) except AttributeError: return False - now = datetime.now(timezone.utc) - valid_from = dateutil.parser.parse(cert.get_notBefore()) - valid_to = dateutil.parser.parse(cert.get_notAfter()) - active = not cert.has_expired() and valid_from <= now < valid_to + now = datetime.datetime.now(datetime.timezone.utc) + valid_from = cert.not_valid_before_utc + valid_to = cert.not_valid_after_utc + active = valid_from <= now < valid_to return active