diff --git a/rsa/pkcs1.py b/rsa/pkcs1.py index 5992c7f..6ef477d 100644 --- a/rsa/pkcs1.py +++ b/rsa/pkcs1.py @@ -49,7 +49,7 @@ "SHA-512": b"\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40", } -HASH_METHODS: typing.Dict[str, typing.Callable[[], HashType]] = { +HASH_METHODS: typing.Dict[str, typing.Callable[..., HashType]] = { "MD5": hashlib.md5, "SHA-1": hashlib.sha1, "SHA-224": hashlib.sha224, diff --git a/rsa/pkcs1_v2.py b/rsa/pkcs1_v2.py index d68b907..c956155 100644 --- a/rsa/pkcs1_v2.py +++ b/rsa/pkcs1_v2.py @@ -15,14 +15,24 @@ """Functions for PKCS#1 version 2 encryption and signing This module implements certain functionality from PKCS#1 version 2. Main -documentation is RFC 2437: https://tools.ietf.org/html/rfc2437 +documentation is RFC 8017: https://tools.ietf.org/html/rfc8017 """ -from rsa import ( - common, - pkcs1, - transform, -) +import os +from hmac import compare_digest + +from . import common, transform, core, key, pkcs1 +from ._compat import xor_bytes + + +def _constant_time_select(v: int, t: int, f: int) -> int: + """Return t if v else f. + + v must be 0 or 1. (False and True are allowed) + t and f are integer between 0 and 255. + """ + v -= 1 + return (~v & t) | (v & f) def mgf1(seed: bytes, length: int, hasher: str = "SHA-1") -> bytes: @@ -81,8 +91,193 @@ def mgf1(seed: bytes, length: int, hasher: str = "SHA-1") -> bytes: return output[:length] +def _OAEP_encode( + message: bytes, keylength: int, label, hash_method: str, mgf1_hash_method: str +) -> bytes: + try: + hasher = pkcs1.HASH_METHODS[hash_method](label) + except KeyError: + raise ValueError( + "Invalid `hash_method` specified. Please select one of: {hash_list}".format( + hash_list=", ".join(sorted(pkcs1.HASH_METHODS.keys())) + ) + ) + hash_length = hasher.digest_size + max_message_length = keylength - 2 * hash_length - 2 + message_length = len(message) + if message_length > max_message_length: + raise OverflowError( + "message is too long; at most %s bytes, given %s bytes" + % (max_message_length, len(message)) + ) + + lhash = hasher.digest() + ps = bytearray(keylength - message_length - 2 * hash_length - 2) + db = ( + hasher.digest() + + b"\0" * (keylength - message_length - 2 * hash_length - 2) + + b"\x01" + + message + ) + + seed = os.urandom(hash_length) + db_mask = mgf1(seed, keylength - hash_length - 1, mgf1_hash_method) + masked_db = xor_bytes(db, db_mask) + + seed_mask = mgf1(masked_db, hash_length, mgf1_hash_method) + masked_seed = xor_bytes(seed, seed_mask) + + em = b"\x00" + masked_seed + masked_db + return em + + +def encrypt_OAEP( + message: bytes, + pub_key: key.PublicKey, + label: bytes = b"", + hash_method: str = "SHA-1", + mgf1_hash_method: str = None, +) -> bytes: + """Encrypts the given message using PKCS#1 v2 RSA-OEAP. + + :param message: the message to encrypt. + :param pub_key: the public key to encrypt with. + :param label: optional RSA-OAEP label. + :param hash_method: hash function to be used. 'SHA-1' (default), + 'SHA-256', 'SHA-384', and 'SHA-512' can be used. + :param mgf1_hash_method: hash function to be used by MGF1 function. + If it is None (default), *hash_method* is used. + """ + # NOTE: Some hash method other than listed in the docstring can be used + # for hash_method. But the RFC 8017 recommends only them. + if mgf1_hash_method is None: + mgf1_hash_method = hash_method + keylength = common.byte_size(pub_key.n) + + em = _OAEP_encode(message, keylength, label, hash_method, mgf1_hash_method) + + m = transform.bytes2int(em) + encrypted = core.encrypt_int(m, pub_key.e, pub_key.n) + c = transform.int2bytes(encrypted, keylength) + + return c + + +def decrypt_OAEP( + crypto: bytes, + priv_key: key.PrivateKey, + label: bytes = b"", + hash_method: str = "SHA-1", + mgf1_hash_method: str = None, +) -> bytes: + """Decrypts the givem crypto using PKCS#1 v2 RSA-OAEP. + + :param crypto: the crypto text as returned by :py:func:`rsa.encrypt` + :param priv_key: the private key to decrypt with. + :param label: optional RSA-OAEP label. + :param hash_method: hash function to be used. 'SHA-1' (default), + 'SHA-256', 'SHA-384', and 'SHA-512' can be used. + :param mgf1_hash_method: hash function to be used by MGF1 function. + If it is None (default), *hash_method* is used. + + :raise rsa.pkcs1.DecryptionError: when the decryption fails. No details are given as + to why the code thinks the decryption fails, as this would leak + information about the private key. + + >>> import rsa + >>> (pub_key, priv_key) = rsa.newkeys(512) + + It works with binary data: + + >>> crypto = encrypt_OAEP(b'hello', pub_key) + >>> decrypt_OAEP(crypto, priv_key) + b'hello' + + You can pass optional label data too: + + >>> crypto = encrypt_OAEP(b'hello', pub_key, label=b'world') + >>> decrypt_OAEP(crypto, priv_key, label=b'world') + b'hello' + + Altering the encrypted information will cause a + :py:class:`rsa.pkcs1.DecryptionError`. + + >>> crypto = encrypt_OAEP(b'hello', pub_key) + >>> crypto = crypto[0:5] + bytes([(ord(crypto[5:6])+1)%256]) + crypto[6:] # change a byte + >>> decrypt_OAEP(crypto, priv_key) + Traceback (most recent call last): + ... + rsa.pkcs1.DecryptionError: Decryption failed + + Changing label will also cause the error. + + >>> crypto = encrypt_OAEP(b'hello', pub_key, label=b'world') + >>> decrypt_OAEP(crypto, priv_key, label=b'universe') + Traceback (most recent call last): + ... + rsa.pkcs1.DecryptionError: Decryption failed + """ + if mgf1_hash_method is None: + mgf1_hash_method = hash_method + + # todo: Step 1: length checking + k = common.byte_size(priv_key.n) + if k != len(crypto): + raise pkcs1.DecryptionError("Decryption failed") + + # Step 2: RSA Decryption + c = transform.bytes2int(crypto) + m = priv_key.blinded_decrypt(c) + em = transform.int2bytes(m, k) + + # Step 3: EME-OAEP decoding + try: + hasher = pkcs1.HASH_METHODS[hash_method](label) + except KeyError: + raise ValueError( + "Invalid `hash_method` specified. Please select one of: {hash_list}".format( + hash_list=", ".join(sorted(pkcs1.HASH_METHODS.keys())) + ) + ) + hash_length = hasher.digest_size + lhash = hasher.digest() + Y = em[0:1] + masked_seed = em[1 : 1 + hash_length] + masked_db = em[1 + hash_length :] + + seed_mask = mgf1(masked_db, hash_length, mgf1_hash_method) + seed = xor_bytes(masked_seed, seed_mask) + + db_mask = mgf1(seed, k - hash_length - 1, mgf1_hash_method) + db = xor_bytes(masked_db, db_mask) + + lhash_ = db[:hash_length] + rest = db[hash_length:] + + # NOTE: Take care about timing attack. See note in the RFC. + hash_is_good = compare_digest(lhash, lhash_) + + index = invalid = 0 + looking_one = 1 + + for i, c in enumerate(rest): + iszero = c == 0 + isone = c == 1 + + index = _constant_time_select(looking_one & isone, i, index) + looking_one = _constant_time_select(isone, 0, looking_one) + invalid = _constant_time_select(looking_one & ~iszero, 1, invalid) + + if invalid | looking_one | (not hash_is_good): + raise pkcs1.DecryptionError("Decryption failed") + + return rest[index + 1 :] + + __all__ = [ "mgf1", + "encrypt_OAEP", + "decrypt_OAEP", ] if __name__ == "__main__": diff --git a/tests/test_pkcs1_v2.py b/tests/test_pkcs1_v2.py index ead1393..eee1cbc 100644 --- a/tests/test_pkcs1_v2.py +++ b/tests/test_pkcs1_v2.py @@ -18,9 +18,13 @@ http://www.itomorrowmag.com/emc-plus/rsa-labs/standards-initiatives/pkcs-rsa-cryptography-standard.htm """ +import struct import unittest +import rsa from rsa import pkcs1_v2 +from rsa._compat import byte +from rsa.pkcs1 import DecryptionError class MGFTest(unittest.TestCase): @@ -77,3 +81,42 @@ def test_invalid_hasher(self): def test_invalid_length(self): with self.assertRaises(OverflowError): pkcs1_v2.mgf1(b"\x06\xe1\xde\xb2", length=2 ** 50) + + +class BinaryTest(unittest.TestCase): + def setUp(self): + (self.pub, self.priv) = rsa.newkeys(512) + + def test_enc_dec(self): + message = struct.pack(">IIII", 0, 0, 0, 1) + print("\tMessage: %r" % message) + + encrypted = pkcs1_v2.encrypt_OAEP(message, self.pub) + print("\tEncrypted: %r" % encrypted) + + decrypted = pkcs1_v2.decrypt_OAEP(encrypted, self.priv) + print("\tDecrypted: %r" % decrypted) + + self.assertEqual(message, decrypted) + + def test_decoding_failure(self): + message = struct.pack(">IIII", 0, 0, 0, 1) + encrypted = pkcs1_v2.encrypt_OAEP(message, self.pub) + + # Alter the encrypted stream + a = encrypted[5] + altered_a = (a + 1) % 256 + encrypted = encrypted[:5] + byte(altered_a) + encrypted[6:] + + self.assertRaises(DecryptionError, pkcs1_v2.decrypt_OAEP, encrypted, self.priv) + + def test_randomness(self): + """Encrypting the same message twice should result in different + cryptos. + """ + + message = struct.pack(">IIII", 0, 0, 0, 1) + encrypted1 = pkcs1_v2.encrypt_OAEP(message, self.pub) + encrypted2 = pkcs1_v2.encrypt_OAEP(message, self.pub) + + self.assertNotEqual(encrypted1, encrypted2)