diff --git a/Cargo.lock b/Cargo.lock index 12d1df34d98..3d08978569c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3238,6 +3238,7 @@ dependencies = [ "bzip2", "cfg-if", "chrono", + "constant_time_eq", "crc32fast", "crossbeam-utils", "csv-core", diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py index 1506bb7982a..42b8db39271 100644 --- a/Lib/test/test_hmac.py +++ b/Lib/test/test_hmac.py @@ -1484,19 +1484,11 @@ def test_compare_digest_func(self): else: self.assertIs(self.compare_digest, operator_compare_digest) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: TypeError not raised by compare_digest - def test_exceptions(self): - return super().test_exceptions() - @hashlib_helper.requires_hashlib() class OpenSSLCompareDigestTestCase(CompareDigestMixin, unittest.TestCase): compare_digest = openssl_compare_digest - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: TypeError not raised by compare_digest - def test_exceptions(self): - return super().test_exceptions() - class OperatorCompareDigestTestCase(CompareDigestMixin, unittest.TestCase): compare_digest = operator_compare_digest diff --git a/crates/stdlib/Cargo.toml b/crates/stdlib/Cargo.toml index 7df5ba522f9..21655547fd0 100644 --- a/crates/stdlib/Cargo.toml +++ b/crates/stdlib/Cargo.toml @@ -69,6 +69,7 @@ sha3 = "0.10.1" blake2 = "0.10.4" hmac = "0.12" pbkdf2 = { version = "0.12", features = ["hmac"] } +constant_time_eq = { workspace = true } ## unicode stuff unicode_names2 = { workspace = true } diff --git a/crates/stdlib/src/hashlib.rs b/crates/stdlib/src/hashlib.rs index 441b8f44815..924009884f8 100644 --- a/crates/stdlib/src/hashlib.rs +++ b/crates/stdlib/src/hashlib.rs @@ -14,7 +14,6 @@ pub mod _hashlib { PyBaseExceptionRef, PyBytes, PyFrozenSet, PyStr, PyTypeRef, PyUtf8StrRef, PyValueError, }, class::StaticType, - convert::ToPyObject, function::{ArgBytesLike, ArgStrOrBytesLike, FuncArgs, OptionalArg}, types::{Constructor, Representable}, }; @@ -724,22 +723,26 @@ pub mod _hashlib { a: ArgStrOrBytesLike, b: ArgStrOrBytesLike, vm: &VirtualMachine, - ) -> PyResult { - const fn is_str(arg: &ArgStrOrBytesLike) -> bool { - matches!(arg, ArgStrOrBytesLike::Str(_)) - } - - if is_str(&a) != is_str(&b) { - return Err(vm.new_type_error(format!( + ) -> PyResult { + use constant_time_eq::constant_time_eq; + + match (&a, &b) { + (ArgStrOrBytesLike::Str(a), ArgStrOrBytesLike::Str(b)) => { + if !a.isascii() || !b.isascii() { + return Err(vm.new_type_error( + "comparing strings with non-ASCII characters is not supported", + )); + } + Ok(constant_time_eq(a.as_bytes(), b.as_bytes())) + } + (ArgStrOrBytesLike::Buf(a), ArgStrOrBytesLike::Buf(b)) => { + Ok(a.with_ref(|a| b.with_ref(|b| constant_time_eq(a, b)))) + } + _ => Err(vm.new_type_error(format!( "a bytes-like object is required, not '{}'", b.as_object().class().name() - ))); + ))), } - - let a_hash = a.borrow_bytes().to_vec(); - let b_hash = b.borrow_bytes().to_vec(); - - Ok((a_hash == b_hash).to_pyobject(vm)) } #[derive(FromArgs, Debug)]