diff --git a/.cspell.dict/cpython.txt b/.cspell.dict/cpython.txt index 8c733e343d1..5676549ec1a 100644 --- a/.cspell.dict/cpython.txt +++ b/.cspell.dict/cpython.txt @@ -2,11 +2,14 @@ argtypes asdl asname augassign +badcert badsyntax basetype boolop bxor cached_tsver +cadata +cafile cellarg cellvar cellvars @@ -23,8 +26,8 @@ freevars fromlist heaptype HIGHRES -Itertool IMMUTABLETYPE +Itertool kwonlyarg kwonlyargs lasti @@ -47,6 +50,7 @@ stackdepth stringlib structseq subparams +ticketer tok_oldval tvars unaryop @@ -56,6 +60,7 @@ VARKEYWORDS varkwarg wbits weakreflist +webpki withitem withs xstat diff --git a/.cspell.dict/rust-more.txt b/.cspell.dict/rust-more.txt index 6f89fdfafe1..f27e53bd6ed 100644 --- a/.cspell.dict/rust-more.txt +++ b/.cspell.dict/rust-more.txt @@ -50,6 +50,7 @@ nanos nonoverlapping objclass peekable +pemfile powc powf powi @@ -61,6 +62,7 @@ rposition rsplitn rustc rustfmt +rustls rustyline seedable seekfrom diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9ab4610ed11..07784b2667e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -16,7 +16,8 @@ concurrency: cancel-in-progress: true env: - CARGO_ARGS: --no-default-features --features stdlib,importlib,stdio,encodings,sqlite,ssl + CARGO_ARGS: --no-default-features --features stdlib,importlib,stdio,encodings,sqlite,ssl-rustls + CARGO_ARGS_NO_SSL: --no-default-features --features stdlib,importlib,stdio,encodings,sqlite # Skip additional tests on Windows. They are checked on Linux and MacOS. # test_glob: many failing tests # test_io: many failing tests @@ -169,7 +170,7 @@ jobs: target: aarch64-apple-ios if: runner.os == 'macOS' - name: Check compilation for iOS - run: cargo check --target aarch64-apple-ios + run: cargo check --target aarch64-apple-ios ${{ env.CARGO_ARGS_NO_SSL }} if: runner.os == 'macOS' exotic_targets: @@ -186,14 +187,14 @@ jobs: - name: Install gcc-multilib and musl-tools run: sudo apt-get update && sudo apt-get install gcc-multilib musl-tools - name: Check compilation for x86 32bit - run: cargo check --target i686-unknown-linux-gnu + run: cargo check --target i686-unknown-linux-gnu ${{ env.CARGO_ARGS_NO_SSL }} - uses: dtolnay/rust-toolchain@stable with: target: aarch64-linux-android - name: Check compilation for android - run: cargo check --target aarch64-linux-android + run: cargo check --target aarch64-linux-android ${{ env.CARGO_ARGS_NO_SSL }} - uses: dtolnay/rust-toolchain@stable with: @@ -202,28 +203,28 @@ jobs: - name: Install gcc-aarch64-linux-gnu run: sudo apt install gcc-aarch64-linux-gnu - name: Check compilation for aarch64 linux gnu - run: cargo check --target aarch64-unknown-linux-gnu + run: cargo check --target aarch64-unknown-linux-gnu ${{ env.CARGO_ARGS_NO_SSL }} - uses: dtolnay/rust-toolchain@stable with: target: i686-unknown-linux-musl - name: Check compilation for musl - run: cargo check --target i686-unknown-linux-musl + run: cargo check --target i686-unknown-linux-musl ${{ env.CARGO_ARGS_NO_SSL }} - uses: dtolnay/rust-toolchain@stable with: target: x86_64-unknown-freebsd - name: Check compilation for freebsd - run: cargo check --target x86_64-unknown-freebsd + run: cargo check --target x86_64-unknown-freebsd ${{ env.CARGO_ARGS_NO_SSL }} - uses: dtolnay/rust-toolchain@stable with: target: x86_64-unknown-freebsd - name: Check compilation for freeBSD - run: cargo check --target x86_64-unknown-freebsd + run: cargo check --target x86_64-unknown-freebsd ${{ env.CARGO_ARGS_NO_SSL }} # - name: Prepare repository for redox compilation # run: bash scripts/redox/uncomment-cargo.sh diff --git a/Cargo.lock b/Cargo.lock index e1b1d958d30..cec9f7afb47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14,6 +14,17 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + [[package]] name = "ahash" version = "0.8.12" @@ -134,6 +145,45 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16" +[[package]] +name = "asn1-rs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" +dependencies = [ + "asn1-rs-derive", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + +[[package]] +name = "asn1-rs-derive" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "asn1-rs-impl" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic" version = "0.6.1" @@ -179,12 +229,57 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "aws-lc-fips-sys" +version = "0.13.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ede71ad84efb06d748d9af3bc500b14957a96282a69a6833b1420dcacb411cc3" +dependencies = [ + "bindgen 0.72.1", + "cc", + "cmake", + "dunce", + "fs_extra", + "regex", +] + +[[package]] +name = "aws-lc-rs" +version = "1.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879b6c89592deb404ba4dc0ae6b58ffd1795c78991cbb5b8bc441c48a070440d" +dependencies = [ + "aws-lc-fips-sys", + "aws-lc-sys", + "untrusted 0.7.1", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.32.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "107a4e9d9cab9963e04e84bb8dee0e25f2a987f9a8bad5ed054abd439caa8f8c" +dependencies = [ + "bindgen 0.72.1", + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "base64" version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" + [[package]] name = "bindgen" version = "0.71.1" @@ -205,6 +300,26 @@ dependencies = [ "syn", ] +[[package]] +name = "bindgen" +version = "0.72.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" +dependencies = [ + "bitflags 2.9.4", + "cexpr", + "clang-sys", + "itertools 0.13.0", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -235,6 +350,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-padding" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8894febbff9f758034a5b8e12d87918f56dfc64a8e1fe757d65e29041538d93" +dependencies = [ + "generic-array", +] + [[package]] name = "bstr" version = "1.12.0" @@ -261,6 +385,12 @@ version = "1.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" +[[package]] +name = "bytes" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" + [[package]] name = "bzip2" version = "0.6.1" @@ -294,6 +424,15 @@ dependencies = [ "rustversion", ] +[[package]] +name = "cbc" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b52a9543ae338f279b96b0b9fed9c8093744685043739079ce85cd58f289a6" +dependencies = [ + "cipher", +] + [[package]] name = "cc" version = "1.2.41" @@ -301,9 +440,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cexpr" version = "0.6.0" @@ -365,6 +512,16 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -410,6 +567,15 @@ dependencies = [ "error-code", ] +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "collection_literals" version = "1.0.3" @@ -422,6 +588,16 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "compact_str" version = "0.9.0" @@ -458,6 +634,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "constant_time_eq" version = "0.4.2" @@ -474,6 +656,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -751,6 +943,59 @@ dependencies = [ "memchr", ] +[[package]] +name = "data-encoding" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" + +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "const-oid", + "der_derive", + "flagset", + "pem-rfc7468", + "zeroize", +] + +[[package]] +name = "der-parser" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" +dependencies = [ + "asn1-rs", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + +[[package]] +name = "der_derive" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8034092389675178f570469e6c3b0465d3d30b4505c294a6550db47f3c17ad18" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "deranged" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587" +dependencies = [ + "powerfmt", +] + [[package]] name = "derive-where" version = "1.6.0" @@ -794,6 +1039,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "dns-lookup" version = "3.0.0" @@ -806,6 +1062,12 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "dyn-clone" version = "1.0.20" @@ -916,13 +1178,19 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127" +[[package]] +name = "flagset" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7ac824320a75a52197e8f2d787f6a38b6718bb6897a35142d749af3c0e8f4fe" + [[package]] name = "flame" version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc2706461e1ee94f55cab2ed2e3d34ae9536cfa830358ef80acff1a3dacab30" dependencies = [ - "lazy_static", + "lazy_static 0.2.11", "serde", "serde_derive", "serde_json", @@ -984,6 +1252,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "generic-array" version = "0.14.9" @@ -1128,6 +1402,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfa686283ad6dd069f105e5ab091b04c62850d3e4cf5d67debad1933f55023df" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "home" version = "0.5.11" @@ -1177,6 +1460,16 @@ version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "block-padding", + "generic-array", +] + [[package]] name = "insta" version = "1.43.2" @@ -1260,6 +1553,38 @@ dependencies = [ "syn", ] +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + [[package]] name = "js-sys" version = "0.3.81" @@ -1295,6 +1620,12 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76f033c7ad61445c5b347c7382dd1237847eb1bce590fe50365dcb33d546be73" +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "lexical-parse-float" version = "1.0.6" @@ -1650,6 +1981,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-complex" version = "0.4.6" @@ -1659,6 +2000,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.46" @@ -1708,6 +2055,15 @@ dependencies = [ "syn", ] +[[package]] +name = "oid-registry" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" +dependencies = [ + "asn1-rs", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -1825,6 +2181,25 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pbkdf2" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" +dependencies = [ + "digest", + "hmac", +] + +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "phf" version = "0.11.3" @@ -1919,6 +2294,33 @@ dependencies = [ "siphasher", ] +[[package]] +name = "pkcs5" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e847e2c91a18bfa887dd028ec33f2fe6f25db77db3619024764914affe8b69a6" +dependencies = [ + "aes", + "cbc", + "der", + "pbkdf2", + "scrypt", + "sha2", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "pkcs5", + "rand_core 0.6.4", + "spki", +] + [[package]] name = "pkg-config" version = "0.3.32" @@ -1979,6 +2381,12 @@ dependencies = [ "portable-atomic", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.21" @@ -2325,6 +2733,20 @@ dependencies = [ "syn", ] +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.16", + "libc", + "untrusted 0.9.0", + "windows-sys 0.52.0", +] + [[package]] name = "ruff_python_ast" version = "0.0.0" @@ -2398,6 +2820,15 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "rusticata-macros" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" +dependencies = [ + "nom", +] + [[package]] name = "rustix" version = "1.1.2" @@ -2411,6 +2842,89 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls" +version = "0.23.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "533f54bc6a7d4f647e46ad909549eda97bf5afc1585190ef692b4286b198bd8f" +dependencies = [ + "aws-lc-rs", + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9980d917ebb0c0536119ba501e90834767bffc3d60641457fd84a1f3fd337923" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94182ad936a0c91c324cd46c6511b9510ed16af436d7b5bab34beab0afd55f7a" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-platform-verifier" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" +dependencies = [ + "core-foundation 0.10.1", + "core-foundation-sys", + "jni", + "log", + "once_cell", + "rustls", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework", + "security-framework-sys", + "webpki-root-certs", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" + +[[package]] +name = "rustls-webpki" +version = "0.103.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" +dependencies = [ + "aws-lc-rs", + "ring", + "rustls-pki-types", + "untrusted 0.9.0", +] + [[package]] name = "rustpython" version = "0.4.0" @@ -2597,13 +3111,16 @@ dependencies = [ "adler32", "ahash", "ascii", + "aws-lc-rs", "base64", "blake2", "bzip2", "cfg-if", + "chrono", "crc32fast", "crossbeam-utils", "csv-core", + "der", "digest", "dns-lookup", "dyn-clone", @@ -2628,16 +3145,23 @@ dependencies = [ "num-integer", "num-traits", "num_enum", + "oid-registry", "openssl", "openssl-probe", "openssl-sys", "page_size", "parking_lot", "paste", + "pem-rfc7468", "phf 0.11.3", + "pkcs8", "pymath", "rand_core 0.9.3", "rustix", + "rustls", + "rustls-native-certs", + "rustls-pemfile", + "rustls-platform-verifier", "rustpython-common", "rustpython-derive", "rustpython-vm", @@ -2661,8 +3185,11 @@ dependencies = [ "unicode-casing", "unicode_names2 2.0.0", "uuid", + "webpki-roots 0.26.11", "widestring", "windows-sys 0.59.0", + "x509-cert", + "x509-parser", "xml", "xz2", ] @@ -2815,6 +3342,15 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "salsa20" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97a22f5af31f73a954c10289c93e8a50cc23d971e80ee446f1f6f7137a088213" +dependencies = [ + "cipher", +] + [[package]] name = "same-file" version = "1.0.6" @@ -2845,6 +3381,40 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scrypt" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0516a385866c09368f0b5bcd1caff3366aace790fcd46e2bb032697bb172fd1f" +dependencies = [ + "pbkdf2", + "salsa20", + "sha2", +] + +[[package]] +name = "security-framework" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef" +dependencies = [ + "bitflags 2.9.4", + "core-foundation 0.10.1", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "serde" version = "1.0.228" @@ -2919,6 +3489,17 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.9" @@ -2945,7 +3526,7 @@ name = "shared-build" version = "0.2.0" source = "git+https://github.com/arihant2math/tkinter.git?tag=v0.2.0#198fc35b1f18f4eda401f97a641908f321b1403a" dependencies = [ - "bindgen", + "bindgen 0.71.1", ] [[package]] @@ -2954,6 +3535,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "rand_core 0.6.4", +] + [[package]] name = "simd-adler32" version = "0.3.7" @@ -2988,6 +3578,16 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -3046,6 +3646,17 @@ dependencies = [ "syn", ] +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "system-configuration" version = "0.6.1" @@ -3053,7 +3664,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags 2.9.4", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -3157,6 +3768,37 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "time" +version = "0.3.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" + +[[package]] +name = "time-macros" +version = "0.2.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "timsort" version = "0.1.3" @@ -3197,6 +3839,27 @@ dependencies = [ "shared-build", ] +[[package]] +name = "tls_codec" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de2e01245e2bb89d6f05801c564fa27624dbd7b1846859876c7dad82e90bf6b" +dependencies = [ + "tls_codec_derive", + "zeroize", +] + +[[package]] +name = "tls_codec_derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "toml" version = "0.8.23" @@ -3457,6 +4120,18 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "utf8parse" version = "0.2.2" @@ -3605,6 +4280,33 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-root-certs" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee3e3b5f5e80bc89f30ce8d0343bf4e5f12341c51f3e26cbeecbc7c85443e85b" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.4", +] + +[[package]] +name = "webpki-roots" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2878ef029c47c6e8cf779119f20fcf52bde7ad42a731b2a304bc221df17571e" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "which" version = "8.0.0" @@ -3741,6 +4443,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -3777,6 +4488,21 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.52.6" @@ -3810,6 +4536,12 @@ dependencies = [ "windows_x86_64_msvc 0.53.1", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" @@ -3822,6 +4554,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" @@ -3834,6 +4572,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.52.6" @@ -3858,6 +4602,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.52.6" @@ -3870,6 +4620,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" @@ -3882,6 +4638,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" @@ -3894,6 +4656,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" @@ -3947,6 +4715,37 @@ version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +[[package]] +name = "x509-cert" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1301e935010a701ae5f8655edc0ad17c44bad3ac5ce8c39185f75453b720ae94" +dependencies = [ + "const-oid", + "der", + "sha1", + "signature", + "spki", + "tls_codec", +] + +[[package]] +name = "x509-parser" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" +dependencies = [ + "asn1-rs", + "data-encoding", + "der-parser", + "lazy_static 1.5.0", + "nom", + "oid-registry", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "xml" version = "1.0.1" @@ -3982,6 +4781,26 @@ dependencies = [ "syn", ] +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zlib-rs" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index 8d344c133c8..37bc5d146c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ repository.workspace = true license.workspace = true [features] -default = ["threading", "stdlib", "stdio", "importlib"] +default = ["threading", "stdlib", "stdio", "importlib", "ssl-rustls"] importlib = ["rustpython-vm/importlib"] encodings = ["rustpython-vm/encodings"] stdio = ["rustpython-vm/stdio"] @@ -20,8 +20,10 @@ freeze-stdlib = ["stdlib", "rustpython-vm/freeze-stdlib", "rustpython-pylib?/fre jit = ["rustpython-vm/jit"] threading = ["rustpython-vm/threading", "rustpython-stdlib/threading"] sqlite = ["rustpython-stdlib/sqlite"] -ssl = ["rustpython-stdlib/ssl"] -ssl-vendor = ["ssl", "rustpython-stdlib/ssl-vendor"] +ssl = [] +ssl-rustls = ["ssl", "rustpython-stdlib/ssl-rustls"] +ssl-openssl = ["ssl", "rustpython-stdlib/ssl-openssl"] +ssl-vendor = ["ssl-openssl", "rustpython-stdlib/ssl-vendor"] tkinter = ["rustpython-stdlib/tkinter"] [build-dependencies] diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index fea3a2ce692..f073def5bc1 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -426,7 +426,6 @@ def test_random(self): ssl.RAND_add(b"this is a random bytes object", 75.0) ssl.RAND_add(bytearray(b"this is a random bytearray object"), 75.0) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_parse_cert(self): # note that this uses an 'unofficial' function in _ssl.c, # provided solely for this test, to exercise the certificate @@ -506,7 +505,6 @@ def test_parse_cert_CVE_2013_4238(self): self.assertEqual(p['subjectAltName'], san) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_parse_all_sans(self): p = ssl._ssl._test_decode_cert(ALLSANFILE) self.assertEqual(p['subjectAltName'], @@ -927,7 +925,6 @@ def test_connect_ex_error(self): ) self.assertIn(rc, errors) - @unittest.skip("TODO: RUSTPYTHON; hangs") def test_read_write_zero(self): # empty reads and writes now work, bpo-42854, bpo-31711 client_context, server_context, hostname = testing_context() @@ -993,7 +990,6 @@ def test_get_ciphers(self): len(intersection), 2, f"\ngot: {sorted(names)}\nexpected: {sorted(expected)}" ) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_options(self): # Test default SSLContext options ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) @@ -1066,8 +1062,8 @@ def test_hostname_checks_common_name(self): with self.assertRaises(AttributeError): ctx.hostname_checks_common_name = True - @ignore_deprecation @unittest.expectedFailure # TODO: RUSTPYTHON + @ignore_deprecation def test_min_max_version(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) # OpenSSL default is MINIMUM_SUPPORTED, however some vendors like @@ -1185,7 +1181,6 @@ def test_verify_flags(self): with self.assertRaises(TypeError): ctx.verify_flags = None - @unittest.expectedFailure # TODO: RUSTPYTHON def test_load_cert_chain(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) # Combined key and cert in a single file @@ -1294,7 +1289,6 @@ def race(): self.assertIsNone(cm.exc_value) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_load_verify_locations(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ctx.load_verify_locations(CERTFILE) @@ -1314,7 +1308,6 @@ def test_load_verify_locations(self): # Issue #10989: crash if the second argument type is invalid self.assertRaises(TypeError, ctx.load_verify_locations, None, True) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_load_verify_cadata(self): # test cadata with open(CAFILE_CACERT) as f: @@ -1380,7 +1373,6 @@ def test_load_verify_cadata(self): with self.assertRaises(ssl.SSLError): ctx.load_verify_locations(cadata=cacert_der + b"A") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_load_dh_params(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) try: @@ -1473,7 +1465,6 @@ def test_cert_store_stats(self): self.assertEqual(ctx.cert_store_stats(), {'x509_ca': 1, 'crl': 0, 'x509': 2}) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_get_ca_certs(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self.assertEqual(ctx.get_ca_certs(), []) @@ -1732,7 +1723,6 @@ def test_lib_reason(self): s = str(cm.exception) self.assertTrue("NO_START_LINE" in s, s) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_subclass(self): # Check that the appropriate SSLError subclass is raised # (this only tests one of them) @@ -1751,7 +1741,6 @@ def test_subclass(self): self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_bad_server_hostname(self): ctx = ssl.create_default_context() with self.assertRaises(ValueError): @@ -1838,7 +1827,6 @@ def test_private_init(self): with self.assertRaisesRegex(TypeError, "public constructor"): ssl.SSLObject(bio, bio) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_unwrap(self): client_ctx, server_ctx, hostname = testing_context() c_in = ssl.MemoryBIO() @@ -2193,7 +2181,6 @@ def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs): % (count, func.__name__)) return ret - @unittest.expectedFailure # TODO: RUSTPYTHON def test_bio_handshake(self): sock = socket.socket(socket.AF_INET) self.addCleanup(sock.close) @@ -2230,7 +2217,6 @@ def test_bio_handshake(self): pass self.assertRaises(ssl.SSLError, sslobj.write, b'foo') - @unittest.expectedFailure # TODO: RUSTPYTHON def test_bio_read_write_data(self): sock = socket.socket(socket.AF_INET) self.addCleanup(sock.close) @@ -2248,7 +2234,6 @@ def test_bio_read_write_data(self): self.assertEqual(buf, b'foo\n') self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_transport_eof(self): client_context, server_context, hostname = testing_context() with socket.socket(socket.AF_INET) as sock: @@ -3565,7 +3550,6 @@ def test_socketserver(self): f.close() self.assertEqual(d1, d2) - @unittest.skip("TODO: RUSTPYTHON; hangs") def test_asyncore_server(self): """Check the example asyncore integration.""" if support.verbose: @@ -3595,7 +3579,6 @@ def test_asyncore_server(self): if support.verbose: sys.stdout.write(" client: connection closed.\n") - @unittest.skip("TODO: RUSTPYTHON; hangs") def test_recv_send(self): """Test recv(), send() and friends.""" if support.verbose: @@ -3732,7 +3715,6 @@ def _recvfrom_into(): s.close() - @unittest.expectedFailure # TODO: RUSTPYTHON def test_recv_zero(self): server = ThreadedEchoServer(CERTFILE) self.enterContext(server) @@ -4040,6 +4022,7 @@ def test_default_ecdh_curve(self): s.connect((HOST, server.port)) self.assertIn("ECDH", s.cipher()[0]) + @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, "'tls-unique' channel binding not available") def test_tls_unique_channel_binding(self): @@ -4212,7 +4195,6 @@ def test_selected_alpn_protocol_if_server_uses_alpn(self): sni_name=hostname) self.assertIs(stats['client_alpn_protocol'], None) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_alpn_protocols(self): server_protocols = ['foo', 'bar', 'milkshake'] protocol_tests = [ @@ -4263,7 +4245,6 @@ def check_common_name(self, stats, name): cert = stats['peercert'] self.assertIn((('commonName', name),), cert['subject']) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_sni_callback(self): calls = [] server_context, other_context, client_context = self.sni_contexts() @@ -4514,7 +4495,6 @@ def test_session_handling(self): 'Session refers to a different SSLContext.') @requires_tls_version('TLSv1_2') - @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipUnless(ssl.HAS_PSK, 'TLS-PSK disabled on this OpenSSL build') def test_psk(self): psk = bytes.fromhex('deadbeef') @@ -4583,7 +4563,6 @@ def server_callback(identity): s.connect((HOST, server.port)) @requires_tls_version('TLSv1_3') - @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipUnless(ssl.HAS_PSK, 'TLS-PSK disabled on this OpenSSL build') def test_psk_tls1_3(self): psk = bytes.fromhex('deadbeef') @@ -4616,6 +4595,43 @@ def server_callback(identity): with client_context.wrap_socket(socket.socket()) as s: s.connect((HOST, server.port)) + @unittest.skip("TODO: rustpython") + def test_thread_recv_while_main_thread_sends(self): + # GH-137583: Locking was added to calls to send() and recv() on SSL + # socket objects. This seemed fine at the surface level because those + # calls weren't re-entrant, but recv() calls would implicitly mimick + # holding a lock by blocking until it received data. This means that + # if a thread started to infinitely block until data was received, calls + # to send() would deadlock, because it would wait forever on the lock + # that the recv() call held. + data = b"1" * 1024 + event = threading.Event() + def background(sock): + event.set() + received = sock.recv(len(data)) + self.assertEqual(received, data) + + client_context, server_context, hostname = testing_context() + server = ThreadedEchoServer(context=server_context) + with server: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as sock: + sock.connect((HOST, server.port)) + sock.settimeout(1) + sock.setblocking(1) + # Ensure that the server is ready to accept requests + sock.sendall(b"123") + self.assertEqual(sock.recv(3), b"123") + with threading_helper.catch_threading_exception() as cm: + thread = threading.Thread(target=background, + args=(sock,), daemon=True) + thread.start() + event.wait() + sock.sendall(data) + thread.join() + if cm.exc_value is not None: + raise cm.exc_value + @unittest.skipUnless(has_tls_version('TLSv1_3'), "Test needs TLS 1.3") class TestPostHandshakeAuth(unittest.TestCase): @@ -4736,6 +4752,7 @@ def test_pha_optional(self): s.write(b'HASCERT') self.assertEqual(s.recv(1024), b'TRUE\n') + @unittest.expectedFailure # TODO: RUSTPYTHON def test_pha_optional_nocert(self): if support.verbose: sys.stdout.write("\n") @@ -4775,6 +4792,7 @@ def test_pha_no_pha_client(self): s.write(b'PHA') self.assertIn(b'extension not received', s.recv(1024)) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_pha_no_pha_server(self): # server doesn't have PHA enabled, cert is requested in handshake client_context, server_context, hostname = testing_context() @@ -4844,7 +4862,6 @@ def test_bpo37428_pha_cert_none(self): # server cert has not been validated self.assertEqual(s.getpeercert(), {}) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_internal_chain_client(self): client_context, server_context, hostname = testing_context( server_chain=False @@ -4916,7 +4933,6 @@ def test_certificate_chain(self): self.assertEqual(ee, uvc[0]) self.assertNotEqual(ee, ca) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_internal_chain_server(self): client_context, server_context, hostname = testing_context() client_context.load_cert_chain(SIGNED_CERTFILE) @@ -5040,7 +5056,6 @@ def test_keylog_env(self): ctx = ssl._create_stdlib_context() self.assertEqual(ctx.keylog_filename, os_helper.TESTFN) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_msg_callback(self): client_context, server_context, hostname = testing_context() @@ -5085,7 +5100,6 @@ def msg_cb(conn, direction, version, content_type, msg_type, data): msg ) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_msg_callback_deadlock_bpo43577(self): client_context, server_context, hostname = testing_context() server_context2 = testing_context()[1] diff --git a/README.md b/README.md index ce5f02bee23..86d0738ec8e 100644 --- a/README.md +++ b/README.md @@ -66,17 +66,11 @@ Welcome to the magnificent Rust Python interpreter >>>>> ``` -If you'd like to make https requests, you can enable the `ssl` feature, which -also lets you install the `pip` package manager. Note that on Windows, you may -need to install OpenSSL, or you can enable the `ssl-vendor` feature instead, -which compiles OpenSSL for you but requires a C compiler, perl, and `make`. -OpenSSL version 3 is expected and tested in CI. Older versions may not work. - -Once you've installed rustpython with SSL support, you can install pip by +You can install pip by running: ```bash -cargo install --git https://github.com/RustPython/RustPython --features ssl +cargo install --git https://github.com/RustPython/RustPython rustpython --install-pip ``` @@ -88,6 +82,13 @@ conda install rustpython -c conda-forge rustpython ``` +### SSL provider + +For HTTPS requests, `ssl-rustls` feature is enabled by default. You can replace it with `ssl-openssl` feature if your environment requires OpenSSL. +Note that to use OpenSSL on Windows, you may need to install OpenSSL, or you can enable the `ssl-vendor` feature instead, +which compiles OpenSSL for you but requires a C compiler, perl, and `make`. +OpenSSL version 3 is expected and tested in CI. Older versions may not work. + ### WASI You can compile RustPython to a standalone WebAssembly WASI module so it can run anywhere. diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index 89d19f15c8a..ab90f52b2df 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -520,7 +520,7 @@ impl ExecutingFrame<'_> { trace!(" {:#?}", self); trace!( " Executing op code: {}", - instruction.display(arg, &self.code.code).to_string() + instruction.display(arg, &self.code.code) ); trace!("======="); } diff --git a/src/lib.rs b/src/lib.rs index 362adfba490..5d0537818a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,6 +59,14 @@ pub use rustpython_vm as vm; pub use settings::{InstallPipMode, RunMode, parse_opts}; pub use shell::run_shell; +#[cfg(all( + feature = "ssl", + not(any(feature = "ssl-rustls", feature = "ssl-openssl")) +))] +compile_error!( + "Feature \"ssl\" is now enabled by either \"ssl-rustls\" or \"ssl-openssl\" to be enabled. Do not manually pass \"ssl\" feature. To enable ssl-openssl, use --no-default-features to disable ssl-rustls" +); + /// The main cli of the `rustpython` interpreter. This function will return `std::process::ExitCode` /// based on the return code of the python code ran through the cli. pub fn run(init: impl FnOnce(&mut VirtualMachine) + 'static) -> ExitCode { @@ -141,7 +149,7 @@ __import__("io").TextIOWrapper( } fn install_pip(installer: InstallPipMode, scope: Scope, vm: &VirtualMachine) -> PyResult<()> { - if cfg!(not(feature = "ssl")) { + if !cfg!(feature = "ssl") { return Err(vm.new_exception_msg( vm.ctx.exceptions.system_error.to_owned(), "install-pip requires rustpython be build with '--features=ssl'".to_owned(), diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index 7f64802d352..e62872324ea 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -15,8 +15,12 @@ default = ["compiler"] compiler = ["rustpython-vm/compiler"] threading = ["rustpython-common/threading", "rustpython-vm/threading"] sqlite = ["dep:libsqlite3-sys"] -ssl = ["openssl", "openssl-sys", "foreign-types-shared", "openssl-probe"] -ssl-vendor = ["ssl", "openssl/vendored"] +# SSL backends - default to rustls +ssl = [] +ssl-rustls = ["ssl", "rustls", "rustls-native-certs", "rustls-pemfile", "rustls-platform-verifier", "x509-cert", "x509-parser", "der", "pem-rfc7468", "webpki-roots", "aws-lc-rs", "oid-registry", "pkcs8"] +ssl-rustls-fips = ["ssl-rustls", "aws-lc-rs/fips"] +ssl-openssl = ["ssl", "openssl", "openssl-sys", "foreign-types-shared", "openssl-probe"] +ssl-vendor = ["ssl-openssl", "openssl/vendored"] tkinter = ["dep:tk-sys", "dep:tcl-sys"] [dependencies] @@ -86,6 +90,7 @@ bzip2 = "0.6" # tkinter tk-sys = { git = "https://github.com/arihant2math/tkinter.git", tag = "v0.2.0", optional = true } tcl-sys = { git = "https://github.com/arihant2math/tkinter.git", tag = "v0.2.0", optional = true } +chrono.workspace = true # uuid [target.'cfg(not(any(target_os = "ios", target_os = "android", target_os = "windows", target_arch = "wasm32", target_os = "redox")))'.dependencies] @@ -107,11 +112,27 @@ rustix = { workspace = true } gethostname = "1.0.2" socket2 = { version = "0.6.0", features = ["all"] } dns-lookup = "3.0" + +# OpenSSL dependencies (optional, for ssl-openssl feature) openssl = { version = "0.10.72", optional = true } openssl-sys = { version = "0.9.110", optional = true } openssl-probe = { version = "0.1.5", optional = true } foreign-types-shared = { version = "0.1.1", optional = true } +# Rustls dependencies (optional, for ssl-rustls feature) +rustls = { version = "0.23.35", default-features = false, features = ["std", "tls12", "aws_lc_rs"], optional = true } +rustls-native-certs = { version = "0.8", optional = true } +rustls-pemfile = { version = "2.2", optional = true } +rustls-platform-verifier = { version = "0.6", optional = true } +x509-cert = { version = "0.2.5", features = ["pem", "builder"], optional = true } +x509-parser = { version = "0.16", optional = true } +der = { version = "0.7", features = ["alloc", "oid"], optional = true } +pem-rfc7468 = { version = "0.7", optional = true } +webpki-roots = { version = "0.26", optional = true } +aws-lc-rs = { version = "1.14.1", optional = true } +oid-registry = { version = "0.7", features = ["x509", "pkcs1", "nist_algs"], optional = true } +pkcs8 = { version = "0.10", features = ["encryption", "pkcs5", "pem"], optional = true } + [target.'cfg(not(any(target_os = "android", target_arch = "wasm32")))'.dependencies] libsqlite3-sys = { version = "0.28", features = ["bundled"], optional = true } lzma-sys = "0.1" diff --git a/stdlib/build.rs b/stdlib/build.rs index b7bf6307157..83ebd81ead6 100644 --- a/stdlib/build.rs +++ b/stdlib/build.rs @@ -23,25 +23,28 @@ fn main() { println!("cargo::rustc-check-cfg=cfg({cfg})"); } - #[allow(clippy::unusual_byte_groupings)] - if let Ok(v) = std::env::var("DEP_OPENSSL_VERSION_NUMBER") { - println!("cargo:rustc-env=OPENSSL_API_VERSION={v}"); - // cfg setup from openssl crate's build script - let version = u64::from_str_radix(&v, 16).unwrap(); - for (ver, cfg) in ossl_vers { - if version >= ver { - println!("cargo:rustc-cfg={cfg}"); + #[cfg(feature = "ssl-openssl")] + { + #[allow(clippy::unusual_byte_groupings)] + if let Ok(v) = std::env::var("DEP_OPENSSL_VERSION_NUMBER") { + println!("cargo:rustc-env=OPENSSL_API_VERSION={v}"); + // cfg setup from openssl crate's build script + let version = u64::from_str_radix(&v, 16).unwrap(); + for (ver, cfg) in ossl_vers { + if version >= ver { + println!("cargo:rustc-cfg={cfg}"); + } } } - } - if let Ok(v) = std::env::var("DEP_OPENSSL_CONF") { - for conf in v.split(',') { - println!("cargo:rustc-cfg=osslconf=\"{conf}\""); + if let Ok(v) = std::env::var("DEP_OPENSSL_CONF") { + for conf in v.split(',') { + println!("cargo:rustc-cfg=osslconf=\"{conf}\""); + } + } + // it's possible for openssl-sys to link against the system openssl under certain conditions, + // so let the ssl module know to only perform a probe if we're actually vendored + if std::env::var("DEP_OPENSSL_VENDORED").is_ok_and(|s| s == "1") { + println!("cargo::rustc-cfg=openssl_vendored") } - } - // it's possible for openssl-sys to link against the system openssl under certain conditions, - // so let the ssl module know to only perform a probe if we're actually vendored - if std::env::var("DEP_OPENSSL_VENDORED").is_ok_and(|s| s == "1") { - println!("cargo::rustc-cfg=openssl_vendored") } } diff --git a/stdlib/src/lib.rs b/stdlib/src/lib.rs index 706ce0ef210..01a27b76609 100644 --- a/stdlib/src/lib.rs +++ b/stdlib/src/lib.rs @@ -75,8 +75,14 @@ mod select; not(any(target_os = "android", target_arch = "wasm32")) ))] mod sqlite; -#[cfg(all(not(target_arch = "wasm32"), feature = "ssl"))] + +#[cfg(all(not(target_arch = "wasm32"), feature = "ssl-openssl"))] +mod openssl; +#[cfg(all(not(target_arch = "wasm32"), feature = "ssl-rustls"))] mod ssl; +#[cfg(all(feature = "ssl-openssl", feature = "ssl-rustls"))] +compile_error!("features \"ssl-openssl\" and \"ssl-rustls\" are mutually exclusive"); + #[cfg(all(unix, not(target_os = "redox"), not(target_os = "ios")))] mod termios; #[cfg(not(any( @@ -167,10 +173,14 @@ pub fn get_module_inits() -> impl Iterator, StdlibInit { "_sqlite3" => sqlite::make_module, } - #[cfg(feature = "ssl")] + #[cfg(all(not(target_arch = "wasm32"), feature = "ssl-rustls"))] { "_ssl" => ssl::make_module, } + #[cfg(all(not(target_arch = "wasm32"), feature = "ssl-openssl"))] + { + "_ssl" => openssl::make_module, + } #[cfg(windows)] { "_overlapped" => overlapped::make_module, diff --git a/stdlib/src/openssl.rs b/stdlib/src/openssl.rs new file mode 100644 index 00000000000..ea67d605f76 --- /dev/null +++ b/stdlib/src/openssl.rs @@ -0,0 +1,3705 @@ +// spell-checker:disable + +mod cert; + +// Conditional compilation for OpenSSL version-specific error codes +cfg_if::cfg_if! { + if #[cfg(ossl310)] { + // OpenSSL 3.1.0+ + mod ssl_data_31; + use ssl_data_31 as ssl_data; + } else if #[cfg(ossl300)] { + // OpenSSL 3.0.0+ + mod ssl_data_300; + use ssl_data_300 as ssl_data; + } else { + // OpenSSL 1.1.1+ (fallback) + mod ssl_data_111; + use ssl_data_111 as ssl_data; + } +} + +use crate::vm::{PyRef, VirtualMachine, builtins::PyModule}; +use openssl_probe::ProbeResult; + +pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { + // if openssl is vendored, it doesn't know the locations + // of system certificates - cache the probe result now. + #[cfg(openssl_vendored)] + LazyLock::force(&PROBE); + _ssl::make_module(vm) +} + +// define our own copy of ProbeResult so we can handle the vendor case +// easily, without having to have a bunch of cfgs +cfg_if::cfg_if! { + if #[cfg(openssl_vendored)] { + use std::sync::LazyLock; + static PROBE: LazyLock = LazyLock::new(openssl_probe::probe); + fn probe() -> &'static ProbeResult { &PROBE } + } else { + fn probe() -> &'static ProbeResult { + &ProbeResult { cert_file: None, cert_dir: None } + } + } +} + +#[allow(non_upper_case_globals)] +#[pymodule(with(cert::ssl_cert, ossl101, ossl111, windows))] +mod _ssl { + use super::{bio, probe}; + use crate::{ + common::lock::{ + PyMappedRwLockReadGuard, PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, + }, + socket::{self, PySocket}, + vm::{ + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{ + PyBaseExceptionRef, PyBytesRef, PyListRef, PyOSError, PyStrRef, PyTypeRef, PyWeak, + }, + class_or_notimplemented, + convert::ToPyException, + exceptions, + function::{ + ArgBytesLike, ArgCallable, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath, + OptionalArg, PyComparisonValue, + }, + types::{Comparable, Constructor, PyComparisonOp}, + utils::ToCString, + }, + }; + use crossbeam_utils::atomic::AtomicCell; + use foreign_types_shared::{ForeignType, ForeignTypeRef}; + use openssl::{ + asn1::{Asn1Object, Asn1ObjectRef}, + error::ErrorStack, + nid::Nid, + ssl::{self, SslContextBuilder, SslOptions, SslVerifyMode}, + x509::X509, + }; + use openssl_sys as sys; + use rustpython_vm::ospath::OsPath; + use std::{ + ffi::CStr, + fmt, + io::{Read, Write}, + path::{Path, PathBuf}, + sync::LazyLock, + time::Instant, + }; + + // Import certificate types from parent module + use super::cert::{self, cert_to_certificate, cert_to_py}; + + // Re-export PySSLCertificate to make it available in the _ssl module + // It will be automatically exposed to Python via #[pyclass] + #[allow(unused_imports)] + use super::cert::PySSLCertificate; + + // Constants + #[pyattr] + use sys::{ + // SSL Alert Descriptions that are exported by openssl_sys + SSL_AD_DECODE_ERROR, + SSL_AD_ILLEGAL_PARAMETER, + SSL_AD_UNRECOGNIZED_NAME, + // SSL_ERROR_INVALID_ERROR_CODE, + SSL_ERROR_SSL, + // SSL_ERROR_WANT_X509_LOOKUP, + SSL_ERROR_SYSCALL, + SSL_ERROR_WANT_CONNECT, + SSL_ERROR_WANT_READ, + SSL_ERROR_WANT_WRITE, + SSL_ERROR_ZERO_RETURN, + SSL_OP_CIPHER_SERVER_PREFERENCE as OP_CIPHER_SERVER_PREFERENCE, + SSL_OP_ENABLE_MIDDLEBOX_COMPAT as OP_ENABLE_MIDDLEBOX_COMPAT, + SSL_OP_LEGACY_SERVER_CONNECT as OP_LEGACY_SERVER_CONNECT, + SSL_OP_NO_SSLv2 as OP_NO_SSLv2, + SSL_OP_NO_SSLv3 as OP_NO_SSLv3, + SSL_OP_NO_TICKET as OP_NO_TICKET, + SSL_OP_NO_TLSv1 as OP_NO_TLSv1, + SSL_OP_SINGLE_DH_USE as OP_SINGLE_DH_USE, + SSL_OP_SINGLE_ECDH_USE as OP_SINGLE_ECDH_USE, + X509_V_FLAG_ALLOW_PROXY_CERTS as VERIFY_ALLOW_PROXY_CERTS, + X509_V_FLAG_CRL_CHECK as VERIFY_CRL_CHECK_LEAF, + X509_V_FLAG_PARTIAL_CHAIN as VERIFY_X509_PARTIAL_CHAIN, + X509_V_FLAG_TRUSTED_FIRST as VERIFY_X509_TRUSTED_FIRST, + X509_V_FLAG_X509_STRICT as VERIFY_X509_STRICT, + }; + + // SSL Alert Descriptions (RFC 5246 and extensions) + // Hybrid approach: use openssl_sys constants where available, hardcode others + #[pyattr] + const ALERT_DESCRIPTION_CLOSE_NOTIFY: libc::c_int = 0; + #[pyattr] + const ALERT_DESCRIPTION_UNEXPECTED_MESSAGE: libc::c_int = 10; + #[pyattr] + const ALERT_DESCRIPTION_BAD_RECORD_MAC: libc::c_int = 20; + #[pyattr] + const ALERT_DESCRIPTION_RECORD_OVERFLOW: libc::c_int = 22; + #[pyattr] + const ALERT_DESCRIPTION_DECOMPRESSION_FAILURE: libc::c_int = 30; + #[pyattr] + const ALERT_DESCRIPTION_HANDSHAKE_FAILURE: libc::c_int = 40; + #[pyattr] + const ALERT_DESCRIPTION_BAD_CERTIFICATE: libc::c_int = 42; + #[pyattr] + const ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE: libc::c_int = 43; + #[pyattr] + const ALERT_DESCRIPTION_CERTIFICATE_REVOKED: libc::c_int = 44; + #[pyattr] + const ALERT_DESCRIPTION_CERTIFICATE_EXPIRED: libc::c_int = 45; + #[pyattr] + const ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN: libc::c_int = 46; + #[pyattr] + const ALERT_DESCRIPTION_ILLEGAL_PARAMETER: libc::c_int = SSL_AD_ILLEGAL_PARAMETER; + #[pyattr] + const ALERT_DESCRIPTION_UNKNOWN_CA: libc::c_int = 48; + #[pyattr] + const ALERT_DESCRIPTION_ACCESS_DENIED: libc::c_int = 49; + #[pyattr] + const ALERT_DESCRIPTION_DECODE_ERROR: libc::c_int = SSL_AD_DECODE_ERROR; + #[pyattr] + const ALERT_DESCRIPTION_DECRYPT_ERROR: libc::c_int = 51; + #[pyattr] + const ALERT_DESCRIPTION_PROTOCOL_VERSION: libc::c_int = 70; + #[pyattr] + const ALERT_DESCRIPTION_INSUFFICIENT_SECURITY: libc::c_int = 71; + #[pyattr] + const ALERT_DESCRIPTION_INTERNAL_ERROR: libc::c_int = 80; + #[pyattr] + const ALERT_DESCRIPTION_USER_CANCELLED: libc::c_int = 90; + #[pyattr] + const ALERT_DESCRIPTION_NO_RENEGOTIATION: libc::c_int = 100; + #[pyattr] + const ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION: libc::c_int = 110; + #[pyattr] + const ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE: libc::c_int = 111; + #[pyattr] + const ALERT_DESCRIPTION_UNRECOGNIZED_NAME: libc::c_int = SSL_AD_UNRECOGNIZED_NAME; + #[pyattr] + const ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE: libc::c_int = 113; + #[pyattr] + const ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE: libc::c_int = 114; + #[pyattr] + const ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY: libc::c_int = 115; + + // CRL verification constants + #[pyattr] + const VERIFY_CRL_CHECK_CHAIN: libc::c_ulong = + sys::X509_V_FLAG_CRL_CHECK | sys::X509_V_FLAG_CRL_CHECK_ALL; + + // taken from CPython, should probably be kept up to date with their version if it ever changes + #[pyattr] + const _DEFAULT_CIPHERS: &str = + "DEFAULT:!aNULL:!eNULL:!MD5:!3DES:!DES:!RC4:!IDEA:!SEED:!aDSS:!SRP:!PSK"; + // #[pyattr] PROTOCOL_SSLv2: u32 = SslVersion::Ssl2 as u32; // unsupported + // #[pyattr] PROTOCOL_SSLv3: u32 = SslVersion::Ssl3 as u32; + #[pyattr] + const PROTOCOL_SSLv23: u32 = SslVersion::Tls as u32; + #[pyattr] + const PROTOCOL_TLS: u32 = SslVersion::Tls as u32; + #[pyattr] + const PROTOCOL_TLS_CLIENT: u32 = SslVersion::TlsClient as u32; + #[pyattr] + const PROTOCOL_TLS_SERVER: u32 = SslVersion::TlsServer as u32; + #[pyattr] + const PROTOCOL_TLSv1: u32 = SslVersion::Tls1 as u32; + #[pyattr] + const PROTOCOL_TLSv1_1: u32 = SslVersion::Tls1_1 as u32; + #[pyattr] + const PROTOCOL_TLSv1_2: u32 = SslVersion::Tls1_2 as u32; + #[pyattr] + const PROTO_MINIMUM_SUPPORTED: i32 = ProtoVersion::MinSupported as i32; + #[pyattr] + const PROTO_SSLv3: i32 = ProtoVersion::Ssl3 as i32; + #[pyattr] + const PROTO_TLSv1: i32 = ProtoVersion::Tls1 as i32; + #[pyattr] + const PROTO_TLSv1_1: i32 = ProtoVersion::Tls1_1 as i32; + #[pyattr] + const PROTO_TLSv1_2: i32 = ProtoVersion::Tls1_2 as i32; + #[pyattr] + const PROTO_TLSv1_3: i32 = ProtoVersion::Tls1_3 as i32; + #[pyattr] + const PROTO_MAXIMUM_SUPPORTED: i32 = ProtoVersion::MaxSupported as i32; + #[pyattr] + const OP_ALL: libc::c_ulong = (sys::SSL_OP_ALL & !sys::SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS) as _; + #[pyattr] + const HAS_TLS_UNIQUE: bool = true; + #[pyattr] + const CERT_NONE: u32 = CertRequirements::None as u32; + #[pyattr] + const CERT_OPTIONAL: u32 = CertRequirements::Optional as u32; + #[pyattr] + const CERT_REQUIRED: u32 = CertRequirements::Required as u32; + #[pyattr] + const VERIFY_DEFAULT: u32 = 0; + #[pyattr] + const SSL_ERROR_EOF: u32 = 8; // custom for python + #[pyattr] + const HAS_SNI: bool = true; + #[pyattr] + const HAS_ECDH: bool = true; + #[pyattr] + const HAS_NPN: bool = false; + #[pyattr] + const HAS_ALPN: bool = true; + #[pyattr] + const HAS_SSLv2: bool = false; + #[pyattr] + const HAS_SSLv3: bool = false; + #[pyattr] + const HAS_TLSv1: bool = true; + #[pyattr] + const HAS_TLSv1_1: bool = true; + #[pyattr] + const HAS_TLSv1_2: bool = true; + #[pyattr] + const HAS_TLSv1_3: bool = cfg!(ossl111); + #[pyattr] + const HAS_PSK: bool = true; + + // Encoding constants for Certificate.public_bytes() + #[pyattr] + pub(crate) const ENCODING_PEM: i32 = sys::X509_FILETYPE_PEM; + #[pyattr] + pub(crate) const ENCODING_DER: i32 = sys::X509_FILETYPE_ASN1; + #[pyattr] + const ENCODING_PEM_AUX: i32 = sys::X509_FILETYPE_PEM + 0x100; + + // OpenSSL error codes for unexpected EOF detection + const ERR_LIB_SSL: i32 = 20; + const SSL_R_UNEXPECTED_EOF_WHILE_READING: i32 = 294; + + // SSL_VERIFY constants for post-handshake authentication + #[cfg(ossl111)] + const SSL_VERIFY_POST_HANDSHAKE: libc::c_int = 0x20; + + // the openssl version from the API headers + + #[pyattr(name = "OPENSSL_VERSION")] + fn openssl_version(_vm: &VirtualMachine) -> &str { + openssl::version::version() + } + #[pyattr(name = "OPENSSL_VERSION_NUMBER")] + fn openssl_version_number(_vm: &VirtualMachine) -> i64 { + openssl::version::number() + } + #[pyattr(name = "OPENSSL_VERSION_INFO")] + fn openssl_version_info(_vm: &VirtualMachine) -> OpensslVersionInfo { + parse_version_info(openssl::version::number()) + } + + #[pyattr(name = "_OPENSSL_API_VERSION")] + fn _openssl_api_version(_vm: &VirtualMachine) -> OpensslVersionInfo { + let openssl_api_version = i64::from_str_radix(env!("OPENSSL_API_VERSION"), 16) + .expect("OPENSSL_API_VERSION is malformed"); + parse_version_info(openssl_api_version) + } + + // SSL Exception Types + + /// An error occurred in the SSL implementation. + #[pyattr] + #[pyexception(name = "SSLError", base = PyOSError)] + #[derive(Debug)] + pub struct PySslError {} + + #[pyexception] + impl PySslError { + // Returns strerror attribute if available, otherwise str(args) + #[pymethod] + fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { + // Try to get strerror attribute first (OSError compatibility) + if let Ok(strerror) = exc.as_object().get_attr("strerror", vm) + && !vm.is_none(&strerror) + { + return strerror.str(vm); + } + + // Otherwise return str(args) + exc.args().as_object().str(vm) + } + } + + /// A certificate could not be verified. + #[pyattr] + #[pyexception(name = "SSLCertVerificationError", base = PySslError)] + #[derive(Debug)] + pub struct PySslCertVerificationError {} + + #[pyexception] + impl PySslCertVerificationError {} + + /// SSL/TLS session closed cleanly. + #[pyattr] + #[pyexception(name = "SSLZeroReturnError", base = PySslError)] + #[derive(Debug)] + pub struct PySslZeroReturnError {} + + #[pyexception] + impl PySslZeroReturnError {} + + /// Non-blocking SSL socket needs to read more data. + #[pyattr] + #[pyexception(name = "SSLWantReadError", base = PySslError)] + #[derive(Debug)] + pub struct PySslWantReadError {} + + #[pyexception] + impl PySslWantReadError {} + + /// Non-blocking SSL socket needs to write more data. + #[pyattr] + #[pyexception(name = "SSLWantWriteError", base = PySslError)] + #[derive(Debug)] + pub struct PySslWantWriteError {} + + #[pyexception] + impl PySslWantWriteError {} + + /// System error when attempting SSL operation. + #[pyattr] + #[pyexception(name = "SSLSyscallError", base = PySslError)] + #[derive(Debug)] + pub struct PySslSyscallError {} + + #[pyexception] + impl PySslSyscallError {} + + /// SSL/TLS connection terminated abruptly. + #[pyattr] + #[pyexception(name = "SSLEOFError", base = PySslError)] + #[derive(Debug)] + pub struct PySslEOFError {} + + #[pyexception] + impl PySslEOFError {} + + type OpensslVersionInfo = (u8, u8, u8, u8, u8); + const fn parse_version_info(mut n: i64) -> OpensslVersionInfo { + let status = (n & 0xF) as u8; + n >>= 4; + let patch = (n & 0xFF) as u8; + n >>= 8; + let fix = (n & 0xFF) as u8; + n >>= 8; + let minor = (n & 0xFF) as u8; + n >>= 8; + let major = (n & 0xFF) as u8; + (major, minor, fix, patch, status) + } + + #[derive(Copy, Clone, num_enum::IntoPrimitive, num_enum::TryFromPrimitive, PartialEq)] + #[repr(i32)] + enum SslVersion { + Ssl2, + Ssl3 = 1, + Tls, + Tls1, + Tls1_1, + Tls1_2, + TlsClient = 0x10, + TlsServer, + } + + #[derive(Copy, Clone, num_enum::IntoPrimitive, num_enum::TryFromPrimitive)] + #[repr(i32)] + enum ProtoVersion { + MinSupported = -2, + Ssl3 = sys::SSL3_VERSION, + Tls1 = sys::TLS1_VERSION, + Tls1_1 = sys::TLS1_1_VERSION, + Tls1_2 = sys::TLS1_2_VERSION, + #[cfg(ossl111)] + Tls1_3 = sys::TLS1_3_VERSION, + #[cfg(not(ossl111))] + Tls1_3 = 0x304, + MaxSupported = -1, + } + + #[derive(num_enum::IntoPrimitive, num_enum::TryFromPrimitive)] + #[repr(i32)] + enum CertRequirements { + None, + Optional, + Required, + } + + #[derive(Debug, PartialEq)] + enum SslServerOrClient { + Client, + Server, + } + + unsafe fn ptr2obj(ptr: *mut sys::ASN1_OBJECT) -> Option { + if ptr.is_null() { + None + } else { + Some(unsafe { Asn1Object::from_ptr(ptr) }) + } + } + + fn _txt2obj(s: &CStr, no_name: bool) -> Option { + unsafe { ptr2obj(sys::OBJ_txt2obj(s.as_ptr(), i32::from(no_name))) } + } + fn _nid2obj(nid: Nid) -> Option { + unsafe { ptr2obj(sys::OBJ_nid2obj(nid.as_raw())) } + } + + type PyNid = (libc::c_int, String, String, Option); + fn obj2py(obj: &Asn1ObjectRef, vm: &VirtualMachine) -> PyResult { + let nid = obj.nid(); + let short_name = nid + .short_name() + .map_err(|_| vm.new_value_error("NID has no short name".to_owned()))? + .to_owned(); + let long_name = nid + .long_name() + .map_err(|_| vm.new_value_error("NID has no long name".to_owned()))? + .to_owned(); + Ok(( + nid.as_raw(), + short_name, + long_name, + cert::obj2txt(obj, true), + )) + } + + #[derive(FromArgs)] + struct Txt2ObjArgs { + txt: PyStrRef, + #[pyarg(any, default = false)] + name: bool, + } + + #[pyfunction] + fn txt2obj(args: Txt2ObjArgs, vm: &VirtualMachine) -> PyResult { + _txt2obj(&args.txt.to_cstring(vm)?, !args.name) + .as_deref() + .ok_or_else(|| vm.new_value_error(format!("unknown object '{}'", args.txt))) + .and_then(|obj| obj2py(obj, vm)) + } + + #[pyfunction] + fn nid2obj(nid: libc::c_int, vm: &VirtualMachine) -> PyResult { + _nid2obj(Nid::from_raw(nid)) + .as_deref() + .ok_or_else(|| vm.new_value_error(format!("unknown NID {nid}"))) + .and_then(|obj| obj2py(obj, vm)) + } + + // Lazily compute and cache cert file/dir paths + static CERT_PATHS: LazyLock<(PathBuf, PathBuf)> = LazyLock::new(|| { + fn path_from_cstr(c: &CStr) -> PathBuf { + #[cfg(unix)] + { + use std::os::unix::ffi::OsStrExt; + std::ffi::OsStr::from_bytes(c.to_bytes()).into() + } + #[cfg(windows)] + { + // Use lossy conversion for potential non-UTF8 + PathBuf::from(c.to_string_lossy().as_ref()) + } + } + + let probe = probe(); + let cert_file = probe + .cert_file + .as_ref() + .map(PathBuf::from) + .unwrap_or_else(|| { + path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_file()) }) + }); + let cert_dir = probe + .cert_dir + .as_ref() + .map(PathBuf::from) + .unwrap_or_else(|| { + path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir()) }) + }); + (cert_file, cert_dir) + }); + + fn get_cert_file_dir() -> (&'static Path, &'static Path) { + let (cert_file, cert_dir) = &*CERT_PATHS; + (cert_file.as_path(), cert_dir.as_path()) + } + + // Lazily compute and cache cert environment variable names + static CERT_ENV_NAMES: LazyLock<(String, String)> = LazyLock::new(|| { + let cert_file_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_file_env()) } + .to_string_lossy() + .into_owned(); + let cert_dir_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir_env()) } + .to_string_lossy() + .into_owned(); + (cert_file_env, cert_dir_env) + }); + + #[pyfunction] + fn get_default_verify_paths( + vm: &VirtualMachine, + ) -> PyResult<(&'static str, PyObjectRef, &'static str, PyObjectRef)> { + let (cert_file_env, cert_dir_env) = &*CERT_ENV_NAMES; + let (cert_file, cert_dir) = get_cert_file_dir(); + let cert_file = OsPath::new_str(cert_file).filename(vm); + let cert_dir = OsPath::new_str(cert_dir).filename(vm); + Ok(( + cert_file_env.as_str(), + cert_file, + cert_dir_env.as_str(), + cert_dir, + )) + } + + #[pyfunction(name = "RAND_status")] + fn rand_status() -> i32 { + unsafe { sys::RAND_status() } + } + + #[pyfunction(name = "RAND_add")] + fn rand_add(string: ArgStrOrBytesLike, entropy: f64) { + let f = |b: &[u8]| { + for buf in b.chunks(libc::c_int::MAX as usize) { + unsafe { sys::RAND_add(buf.as_ptr() as *const _, buf.len() as _, entropy) } + } + }; + f(&string.borrow_bytes()) + } + + #[pyfunction(name = "RAND_bytes")] + fn rand_bytes(n: i32, vm: &VirtualMachine) -> PyResult> { + if n < 0 { + return Err(vm.new_value_error("num must be positive")); + } + let mut buf = vec![0; n as usize]; + openssl::rand::rand_bytes(&mut buf).map_err(|e| convert_openssl_error(vm, e))?; + Ok(buf) + } + + // Callback data stored in SSL context for SNI + struct SniCallbackData { + ssl_context: PyRef, + vm_ptr: *const VirtualMachine, + } + + impl Drop for SniCallbackData { + fn drop(&mut self) { + // PyRef will handle reference counting + } + } + + // Get or create an ex_data index for SNI callback data + fn get_sni_ex_data_index() -> libc::c_int { + use std::sync::LazyLock; + static SNI_EX_DATA_IDX: LazyLock = LazyLock::new(|| unsafe { + sys::SSL_get_ex_new_index( + 0, + std::ptr::null_mut(), + None, + None, + Some(sni_callback_data_free), + ) + }); + *SNI_EX_DATA_IDX + } + + // Free function for callback data + unsafe extern "C" fn sni_callback_data_free( + _parent: *mut libc::c_void, + ptr: *mut libc::c_void, + _ad: *mut sys::CRYPTO_EX_DATA, + _idx: libc::c_int, + _argl: libc::c_long, + _argp: *mut libc::c_void, + ) { + if !ptr.is_null() { + unsafe { + let _ = Box::from_raw(ptr as *mut SniCallbackData); + } + } + } + + // SNI callback function called by OpenSSL + unsafe extern "C" fn _servername_callback( + ssl_ptr: *mut sys::SSL, + al: *mut libc::c_int, + arg: *mut libc::c_void, + ) -> libc::c_int { + const SSL_TLSEXT_ERR_OK: libc::c_int = 0; + const SSL_TLSEXT_ERR_ALERT_FATAL: libc::c_int = 2; + const SSL_AD_INTERNAL_ERROR: libc::c_int = 80; + const TLSEXT_NAMETYPE_host_name: libc::c_int = 0; + + if arg.is_null() { + return SSL_TLSEXT_ERR_OK; + } + + unsafe { + let ctx = &*(arg as *const PySslContext); + + // Get the callback + let callback_opt = ctx.sni_callback.lock().clone(); + let Some(callback) = callback_opt else { + return SSL_TLSEXT_ERR_OK; + }; + + // Get callback data from SSL ex_data + let idx = get_sni_ex_data_index(); + let data_ptr = sys::SSL_get_ex_data(ssl_ptr, idx); + if data_ptr.is_null() { + return SSL_TLSEXT_ERR_ALERT_FATAL; + } + + let callback_data = &*(data_ptr as *const SniCallbackData); + + // SAFETY: vm_ptr is stored during wrap_socket and is valid for the lifetime + // of the SSL connection. The handshake happens synchronously in the same thread. + let vm = &*callback_data.vm_ptr; + + // Get server name + let servername = sys::SSL_get_servername(ssl_ptr, TLSEXT_NAMETYPE_host_name); + let server_name_arg = if servername.is_null() { + vm.ctx.none() + } else { + let name_cstr = std::ffi::CStr::from_ptr(servername); + match name_cstr.to_str() { + Ok(name_str) => vm.ctx.new_str(name_str).into(), + Err(_) => vm.ctx.none(), + } + }; + + // Get SSL socket from SSL ex_data (stored as PySslSocket pointer) + let ssl_socket_ptr = sys::SSL_get_ex_data(ssl_ptr, 0); // Index 0 for SSL socket + let ssl_socket_obj = if !ssl_socket_ptr.is_null() { + let ssl_socket = &*(ssl_socket_ptr as *const PySslSocket); + // Try to get owner first + ssl_socket + .owner + .read() + .as_ref() + .and_then(|weak| weak.upgrade()) + .unwrap_or_else(|| vm.ctx.none()) + } else { + vm.ctx.none() + }; + + // Call the Python callback + match callback.call( + ( + ssl_socket_obj, + server_name_arg, + callback_data.ssl_context.to_owned(), + ), + vm, + ) { + Ok(result) => { + // Check return value type (must be None or integer) + if vm.is_none(&result) { + // None is OK + SSL_TLSEXT_ERR_OK + } else { + // Try to convert to integer + match result.try_to_value::(vm) { + Ok(alert_code) => { + // Valid integer - use as alert code + *al = alert_code; + SSL_TLSEXT_ERR_ALERT_FATAL + } + Err(_) => { + // Type conversion failed - raise TypeError + let type_error = vm.new_type_error(format!( + "servername callback must return None or an integer, not '{}'", + result.class().name() + )); + vm.run_unraisable(type_error, None, result); + *al = SSL_AD_INTERNAL_ERROR; + SSL_TLSEXT_ERR_ALERT_FATAL + } + } + } + } + Err(exc) => { + // Log the exception but don't propagate it + vm.run_unraisable(exc, None, vm.ctx.none()); + *al = SSL_AD_INTERNAL_ERROR; + SSL_TLSEXT_ERR_ALERT_FATAL + } + } + } + } + + // Message callback function called by OpenSSL + // Based on CPython's _PySSL_msg_callback in Modules/_ssl/debughelpers.c + unsafe extern "C" fn _msg_callback( + write_p: libc::c_int, + version: libc::c_int, + content_type: libc::c_int, + buf: *const libc::c_void, + len: usize, + ssl_ptr: *mut sys::SSL, + _arg: *mut libc::c_void, + ) { + if ssl_ptr.is_null() { + return; + } + + unsafe { + // Get SSL socket from SSL_get_app_data (index 0) + let ssl_socket_ptr = sys::SSL_get_ex_data(ssl_ptr, 0); + if ssl_socket_ptr.is_null() { + return; + } + + let ssl_socket = &*(ssl_socket_ptr as *const PySslSocket); + + // Get the callback from the context + let callback_opt = ssl_socket.ctx.read().msg_callback.lock().clone(); + let Some(callback) = callback_opt else { + return; + }; + + // Get callback data from SSL ex_data (for VM) + let idx = get_sni_ex_data_index(); + let data_ptr = sys::SSL_get_ex_data(ssl_ptr, idx); + if data_ptr.is_null() { + return; + } + + let callback_data = &*(data_ptr as *const SniCallbackData); + let vm = &*callback_data.vm_ptr; + + // Get SSL socket owner object + let ssl_socket_obj = ssl_socket + .owner + .read() + .as_ref() + .and_then(|weak| weak.upgrade()) + .unwrap_or_else(|| vm.ctx.none()); + + // Create the message bytes + let buf_slice = std::slice::from_raw_parts(buf as *const u8, len); + let msg_bytes = vm.ctx.new_bytes(buf_slice.to_vec()); + + // Determine direction string + let direction_str = if write_p != 0 { "write" } else { "read" }; + + // Call the Python callback + // Signature: callback(conn, direction, version, content_type, msg_type, data) + // For simplicity, we'll pass msg_type as 0 (would need more parsing to get the actual type) + match callback.call( + ( + ssl_socket_obj, + vm.ctx.new_str(direction_str), + vm.ctx.new_int(version), + vm.ctx.new_int(content_type), + vm.ctx.new_int(0), // msg_type - would need parsing + msg_bytes, + ), + vm, + ) { + Ok(_) => {} + Err(exc) => { + // Log the exception but don't propagate it + vm.run_unraisable(exc, None, vm.ctx.none()); + } + } + } + } + + #[pyfunction(name = "RAND_pseudo_bytes")] + fn rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec, bool)> { + if n < 0 { + return Err(vm.new_value_error("num must be positive")); + } + let mut buf = vec![0; n as usize]; + let ret = unsafe { sys::RAND_bytes(buf.as_mut_ptr(), n) }; + match ret { + 0 | 1 => Ok((buf, ret == 1)), + _ => Err(convert_openssl_error(vm, ErrorStack::get())), + } + } + + #[pyattr] + #[pyclass(module = "ssl", name = "_SSLContext")] + #[derive(PyPayload)] + struct PySslContext { + ctx: PyRwLock, + check_hostname: AtomicCell, + protocol: SslVersion, + post_handshake_auth: PyMutex, + sni_callback: PyMutex>, + msg_callback: PyMutex>, + } + + impl fmt::Debug for PySslContext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("_SSLContext") + } + } + + fn builder_as_ctx(x: &SslContextBuilder) -> &ssl::SslContextRef { + unsafe { ssl::SslContextRef::from_ptr(x.as_ptr()) } + } + + impl Constructor for PySslContext { + type Args = i32; + + fn py_new(cls: PyTypeRef, proto_version: Self::Args, vm: &VirtualMachine) -> PyResult { + let proto = SslVersion::try_from(proto_version) + .map_err(|_| vm.new_value_error("invalid protocol version"))?; + let method = match proto { + // SslVersion::Ssl3 => unsafe { ssl::SslMethod::from_ptr(sys::SSLv3_method()) }, + SslVersion::Tls => ssl::SslMethod::tls(), + SslVersion::Tls1 => ssl::SslMethod::tls(), + SslVersion::Tls1_1 => ssl::SslMethod::tls(), + SslVersion::Tls1_2 => ssl::SslMethod::tls(), + SslVersion::TlsClient => ssl::SslMethod::tls_client(), + SslVersion::TlsServer => ssl::SslMethod::tls_server(), + _ => return Err(vm.new_value_error("invalid protocol version")), + }; + let mut builder = + SslContextBuilder::new(method).map_err(|e| convert_openssl_error(vm, e))?; + + #[cfg(target_os = "android")] + android::load_client_ca_list(vm, &mut builder)?; + + let check_hostname = proto == SslVersion::TlsClient; + builder.set_verify(if check_hostname { + SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT + } else { + SslVerifyMode::NONE + }); + + let mut options = SslOptions::ALL & !SslOptions::DONT_INSERT_EMPTY_FRAGMENTS; + if proto != SslVersion::Ssl2 { + options |= SslOptions::NO_SSLV2; + } + if proto != SslVersion::Ssl3 { + options |= SslOptions::NO_SSLV3; + } + options |= SslOptions::NO_COMPRESSION; + options |= SslOptions::CIPHER_SERVER_PREFERENCE; + options |= SslOptions::SINGLE_DH_USE; + options |= SslOptions::SINGLE_ECDH_USE; + options |= SslOptions::ENABLE_MIDDLEBOX_COMPAT; + builder.set_options(options); + + let mode = ssl::SslMode::ACCEPT_MOVING_WRITE_BUFFER | ssl::SslMode::AUTO_RETRY; + builder.set_mode(mode); + + #[cfg(ossl111)] + unsafe { + sys::SSL_CTX_set_post_handshake_auth(builder.as_ptr(), 0); + } + + // Note: Unlike some other implementations, we do NOT set session_id_context at the + // context level. CPython sets it only on individual SSL objects (server-side only). + // This matches CPython's behavior in _ssl.c where SSL_set_session_id_context is called + // in newPySSLSocket() at line 862, not during context creation. + + // Set protocol version limits based on the protocol version + unsafe { + let ctx_ptr = builder.as_ptr(); + match proto { + SslVersion::Tls1 => { + sys::SSL_CTX_set_min_proto_version(ctx_ptr, sys::TLS1_VERSION); + sys::SSL_CTX_set_max_proto_version(ctx_ptr, sys::TLS1_VERSION); + } + SslVersion::Tls1_1 => { + sys::SSL_CTX_set_min_proto_version(ctx_ptr, sys::TLS1_1_VERSION); + sys::SSL_CTX_set_max_proto_version(ctx_ptr, sys::TLS1_1_VERSION); + } + SslVersion::Tls1_2 => { + sys::SSL_CTX_set_min_proto_version(ctx_ptr, sys::TLS1_2_VERSION); + sys::SSL_CTX_set_max_proto_version(ctx_ptr, sys::TLS1_2_VERSION); + } + _ => { + // For Tls, TlsClient, TlsServer, use default (no restrictions) + } + } + } + + // Set default verify flags: VERIFY_X509_TRUSTED_FIRST + unsafe { + let ctx_ptr = builder.as_ptr(); + let param = sys::SSL_CTX_get0_param(ctx_ptr); + sys::X509_VERIFY_PARAM_set_flags(param, sys::X509_V_FLAG_TRUSTED_FIRST); + } + + PySslContext { + ctx: PyRwLock::new(builder), + check_hostname: AtomicCell::new(check_hostname), + protocol: proto, + post_handshake_auth: PyMutex::new(false), + sni_callback: PyMutex::new(None), + msg_callback: PyMutex::new(None), + } + .into_ref_with_type(vm, cls) + .map(Into::into) + } + } + + #[pyclass(flags(BASETYPE, IMMUTABLETYPE), with(Constructor))] + impl PySslContext { + fn builder(&self) -> PyRwLockWriteGuard<'_, SslContextBuilder> { + self.ctx.write() + } + fn ctx(&self) -> PyMappedRwLockReadGuard<'_, ssl::SslContextRef> { + PyRwLockReadGuard::map(self.ctx.read(), builder_as_ctx) + } + + #[pygetset] + fn post_handshake_auth(&self) -> bool { + *self.post_handshake_auth.lock() + } + #[pygetset(setter)] + fn set_post_handshake_auth( + &self, + value: Option, + vm: &VirtualMachine, + ) -> PyResult<()> { + let value = value.ok_or_else(|| vm.new_attribute_error("cannot delete attribute"))?; + *self.post_handshake_auth.lock() = value.is_true(vm)?; + Ok(()) + } + + #[cfg(ossl110)] + #[pygetset] + fn security_level(&self) -> i32 { + unsafe { SSL_CTX_get_security_level(self.ctx().as_ptr()) } + } + + #[pymethod] + fn set_ciphers(&self, cipherlist: PyStrRef, vm: &VirtualMachine) -> PyResult<()> { + let ciphers = cipherlist.as_str(); + if ciphers.contains('\0') { + return Err(exceptions::cstring_error(vm)); + } + self.builder().set_cipher_list(ciphers).map_err(|_| { + vm.new_exception_msg( + PySslError::class(&vm.ctx).to_owned(), + "No cipher can be selected.".to_owned(), + ) + }) + } + + #[pymethod] + fn get_ciphers(&self, vm: &VirtualMachine) -> PyResult { + let ctx = self.ctx(); + let ssl = ssl::Ssl::new(&ctx).map_err(|e| convert_openssl_error(vm, e))?; + + unsafe { + let ciphers_ptr = SSL_get_ciphers(ssl.as_ptr()); + if ciphers_ptr.is_null() { + return Ok(vm.ctx.new_list(vec![])); + } + + let num_ciphers = sys::OPENSSL_sk_num(ciphers_ptr as *const _); + let mut result = Vec::new(); + + for i in 0..num_ciphers { + let cipher_ptr = + sys::OPENSSL_sk_value(ciphers_ptr as *const _, i) as *const sys::SSL_CIPHER; + let cipher = ssl::SslCipherRef::from_ptr(cipher_ptr as *mut _); + + let (name, version, bits) = cipher_to_tuple(cipher); + let dict = vm.ctx.new_dict(); + dict.set_item("name", vm.ctx.new_str(name).into(), vm)?; + dict.set_item("protocol", vm.ctx.new_str(version).into(), vm)?; + dict.set_item("secret_bits", vm.ctx.new_int(bits).into(), vm)?; + + // Add description field + let description = cipher_description(cipher_ptr); + dict.set_item("description", vm.ctx.new_str(description).into(), vm)?; + + result.push(dict.into()); + } + + Ok(vm.ctx.new_list(result)) + } + } + + #[pymethod] + fn set_ecdh_curve( + &self, + name: Either, + vm: &VirtualMachine, + ) -> PyResult<()> { + use openssl::ec::{EcGroup, EcKey}; + + // Convert name to CString, supporting both str and bytes + let name_cstr = match name { + Either::A(s) => { + if s.as_str().contains('\0') { + return Err(exceptions::cstring_error(vm)); + } + s.to_cstring(vm)? + } + Either::B(b) => std::ffi::CString::new(b.borrow_buf().to_vec()) + .map_err(|_| exceptions::cstring_error(vm))?, + }; + + // Find the NID for the curve name using OBJ_sn2nid + let nid_raw = unsafe { sys::OBJ_sn2nid(name_cstr.as_ptr()) }; + if nid_raw == 0 { + return Err(vm.new_value_error("unknown curve name")); + } + let nid = Nid::from_raw(nid_raw); + + // Create EC key from the curve + let group = EcGroup::from_curve_name(nid).map_err(|e| convert_openssl_error(vm, e))?; + let key = EcKey::from_group(&group).map_err(|e| convert_openssl_error(vm, e))?; + + // Set the temporary ECDH key + self.builder() + .set_tmp_ecdh(&key) + .map_err(|e| convert_openssl_error(vm, e)) + } + + #[pygetset] + fn options(&self) -> libc::c_ulong { + self.ctx.read().options().bits() as _ + } + #[pygetset(setter)] + fn set_options(&self, opts: libc::c_ulong) { + self.builder() + .set_options(SslOptions::from_bits_truncate(opts as _)); + } + #[pygetset] + fn protocol(&self) -> i32 { + self.protocol as i32 + } + #[pygetset] + fn verify_mode(&self) -> i32 { + let mode = self.ctx().verify_mode(); + if mode == SslVerifyMode::NONE { + CertRequirements::None.into() + } else if mode == SslVerifyMode::PEER { + CertRequirements::Optional.into() + } else if mode == SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT { + CertRequirements::Required.into() + } else { + unreachable!() + } + } + #[pygetset(setter)] + fn set_verify_mode(&self, cert: i32, vm: &VirtualMachine) -> PyResult<()> { + let mut ctx = self.builder(); + let cert_req = CertRequirements::try_from(cert) + .map_err(|_| vm.new_value_error("invalid value for verify_mode"))?; + let mode = match cert_req { + CertRequirements::None if self.check_hostname.load() => { + return Err(vm.new_value_error( + "Cannot set verify_mode to CERT_NONE when check_hostname is enabled.", + )); + } + CertRequirements::None => SslVerifyMode::NONE, + CertRequirements::Optional => SslVerifyMode::PEER, + CertRequirements::Required => { + SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT + } + }; + ctx.set_verify(mode); + Ok(()) + } + #[pygetset] + fn verify_flags(&self) -> libc::c_ulong { + unsafe { + let ctx_ptr = self.ctx().as_ptr(); + let param = sys::SSL_CTX_get0_param(ctx_ptr); + sys::X509_VERIFY_PARAM_get_flags(param) + } + } + #[pygetset(setter)] + fn set_verify_flags(&self, new_flags: libc::c_ulong, vm: &VirtualMachine) -> PyResult<()> { + unsafe { + let ctx_ptr = self.ctx().as_ptr(); + let param = sys::SSL_CTX_get0_param(ctx_ptr); + let flags = sys::X509_VERIFY_PARAM_get_flags(param); + let clear = flags & !new_flags; + let set = !flags & new_flags; + + if clear != 0 && sys::X509_VERIFY_PARAM_clear_flags(param, clear) == 0 { + return Err(vm.new_exception_msg( + PySslError::class(&vm.ctx).to_owned(), + "Failed to clear verify flags".to_owned(), + )); + } + if set != 0 && sys::X509_VERIFY_PARAM_set_flags(param, set) == 0 { + return Err(vm.new_exception_msg( + PySslError::class(&vm.ctx).to_owned(), + "Failed to set verify flags".to_owned(), + )); + } + Ok(()) + } + } + #[pygetset] + fn check_hostname(&self) -> bool { + self.check_hostname.load() + } + #[pygetset(setter)] + fn set_check_hostname(&self, ch: bool) { + let mut ctx = self.builder(); + if ch && builder_as_ctx(&ctx).verify_mode() == SslVerifyMode::NONE { + ctx.set_verify(SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT); + } + self.check_hostname.store(ch); + } + + // PY_PROTO_MINIMUM_SUPPORTED = -2, PY_PROTO_MAXIMUM_SUPPORTED = -1 + #[pygetset] + fn minimum_version(&self) -> i32 { + let ctx = self.ctx(); + let version = unsafe { sys::SSL_CTX_get_min_proto_version(ctx.as_ptr()) }; + if version == 0 { + -2 // PY_PROTO_MINIMUM_SUPPORTED + } else { + version + } + } + #[pygetset(setter)] + fn set_minimum_version(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> { + // Handle special values + let proto_version = match value { + -2 => { + // PY_PROTO_MINIMUM_SUPPORTED -> use minimum available (TLS 1.2) + sys::TLS1_2_VERSION + } + -1 => { + // PY_PROTO_MAXIMUM_SUPPORTED -> use maximum available + // For max on min_proto_version, we use the newest available + sys::TLS1_3_VERSION + } + _ => value, + }; + + let ctx = self.builder(); + let result = unsafe { sys::SSL_CTX_set_min_proto_version(ctx.as_ptr(), proto_version) }; + if result == 0 { + return Err(vm.new_value_error("invalid protocol version")); + } + Ok(()) + } + + #[pygetset] + fn maximum_version(&self) -> i32 { + let ctx = self.ctx(); + let version = unsafe { sys::SSL_CTX_get_max_proto_version(ctx.as_ptr()) }; + if version == 0 { + -1 // PY_PROTO_MAXIMUM_SUPPORTED + } else { + version + } + } + #[pygetset(setter)] + fn set_maximum_version(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> { + // Handle special values + let proto_version = match value { + -1 => { + // PY_PROTO_MAXIMUM_SUPPORTED -> use 0 for OpenSSL (means no limit) + 0 + } + -2 => { + // PY_PROTO_MINIMUM_SUPPORTED -> use minimum available (TLS 1.2) + sys::TLS1_2_VERSION + } + _ => value, + }; + + let ctx = self.builder(); + let result = unsafe { sys::SSL_CTX_set_max_proto_version(ctx.as_ptr(), proto_version) }; + if result == 0 { + return Err(vm.new_value_error("invalid protocol version")); + } + Ok(()) + } + + #[pygetset] + fn num_tickets(&self, _vm: &VirtualMachine) -> PyResult { + // Only supported for TLS 1.3 + #[cfg(ossl110)] + { + let ctx = self.ctx(); + let num = unsafe { sys::SSL_CTX_get_num_tickets(ctx.as_ptr()) }; + Ok(num) + } + #[cfg(not(ossl110))] + { + Ok(0) + } + } + #[pygetset(setter)] + fn set_num_tickets(&self, value: isize, vm: &VirtualMachine) -> PyResult<()> { + // Check for negative values + if value < 0 { + return Err( + vm.new_value_error("num_tickets must be a non-negative integer".to_owned()) + ); + } + + // Check that this is a server context + if self.protocol != SslVersion::TlsServer { + return Err(vm.new_value_error("SSLContext is not a server context.".to_owned())); + } + + #[cfg(ossl110)] + { + let ctx = self.builder(); + let result = unsafe { sys::SSL_CTX_set_num_tickets(ctx.as_ptr(), value as usize) }; + if result != 1 { + return Err(vm.new_value_error("failed to set num tickets.")); + } + Ok(()) + } + #[cfg(not(ossl110))] + { + let _ = (value, vm); + Ok(()) + } + } + + #[pymethod] + fn set_default_verify_paths(&self, vm: &VirtualMachine) -> PyResult<()> { + cfg_if::cfg_if! { + if #[cfg(openssl_vendored)] { + let (cert_file, cert_dir) = get_cert_file_dir(); + self.builder() + .load_verify_locations(Some(cert_file), Some(cert_dir)) + .map_err(|e| convert_openssl_error(vm, e)) + } else { + self.builder() + .set_default_verify_paths() + .map_err(|e| convert_openssl_error(vm, e)) + } + } + } + + #[pymethod] + fn _set_alpn_protocols(&self, protos: ArgBytesLike, vm: &VirtualMachine) -> PyResult<()> { + #[cfg(ossl102)] + { + let mut ctx = self.builder(); + let server = protos.with_ref(|pbuf| { + if pbuf.len() > libc::c_uint::MAX as usize { + return Err(vm.new_overflow_error(format!( + "protocols longer than {} bytes", + libc::c_uint::MAX + ))); + } + ctx.set_alpn_protos(pbuf) + .map_err(|e| convert_openssl_error(vm, e))?; + Ok(pbuf.to_vec()) + })?; + ctx.set_alpn_select_callback(move |_, client| { + let proto = + ssl::select_next_proto(&server, client).ok_or(ssl::AlpnError::NOACK)?; + let pos = memchr::memmem::find(client, proto) + .expect("selected alpn proto should be present in client protos"); + Ok(&client[pos..proto.len()]) + }); + Ok(()) + } + #[cfg(not(ossl102))] + { + Err(vm.new_not_implemented_error( + "The NPN extension requires OpenSSL 1.0.1 or later.", + )) + } + } + + #[pymethod] + fn load_verify_locations( + &self, + args: LoadVerifyLocationsArgs, + vm: &VirtualMachine, + ) -> PyResult<()> { + if let (None, None, None) = (&args.cafile, &args.capath, &args.cadata) { + return Err(vm.new_type_error("cafile, capath and cadata cannot be all omitted")); + } + + #[cold] + fn invalid_cadata(vm: &VirtualMachine) -> PyBaseExceptionRef { + vm.new_type_error("cadata should be an ASCII string or a bytes-like object") + } + + let mut ctx = self.builder(); + + // validate cadata type and load cadata + if let Some(cadata) = args.cadata { + let certs = match cadata { + Either::A(s) => { + if !s.is_ascii() { + return Err(invalid_cadata(vm)); + } + X509::stack_from_pem(s.as_bytes()) + } + Either::B(b) => b.with_ref(x509_stack_from_der), + }; + let certs = certs.map_err(|e| convert_openssl_error(vm, e))?; + let store = ctx.cert_store_mut(); + for cert in certs { + store + .add_cert(cert) + .map_err(|e| convert_openssl_error(vm, e))?; + } + } + + if args.cafile.is_some() || args.capath.is_some() { + let cafile_path = args.cafile.map(|p| p.to_path_buf(vm)).transpose()?; + let capath_path = args.capath.map(|p| p.to_path_buf(vm)).transpose()?; + ctx.load_verify_locations(cafile_path.as_deref(), capath_path.as_deref()) + .map_err(|e| convert_openssl_error(vm, e))?; + } + + Ok(()) + } + + #[pymethod] + fn get_ca_certs( + &self, + binary_form: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult> { + let binary_form = binary_form.unwrap_or(false); + let ctx = self.ctx(); + #[cfg(ossl300)] + let certs = ctx.cert_store().all_certificates(); + #[cfg(not(ossl300))] + let certs = ctx.cert_store().objects().iter().filter_map(|x| x.x509()); + + // Filter to only include CA certificates (Basic Constraints: CA=TRUE) + let certs = certs + .into_iter() + .filter(|cert| { + unsafe { + // X509_check_ca() returns 1 for CA certificates + X509_check_ca(cert.as_ptr()) == 1 + } + }) + .map(|ref cert| cert_to_py(vm, cert, binary_form)) + .collect::, _>>()?; + Ok(certs) + } + + #[pymethod] + fn cert_store_stats(&self, vm: &VirtualMachine) -> PyResult { + let ctx = self.ctx(); + let store_ptr = unsafe { sys::SSL_CTX_get_cert_store(ctx.as_ptr()) }; + + if store_ptr.is_null() { + return Err(vm.new_memory_error("failed to get cert store".to_owned())); + } + + let objs_ptr = unsafe { sys::X509_STORE_get0_objects(store_ptr) }; + if objs_ptr.is_null() { + return Err(vm.new_memory_error("failed to query cert store".to_owned())); + } + + let mut x509_count = 0; + let mut crl_count = 0; + let mut ca_count = 0; + + unsafe { + let num_objs = sys::OPENSSL_sk_num(objs_ptr as *const _); + for i in 0..num_objs { + let obj_ptr = + sys::OPENSSL_sk_value(objs_ptr as *const _, i) as *const sys::X509_OBJECT; + let obj_type = X509_OBJECT_get_type(obj_ptr); + + match obj_type { + X509_LU_X509 => { + x509_count += 1; + let x509_ptr = sys::X509_OBJECT_get0_X509(obj_ptr); + if !x509_ptr.is_null() && X509_check_ca(x509_ptr) == 1 { + ca_count += 1; + } + } + X509_LU_CRL => { + crl_count += 1; + } + _ => { + // Ignore unrecognized types + } + } + } + // Note: No need to free objs_ptr as X509_STORE_get0_objects returns + // a pointer to internal data that should not be freed by the caller + } + + let dict = vm.ctx.new_dict(); + dict.set_item("x509", vm.ctx.new_int(x509_count).into(), vm)?; + dict.set_item("crl", vm.ctx.new_int(crl_count).into(), vm)?; + dict.set_item("x509_ca", vm.ctx.new_int(ca_count).into(), vm)?; + Ok(dict.into()) + } + + #[pymethod] + fn session_stats(&self, vm: &VirtualMachine) -> PyResult { + let ctx = self.ctx(); + let ctx_ptr = ctx.as_ptr(); + + let dict = vm.ctx.new_dict(); + + macro_rules! add_stat { + ($key:expr, $func:ident) => { + let value = unsafe { $func(ctx_ptr) }; + dict.set_item($key, vm.ctx.new_int(value).into(), vm)?; + }; + } + + add_stat!("number", SSL_CTX_sess_number); + add_stat!("connect", SSL_CTX_sess_connect); + add_stat!("connect_good", SSL_CTX_sess_connect_good); + add_stat!("connect_renegotiate", SSL_CTX_sess_connect_renegotiate); + add_stat!("accept", SSL_CTX_sess_accept); + add_stat!("accept_good", SSL_CTX_sess_accept_good); + add_stat!("accept_renegotiate", SSL_CTX_sess_accept_renegotiate); + add_stat!("hits", SSL_CTX_sess_hits); + add_stat!("misses", SSL_CTX_sess_misses); + add_stat!("timeouts", SSL_CTX_sess_timeouts); + add_stat!("cache_full", SSL_CTX_sess_cache_full); + + Ok(dict.into()) + } + + #[pymethod] + fn load_dh_params(&self, filepath: FsPath, vm: &VirtualMachine) -> PyResult<()> { + let path = filepath.to_path_buf(vm)?; + + // Open the file using fopen (cross-platform) + let fp = + rustpython_common::fileutils::fopen(path.as_path(), "rb").map_err(|e| { + match e.kind() { + std::io::ErrorKind::NotFound => vm.new_exception_msg( + vm.ctx.exceptions.file_not_found_error.to_owned(), + e.to_string(), + ), + _ => vm.new_os_error(e.to_string()), + } + })?; + + // Read DH parameters + let dh = unsafe { + PEM_read_DHparams( + fp, + std::ptr::null_mut(), + std::ptr::null_mut(), + std::ptr::null_mut(), + ) + }; + unsafe { + libc::fclose(fp); + } + + if dh.is_null() { + return Err(convert_openssl_error(vm, ErrorStack::get())); + } + + // Set temporary DH parameters + let ctx = self.builder(); + let result = unsafe { sys::SSL_CTX_set_tmp_dh(ctx.as_ptr(), dh) }; + unsafe { + sys::DH_free(dh); + } + + if result != 1 { + return Err(convert_openssl_error(vm, ErrorStack::get())); + } + + Ok(()) + } + + #[pygetset] + fn sni_callback(&self) -> Option { + self.sni_callback.lock().clone() + } + + #[pygetset(setter)] + fn set_sni_callback( + &self, + value: Option, + vm: &VirtualMachine, + ) -> PyResult<()> { + // Check if this is a server context + if self.protocol == SslVersion::TlsClient { + return Err(vm.new_value_error( + "sni_callback cannot be set on TLS_CLIENT context".to_owned(), + )); + } + + let mut callback_guard = self.sni_callback.lock(); + + if let Some(callback_obj) = value { + if !vm.is_none(&callback_obj) { + // Check if callable + if !callback_obj.is_callable() { + return Err(vm.new_type_error("not a callable object".to_owned())); + } + + // Set the callback + *callback_guard = Some(callback_obj); + + // Set OpenSSL callback + unsafe { + sys::SSL_CTX_set_tlsext_servername_callback__fixed_rust( + self.ctx().as_ptr(), + Some(_servername_callback), + ); + sys::SSL_CTX_set_tlsext_servername_arg( + self.ctx().as_ptr(), + self as *const _ as *mut _, + ); + } + } else { + // Clear callback + *callback_guard = None; + unsafe { + sys::SSL_CTX_set_tlsext_servername_callback__fixed_rust( + self.ctx().as_ptr(), + None, + ); + } + } + } else { + // Clear callback + *callback_guard = None; + unsafe { + sys::SSL_CTX_set_tlsext_servername_callback__fixed_rust( + self.ctx().as_ptr(), + None, + ); + } + } + + Ok(()) + } + + #[pymethod] + fn set_servername_callback( + &self, + callback: Option, + vm: &VirtualMachine, + ) -> PyResult<()> { + self.set_sni_callback(callback, vm) + } + + #[pygetset(name = "_msg_callback")] + fn msg_callback(&self) -> Option { + self.msg_callback.lock().clone() + } + + #[pygetset(setter, name = "_msg_callback")] + fn set_msg_callback( + &self, + value: Option, + vm: &VirtualMachine, + ) -> PyResult<()> { + let mut callback_guard = self.msg_callback.lock(); + + if let Some(callback_obj) = value { + if !vm.is_none(&callback_obj) { + // Check if callable + if !callback_obj.is_callable() { + return Err(vm.new_type_error("not a callable object".to_owned())); + } + + // Set the callback + *callback_guard = Some(callback_obj); + + // Set OpenSSL callback + unsafe { + SSL_CTX_set_msg_callback(self.ctx().as_ptr(), Some(_msg_callback)); + } + } else { + // Clear callback + *callback_guard = None; + unsafe { + SSL_CTX_set_msg_callback(self.ctx().as_ptr(), None); + } + } + } else { + // Clear callback when value is None + *callback_guard = None; + unsafe { + SSL_CTX_set_msg_callback(self.ctx().as_ptr(), None); + } + } + + Ok(()) + } + + #[pymethod] + fn load_cert_chain(&self, args: LoadCertChainArgs, vm: &VirtualMachine) -> PyResult<()> { + let LoadCertChainArgs { + certfile, + keyfile, + password, + } = args; + // TODO: requires passing a callback to C + if password.is_some() { + return Err(vm.new_not_implemented_error("password arg not yet supported")); + } + let mut ctx = self.builder(); + let key_path = keyfile.map(|path| path.to_path_buf(vm)).transpose()?; + let cert_path = certfile.to_path_buf(vm)?; + ctx.set_certificate_chain_file(&cert_path) + .and_then(|()| { + ctx.set_private_key_file( + key_path.as_ref().unwrap_or(&cert_path), + ssl::SslFiletype::PEM, + ) + }) + .and_then(|()| ctx.check_private_key()) + .map_err(|e| convert_openssl_error(vm, e)) + } + + // Helper function to create SSL socket + // = CPython's newPySSLSocket() + fn new_py_ssl_socket( + ctx_ref: PyRef, + server_side: bool, + server_hostname: Option, + vm: &VirtualMachine, + ) -> PyResult<(ssl::Ssl, SslServerOrClient, Option)> { + // Validate socket type and context protocol + if server_side && ctx_ref.protocol == SslVersion::TlsClient { + return Err(vm.new_exception_msg( + PySslError::class(&vm.ctx).to_owned(), + "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(), + )); + } + if !server_side && ctx_ref.protocol == SslVersion::TlsServer { + return Err(vm.new_exception_msg( + PySslError::class(&vm.ctx).to_owned(), + "Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(), + )); + } + + // Create SSL object + let mut ssl = + ssl::Ssl::new(&ctx_ref.ctx()).map_err(|e| convert_openssl_error(vm, e))?; + + // Set session id context for server-side sockets + let socket_type = if server_side { + unsafe { + const SID_CTX: &[u8] = b"Python"; + let ret = SSL_set_session_id_context( + ssl.as_ptr(), + SID_CTX.as_ptr(), + SID_CTX.len() as libc::c_uint, + ); + if ret == 0 { + return Err(convert_openssl_error(vm, ErrorStack::get())); + } + } + SslServerOrClient::Server + } else { + SslServerOrClient::Client + }; + + // Configure server hostname + if let Some(hostname) = &server_hostname { + let hostname_str = hostname.as_str(); + if hostname_str.is_empty() || hostname_str.starts_with('.') { + return Err(vm.new_value_error( + "server_hostname cannot be an empty string or start with a leading dot.", + )); + } + if hostname_str.contains('\0') { + return Err(vm.new_value_error("embedded null byte in server_hostname")); + } + let ip = hostname_str.parse::(); + if ip.is_err() { + ssl.set_hostname(hostname_str) + .map_err(|e| convert_openssl_error(vm, e))?; + } + if ctx_ref.check_hostname.load() { + if let Ok(ip) = ip { + ssl.param_mut() + .set_ip(ip) + .map_err(|e| convert_openssl_error(vm, e))?; + } else { + ssl.param_mut() + .set_host(hostname_str) + .map_err(|e| convert_openssl_error(vm, e))?; + } + } + } + + // Configure post-handshake authentication + #[cfg(ossl111)] + if *ctx_ref.post_handshake_auth.lock() { + unsafe { + if server_side { + // Server socket: add SSL_VERIFY_POST_HANDSHAKE flag + // Only in combination with SSL_VERIFY_PEER + let mode = sys::SSL_get_verify_mode(ssl.as_ptr()); + if (mode & sys::SSL_VERIFY_PEER as libc::c_int) != 0 { + sys::SSL_set_verify( + ssl.as_ptr(), + mode | SSL_VERIFY_POST_HANDSHAKE, + None, + ); + } + } else { + // Client socket: call SSL_set_post_handshake_auth + SSL_set_post_handshake_auth(ssl.as_ptr(), 1); + } + } + } + + // Set connect/accept state + if server_side { + ssl.set_accept_state(); + } else { + ssl.set_connect_state(); + } + + Ok((ssl, socket_type, server_hostname)) + } + + #[pymethod] + fn _wrap_socket( + zelf: PyRef, + args: WrapSocketArgs, + vm: &VirtualMachine, + ) -> PyResult { + // Use common helper function + let (ssl, socket_type, server_hostname) = + Self::new_py_ssl_socket(zelf.clone(), args.server_side, args.server_hostname, vm)?; + + // Create SslStream with socket + let stream = ssl::SslStream::new(ssl, SocketStream(args.sock.clone())) + .map_err(|e| convert_openssl_error(vm, e))?; + + let py_ssl_socket = PySslSocket { + ctx: PyRwLock::new(zelf.clone()), + connection: PyRwLock::new(SslConnection::Socket(stream)), + socket_type, + server_hostname, + owner: PyRwLock::new(args.owner.map(|o| o.downgrade(None, vm)).transpose()?), + }; + + // Convert to PyRef (heap allocation) to avoid use-after-free + let py_ref = + py_ssl_socket.into_ref_with_type(vm, PySslSocket::class(&vm.ctx).to_owned())?; + + // Set SNI callback data if callback is configured + if zelf.sni_callback.lock().is_some() { + unsafe { + let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); + + // Store callback data in SSL ex_data + let callback_data = Box::new(SniCallbackData { + ssl_context: zelf.clone(), + vm_ptr: vm as *const _, + }); + let idx = get_sni_ex_data_index(); + sys::SSL_set_ex_data(ssl_ptr, idx, Box::into_raw(callback_data) as *mut _); + + // Store PyRef pointer (heap-allocated) in ex_data index 0 + sys::SSL_set_ex_data(ssl_ptr, 0, &*py_ref as *const _ as *mut _); + } + } + + // Set session if provided + if let Some(session) = args.session + && !vm.is_none(&session) + { + py_ref.set_session(session, vm)?; + } + + Ok(py_ref.into()) + } + + #[pymethod] + fn _wrap_bio( + zelf: PyRef, + args: WrapBioArgs, + vm: &VirtualMachine, + ) -> PyResult { + // Use common helper function + let (ssl, socket_type, server_hostname) = + Self::new_py_ssl_socket(zelf.clone(), args.server_side, args.server_hostname, vm)?; + + // Create BioStream wrapper + let bio_stream = BioStream { + inbio: args.incoming, + outbio: args.outgoing, + }; + + // Create SslStream with BioStream + let stream = + ssl::SslStream::new(ssl, bio_stream).map_err(|e| convert_openssl_error(vm, e))?; + + let py_ssl_socket = PySslSocket { + ctx: PyRwLock::new(zelf.clone()), + connection: PyRwLock::new(SslConnection::Bio(stream)), + socket_type, + server_hostname, + owner: PyRwLock::new(args.owner.map(|o| o.downgrade(None, vm)).transpose()?), + }; + + // Convert to PyRef (heap allocation) to avoid use-after-free + let py_ref = + py_ssl_socket.into_ref_with_type(vm, PySslSocket::class(&vm.ctx).to_owned())?; + + // Set SNI callback data if callback is configured + if zelf.sni_callback.lock().is_some() { + unsafe { + let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); + + // Store callback data in SSL ex_data + let callback_data = Box::new(SniCallbackData { + ssl_context: zelf.clone(), + vm_ptr: vm as *const _, + }); + let idx = get_sni_ex_data_index(); + sys::SSL_set_ex_data(ssl_ptr, idx, Box::into_raw(callback_data) as *mut _); + + // Store PyRef pointer (heap-allocated) in ex_data index 0 + sys::SSL_set_ex_data(ssl_ptr, 0, &*py_ref as *const _ as *mut _); + } + } + + // Set session if provided + if let Some(session) = args.session + && !vm.is_none(&session) + { + py_ref.set_session(session, vm)?; + } + + Ok(py_ref.into()) + } + } + + #[derive(FromArgs)] + #[allow(dead_code)] // Fields will be used when _wrap_bio is fully implemented + struct WrapBioArgs { + incoming: PyRef, + outgoing: PyRef, + server_side: bool, + #[pyarg(any, default)] + server_hostname: Option, + #[pyarg(named, default)] + owner: Option, + #[pyarg(named, default)] + session: Option, + } + + #[derive(FromArgs)] + struct WrapSocketArgs { + sock: PyRef, + server_side: bool, + #[pyarg(any, default)] + server_hostname: Option, + #[pyarg(named, default)] + owner: Option, + #[pyarg(named, default)] + session: Option, + } + + #[derive(FromArgs)] + struct LoadVerifyLocationsArgs { + #[pyarg(any, default)] + cafile: Option, + #[pyarg(any, default)] + capath: Option, + #[pyarg(any, default)] + cadata: Option>, + } + + #[derive(FromArgs)] + struct LoadCertChainArgs { + certfile: FsPath, + #[pyarg(any, optional)] + keyfile: Option, + #[pyarg(any, optional)] + password: Option>, + } + + // Err is true if the socket is blocking + type SocketDeadline = Result; + + enum SelectRet { + Nonblocking, + TimedOut, + IsBlocking, + Closed, + Ok, + } + + #[derive(Clone, Copy)] + enum SslNeeds { + Read, + Write, + } + + struct SocketStream(PyRef); + + impl SocketStream { + fn timeout_deadline(&self) -> SocketDeadline { + self.0.get_timeout().map(|d| Instant::now() + d) + } + + fn select(&self, needs: SslNeeds, deadline: &SocketDeadline) -> SelectRet { + let sock = match self.0.sock_opt() { + Some(s) => s, + None => return SelectRet::Closed, + }; + let deadline = match &deadline { + Ok(deadline) => match deadline.checked_duration_since(Instant::now()) { + Some(deadline) => deadline, + None => return SelectRet::TimedOut, + }, + Err(true) => return SelectRet::IsBlocking, + Err(false) => return SelectRet::Nonblocking, + }; + let res = socket::sock_select( + &sock, + match needs { + SslNeeds::Read => socket::SelectKind::Read, + SslNeeds::Write => socket::SelectKind::Write, + }, + Some(deadline), + ); + match res { + Ok(true) => SelectRet::TimedOut, + _ => SelectRet::Ok, + } + } + + fn socket_needs( + &self, + err: &ssl::Error, + deadline: &SocketDeadline, + ) -> (Option, SelectRet) { + let needs = match err.code() { + ssl::ErrorCode::WANT_READ => Some(SslNeeds::Read), + ssl::ErrorCode::WANT_WRITE => Some(SslNeeds::Write), + _ => None, + }; + let state = needs.map_or(SelectRet::Ok, |needs| self.select(needs, deadline)); + (needs, state) + } + } + + fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef { + vm.new_exception_msg( + PySslError::class(&vm.ctx).to_owned(), + "Underlying socket has been closed.".to_owned(), + ) + } + + // BIO stream wrapper to implement Read/Write traits for MemoryBIO + struct BioStream { + inbio: PyRef, + outbio: PyRef, + } + + impl Read for BioStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + // Read from incoming MemoryBIO + unsafe { + let nbytes = sys::BIO_read( + self.inbio.bio, + buf.as_mut_ptr() as *mut _, + buf.len().min(i32::MAX as usize) as i32, + ); + if nbytes < 0 { + // BIO_read returns -1 on error or when no data is available + // Check if it's a retry condition (WANT_READ) + Err(std::io::Error::new( + std::io::ErrorKind::WouldBlock, + "BIO has no data available", + )) + } else { + Ok(nbytes as usize) + } + } + } + } + + impl Write for BioStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + // Write to outgoing MemoryBIO + unsafe { + let nbytes = sys::BIO_write( + self.outbio.bio, + buf.as_ptr() as *const _, + buf.len().min(i32::MAX as usize) as i32, + ); + if nbytes < 0 { + return Err(std::io::Error::other("BIO write failed")); + } + Ok(nbytes as usize) + } + } + + fn flush(&mut self) -> std::io::Result<()> { + // MemoryBIO doesn't need flushing + Ok(()) + } + } + + // Enum to represent different SSL connection modes + enum SslConnection { + Socket(ssl::SslStream), + Bio(ssl::SslStream), + } + + impl SslConnection { + // Get a reference to the SSL object + fn ssl(&self) -> &ssl::SslRef { + match self { + SslConnection::Socket(stream) => stream.ssl(), + SslConnection::Bio(stream) => stream.ssl(), + } + } + + // Get underlying socket stream reference (only for socket mode) + fn get_ref(&self) -> Option<&SocketStream> { + match self { + SslConnection::Socket(stream) => Some(stream.get_ref()), + SslConnection::Bio(_) => None, + } + } + + // Check if this is in BIO mode + fn is_bio(&self) -> bool { + matches!(self, SslConnection::Bio(_)) + } + + // Perform SSL handshake + fn do_handshake(&mut self) -> Result<(), ssl::Error> { + match self { + SslConnection::Socket(stream) => stream.do_handshake(), + SslConnection::Bio(stream) => stream.do_handshake(), + } + } + + // Write data to SSL connection + fn ssl_write(&mut self, buf: &[u8]) -> Result { + match self { + SslConnection::Socket(stream) => stream.ssl_write(buf), + SslConnection::Bio(stream) => stream.ssl_write(buf), + } + } + + // Read data from SSL connection + fn ssl_read(&mut self, buf: &mut [u8]) -> Result { + match self { + SslConnection::Socket(stream) => stream.ssl_read(buf), + SslConnection::Bio(stream) => stream.ssl_read(buf), + } + } + + // Get SSL shutdown state + fn get_shutdown(&mut self) -> ssl::ShutdownState { + match self { + SslConnection::Socket(stream) => stream.get_shutdown(), + SslConnection::Bio(stream) => stream.get_shutdown(), + } + } + } + + #[pyattr] + #[pyclass(module = "ssl", name = "_SSLSocket", traverse)] + #[derive(PyPayload)] + struct PySslSocket { + ctx: PyRwLock>, + #[pytraverse(skip)] + connection: PyRwLock, + #[pytraverse(skip)] + socket_type: SslServerOrClient, + server_hostname: Option, + owner: PyRwLock>>, + } + + impl fmt::Debug for PySslSocket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("_SSLSocket") + } + } + + #[pyclass(flags(IMMUTABLETYPE))] + impl PySslSocket { + #[pygetset] + fn owner(&self) -> Option { + self.owner.read().as_ref().and_then(|weak| weak.upgrade()) + } + #[pygetset(setter)] + fn set_owner(&self, owner: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let mut lock = self.owner.write(); + lock.take(); + *lock = Some(owner.downgrade(None, vm)?); + Ok(()) + } + #[pygetset] + fn server_side(&self) -> bool { + self.socket_type == SslServerOrClient::Server + } + #[pygetset] + fn context(&self) -> PyRef { + self.ctx.read().clone() + } + #[pygetset(setter)] + fn set_context(&self, value: PyRef, vm: &VirtualMachine) -> PyResult<()> { + // Update the SSL context in the underlying SSL object + let stream = self.connection.read(); + + // Set the new SSL_CTX on the SSL object + unsafe { + let result = SSL_set_SSL_CTX(stream.ssl().as_ptr(), value.ctx().as_ptr()); + if result.is_null() { + return Err(vm.new_runtime_error("Failed to set SSL context".to_owned())); + } + } + + // Update self.ctx to the new context + *self.ctx.write() = value; + Ok(()) + } + #[pygetset] + fn server_hostname(&self) -> Option { + self.server_hostname.clone() + } + + #[pymethod] + fn getpeercert( + &self, + binary: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult> { + let binary = binary.unwrap_or(false); + let stream = self.connection.read(); + if !stream.ssl().is_init_finished() { + return Err(vm.new_value_error("handshake not done yet")); + } + + let peer_cert = stream.ssl().peer_certificate(); + let Some(cert) = peer_cert else { + return Ok(None); + }; + + if binary { + // Return DER-encoded certificate + cert_to_py(vm, &cert, true).map(Some) + } else { + // Check verify_mode + unsafe { + let ssl_ctx = sys::SSL_get_SSL_CTX(stream.ssl().as_ptr()); + let verify_mode = sys::SSL_CTX_get_verify_mode(ssl_ctx); + if (verify_mode & sys::SSL_VERIFY_PEER as libc::c_int) == 0 { + // Return empty dict when SSL_VERIFY_PEER is not set + Ok(Some(vm.ctx.new_dict().into())) + } else { + // Return decoded certificate + cert_to_py(vm, &cert, false).map(Some) + } + } + } + } + + #[pymethod] + fn get_unverified_chain(&self, vm: &VirtualMachine) -> PyResult> { + let stream = self.connection.read(); + let Some(chain) = stream.ssl().peer_cert_chain() else { + return Ok(None); + }; + + // Return Certificate objects + let certs: Vec = chain + .iter() + .map(|cert| unsafe { + sys::X509_up_ref(cert.as_ptr()); + let owned = X509::from_ptr(cert.as_ptr()); + cert_to_certificate(vm, owned) + }) + .collect::>()?; + Ok(Some(vm.ctx.new_list(certs))) + } + + #[pymethod] + fn get_verified_chain(&self, vm: &VirtualMachine) -> PyResult> { + let stream = self.connection.read(); + unsafe { + let chain = sys::SSL_get0_verified_chain(stream.ssl().as_ptr()); + if chain.is_null() { + return Ok(None); + } + + let num_certs = sys::OPENSSL_sk_num(chain as *const _); + + let mut certs = Vec::with_capacity(num_certs as usize); + // Return Certificate objects + for i in 0..num_certs { + let cert_ptr = sys::OPENSSL_sk_value(chain as *const _, i) as *mut sys::X509; + if cert_ptr.is_null() { + continue; + } + // Clone the X509 certificate to create an owned copy + sys::X509_up_ref(cert_ptr); + let owned_cert = X509::from_ptr(cert_ptr); + let cert_obj = cert_to_certificate(vm, owned_cert)?; + certs.push(cert_obj); + } + + Ok(if certs.is_empty() { + None + } else { + Some(vm.ctx.new_list(certs)) + }) + } + } + + #[pymethod] + fn version(&self) -> Option<&'static str> { + let v = self.connection.read().ssl().version_str(); + if v == "unknown" { None } else { Some(v) } + } + + #[pymethod] + fn cipher(&self) -> Option { + self.connection + .read() + .ssl() + .current_cipher() + .map(cipher_to_tuple) + } + + #[pymethod] + fn shared_ciphers(&self, vm: &VirtualMachine) -> Option { + #[cfg(ossl110)] + { + let stream = self.connection.read(); + unsafe { + let server_ciphers = SSL_get_ciphers(stream.ssl().as_ptr()); + if server_ciphers.is_null() { + return None; + } + + let client_ciphers = SSL_get_client_ciphers(stream.ssl().as_ptr()); + if client_ciphers.is_null() { + return None; + } + + let mut result = Vec::new(); + let num_server = sys::OPENSSL_sk_num(server_ciphers as *const _); + let num_client = sys::OPENSSL_sk_num(client_ciphers as *const _); + + for i in 0..num_server { + let server_cipher_ptr = sys::OPENSSL_sk_value(server_ciphers as *const _, i) + as *const sys::SSL_CIPHER; + + // Check if client supports this cipher by comparing pointers + let mut found = false; + for j in 0..num_client { + let client_cipher_ptr = + sys::OPENSSL_sk_value(client_ciphers as *const _, j) + as *const sys::SSL_CIPHER; + + if server_cipher_ptr == client_cipher_ptr { + found = true; + break; + } + } + + if found { + let cipher = ssl::SslCipherRef::from_ptr(server_cipher_ptr as *mut _); + let (name, version, bits) = cipher_to_tuple(cipher); + let tuple = vm.new_tuple(( + vm.ctx.new_str(name), + vm.ctx.new_str(version), + vm.ctx.new_int(bits), + )); + result.push(tuple.into()); + } + } + + if result.is_empty() { + None + } else { + Some(vm.ctx.new_list(result)) + } + } + } + #[cfg(not(ossl110))] + { + let _ = vm; + None + } + } + + #[pymethod] + fn selected_alpn_protocol(&self) -> Option { + #[cfg(ossl102)] + { + let stream = self.connection.read(); + unsafe { + let mut out: *const libc::c_uchar = std::ptr::null(); + let mut outlen: libc::c_uint = 0; + + sys::SSL_get0_alpn_selected(stream.ssl().as_ptr(), &mut out, &mut outlen); + + if out.is_null() { + None + } else { + let slice = std::slice::from_raw_parts(out, outlen as usize); + Some(String::from_utf8_lossy(slice).into_owned()) + } + } + } + #[cfg(not(ossl102))] + { + None + } + } + + #[pymethod] + fn get_channel_binding( + &self, + cb_type: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult> { + const CB_MAXLEN: usize = 512; + + let cb_type_str = cb_type.as_ref().map_or("tls-unique", |s| s.as_str()); + + if cb_type_str != "tls-unique" { + return Err(vm.new_value_error(format!( + "Unsupported channel binding type '{}'", + cb_type_str + ))); + } + + let stream = self.connection.read(); + let ssl_ptr = stream.ssl().as_ptr(); + + unsafe { + let session_reused = sys::SSL_session_reused(ssl_ptr) != 0; + let is_client = matches!(self.socket_type, SslServerOrClient::Client); + + // Use XOR logic from CPython + let use_finished = session_reused ^ is_client; + + let mut buf = vec![0u8; CB_MAXLEN]; + let len = if use_finished { + sys::SSL_get_finished(ssl_ptr, buf.as_mut_ptr() as *mut _, CB_MAXLEN) + } else { + sys::SSL_get_peer_finished(ssl_ptr, buf.as_mut_ptr() as *mut _, CB_MAXLEN) + }; + + if len == 0 { + Ok(None) + } else { + buf.truncate(len); + Ok(Some(vm.ctx.new_bytes(buf))) + } + } + } + + #[pymethod] + fn verify_client_post_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { + #[cfg(ossl111)] + { + let stream = self.connection.read(); + let result = unsafe { SSL_verify_client_post_handshake(stream.ssl().as_ptr()) }; + if result == 0 { + Err(convert_openssl_error(vm, openssl::error::ErrorStack::get())) + } else { + Ok(()) + } + } + #[cfg(not(ossl111))] + { + Err(vm.new_not_implemented_error( + "Post-handshake auth is not supported by your OpenSSL version.".to_owned(), + )) + } + } + + #[pymethod] + fn shutdown(&self, vm: &VirtualMachine) -> PyResult> { + let stream = self.connection.read(); + + // BIO mode doesn't have an underlying socket + if stream.is_bio() { + return Err(vm.new_not_implemented_error( + "shutdown() is not supported for BIO-based SSL objects".to_owned(), + )); + } + + let ssl_ptr = stream.ssl().as_ptr(); + + // Perform SSL shutdown + let ret = unsafe { sys::SSL_shutdown(ssl_ptr) }; + + if ret < 0 { + // Error occurred + let err = unsafe { sys::SSL_get_error(ssl_ptr, ret) }; + + if err == sys::SSL_ERROR_WANT_READ || err == sys::SSL_ERROR_WANT_WRITE { + // Non-blocking would block - this is okay for shutdown + // Return the underlying socket + } else { + return Err(vm.new_exception_msg( + PySslError::class(&vm.ctx).to_owned(), + format!("SSL shutdown failed: error code {}", err), + )); + } + } + + // Return the underlying socket + // Get the socket from the stream (SocketStream wraps PyRef) + let socket = stream + .get_ref() + .expect("unwrap() called on bio mode; should only be called in socket mode"); + Ok(socket.0.clone()) + } + + #[cfg(osslconf = "OPENSSL_NO_COMP")] + #[pymethod] + fn compression(&self) -> Option<&'static str> { + None + } + #[cfg(not(osslconf = "OPENSSL_NO_COMP"))] + #[pymethod] + fn compression(&self) -> Option<&'static str> { + let stream = self.connection.read(); + let comp_method = unsafe { sys::SSL_get_current_compression(stream.ssl().as_ptr()) }; + if comp_method.is_null() { + return None; + } + let typ = unsafe { sys::COMP_get_type(comp_method) }; + let nid = Nid::from_raw(typ); + if nid == Nid::UNDEF { + return None; + } + nid.short_name().ok() + } + + #[pymethod] + fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { + let mut stream = self.connection.write(); + let ssl_ptr = stream.ssl().as_ptr(); + + // BIO mode: no timeout/select logic, just do handshake + if stream.is_bio() { + return stream.do_handshake().map_err(|e| { + let exc = convert_ssl_error(vm, e); + // If it's a cert verification error, set verify info + if exc.class().is(PySslCertVerificationError::class(&vm.ctx)) { + set_verify_error_info(&exc, ssl_ptr, vm); + } + exc + }); + } + + // Socket mode: handle timeout and blocking + let timeout = stream + .get_ref() + .expect("handshake called in bio mode; should only be called in socket mode") + .timeout_deadline(); + loop { + let err = match stream.do_handshake() { + Ok(()) => return Ok(()), + Err(e) => e, + }; + let (needs, state) = stream + .get_ref() + .expect("handshake called in bio mode; should only be called in socket mode") + .socket_needs(&err, &timeout); + match state { + SelectRet::TimedOut => { + return Err(socket::timeout_error_msg( + vm, + "The handshake operation timed out".to_owned(), + )); + } + SelectRet::Closed => return Err(socket_closed_error(vm)), + SelectRet::Nonblocking => {} + SelectRet::IsBlocking | SelectRet::Ok => { + // For blocking sockets, select() has completed successfully + // Continue the handshake loop (matches CPython's SOCKET_IS_BLOCKING behavior) + if needs.is_some() { + continue; + } + } + } + let exc = convert_ssl_error(vm, err); + // If it's a cert verification error, set verify info + if exc.class().is(PySslCertVerificationError::class(&vm.ctx)) { + set_verify_error_info(&exc, ssl_ptr, vm); + } + return Err(exc); + } + } + + #[pymethod] + fn write(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult { + let mut stream = self.connection.write(); + let data = data.borrow_buf(); + let data = &*data; + + // BIO mode: no timeout/select logic + if stream.is_bio() { + return stream.ssl_write(data).map_err(|e| convert_ssl_error(vm, e)); + } + + // Socket mode: handle timeout and blocking + let socket_ref = stream + .get_ref() + .expect("write called in bio mode; should only be called in socket mode"); + let timeout = socket_ref.timeout_deadline(); + let state = socket_ref.select(SslNeeds::Write, &timeout); + match state { + SelectRet::TimedOut => { + return Err(socket::timeout_error_msg( + vm, + "The write operation timed out".to_owned(), + )); + } + SelectRet::Closed => return Err(socket_closed_error(vm)), + _ => {} + } + loop { + let err = match stream.ssl_write(data) { + Ok(len) => return Ok(len), + Err(e) => e, + }; + let (needs, state) = stream + .get_ref() + .expect("write called in bio mode; should only be called in socket mode") + .socket_needs(&err, &timeout); + match state { + SelectRet::TimedOut => { + return Err(socket::timeout_error_msg( + vm, + "The write operation timed out".to_owned(), + )); + } + SelectRet::Closed => return Err(socket_closed_error(vm)), + SelectRet::Nonblocking => {} + SelectRet::IsBlocking | SelectRet::Ok => { + // For blocking sockets, select() has completed successfully + // Continue the write loop (matches CPython's SOCKET_IS_BLOCKING behavior) + if needs.is_some() { + continue; + } + } + } + return Err(convert_ssl_error(vm, err)); + } + } + + #[pygetset] + fn session(&self, _vm: &VirtualMachine) -> PyResult> { + let stream = self.connection.read(); + unsafe { + // Use SSL_get1_session which returns an owned reference (ref count already incremented) + let session_ptr = SSL_get1_session(stream.ssl().as_ptr()); + if session_ptr.is_null() { + Ok(None) + } else { + Ok(Some(PySslSession { + session: session_ptr, + ctx: self.ctx.read().clone(), + })) + } + } + } + + #[pygetset(setter)] + fn set_session(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + // Check if value is SSLSession type + let session = value + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("Value is not a SSLSession.".to_owned()))?; + + // Check if session refers to the same SSLContext + if !std::ptr::eq( + self.ctx.read().ctx.read().as_ptr(), + session.ctx.ctx.read().as_ptr(), + ) { + return Err( + vm.new_value_error("Session refers to a different SSLContext.".to_owned()) + ); + } + + // Check if this is a client socket + if self.socket_type != SslServerOrClient::Client { + return Err( + vm.new_value_error("Cannot set session for server-side SSLSocket.".to_owned()) + ); + } + + // Check if handshake is not finished + let stream = self.connection.read(); + unsafe { + if sys::SSL_is_init_finished(stream.ssl().as_ptr()) != 0 { + return Err( + vm.new_value_error("Cannot set session after handshake.".to_owned()) + ); + } + + let ret = sys::SSL_set_session(stream.ssl().as_ptr(), session.session); + if ret == 0 { + return Err(convert_openssl_error(vm, ErrorStack::get())); + } + } + + Ok(()) + } + + #[pygetset] + fn session_reused(&self) -> bool { + let stream = self.connection.read(); + unsafe { sys::SSL_session_reused(stream.ssl().as_ptr()) != 0 } + } + + #[pymethod] + fn read( + &self, + n: usize, + buffer: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + // Special case: reading 0 bytes should return empty bytes immediately + if n == 0 { + return if buffer.is_present() { + Ok(vm.ctx.new_int(0).into()) + } else { + Ok(vm.ctx.new_bytes(vec![]).into()) + }; + } + + let mut stream = self.connection.write(); + let mut inner_buffer = if let OptionalArg::Present(buffer) = &buffer { + Either::A(buffer.borrow_buf_mut()) + } else { + Either::B(vec![0u8; n]) + }; + let buf = match &mut inner_buffer { + Either::A(b) => &mut **b, + Either::B(b) => b.as_mut_slice(), + }; + let buf = match buf.get_mut(..n) { + Some(b) => b, + None => buf, + }; + + // BIO mode: no timeout/select logic + let count = if stream.is_bio() { + match stream.ssl_read(buf) { + Ok(count) => count, + Err(e) => return Err(convert_ssl_error(vm, e)), + } + } else { + // Socket mode: handle timeout and blocking + let timeout = stream + .get_ref() + .expect("read called in bio mode; should only be called in socket mode") + .timeout_deadline(); + loop { + let err = match stream.ssl_read(buf) { + Ok(count) => break count, + Err(e) => e, + }; + if err.code() == ssl::ErrorCode::ZERO_RETURN + && stream.get_shutdown() == ssl::ShutdownState::RECEIVED + { + break 0; + } + let (needs, state) = stream + .get_ref() + .expect("read called in bio mode; should only be called in socket mode") + .socket_needs(&err, &timeout); + match state { + SelectRet::TimedOut => { + return Err(socket::timeout_error_msg( + vm, + "The read operation timed out".to_owned(), + )); + } + SelectRet::Closed => return Err(socket_closed_error(vm)), + SelectRet::Nonblocking => {} + SelectRet::IsBlocking | SelectRet::Ok => { + // For blocking sockets, select() has completed successfully + // Continue the read loop (matches CPython's SOCKET_IS_BLOCKING behavior) + if needs.is_some() { + continue; + } + } + } + return Err(convert_ssl_error(vm, err)); + } + }; + let ret = match inner_buffer { + Either::A(_buf) => vm.ctx.new_int(count).into(), + Either::B(mut buf) => { + buf.truncate(count); + buf.shrink_to_fit(); + vm.ctx.new_bytes(buf).into() + } + }; + Ok(ret) + } + } + + #[pyattr] + #[pyclass(module = "ssl", name = "SSLSession")] + #[derive(PyPayload)] + struct PySslSession { + session: *mut sys::SSL_SESSION, + ctx: PyRef, + } + + impl fmt::Debug for PySslSession { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("SSLSession") + } + } + + impl Drop for PySslSession { + fn drop(&mut self) { + if !self.session.is_null() { + unsafe { + sys::SSL_SESSION_free(self.session); + } + } + } + } + + unsafe impl Send for PySslSession {} + unsafe impl Sync for PySslSession {} + + impl Comparable for PySslSession { + fn cmp( + zelf: &Py, + other: &crate::vm::PyObject, + op: PyComparisonOp, + _vm: &VirtualMachine, + ) -> PyResult { + let other = class_or_notimplemented!(Self, other); + + if !matches!(op, PyComparisonOp::Eq | PyComparisonOp::Ne) { + return Ok(PyComparisonValue::NotImplemented); + } + let mut eq = unsafe { + let mut self_len: libc::c_uint = 0; + let mut other_len: libc::c_uint = 0; + let self_id = sys::SSL_SESSION_get_id(zelf.session, &mut self_len); + let other_id = sys::SSL_SESSION_get_id(other.session, &mut other_len); + + if self_len != other_len { + false + } else { + let self_slice = std::slice::from_raw_parts(self_id, self_len as usize); + let other_slice = std::slice::from_raw_parts(other_id, other_len as usize); + self_slice == other_slice + } + }; + if matches!(op, PyComparisonOp::Ne) { + eq = !eq; + } + Ok(PyComparisonValue::Implemented(eq)) + } + } + + #[pyattr] + #[pyclass(module = "ssl", name = "MemoryBIO")] + #[derive(PyPayload)] + struct PySslMemoryBio { + bio: *mut sys::BIO, + eof_written: AtomicCell, + } + + impl fmt::Debug for PySslMemoryBio { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("MemoryBIO") + } + } + + impl Drop for PySslMemoryBio { + fn drop(&mut self) { + if !self.bio.is_null() { + unsafe { + sys::BIO_free_all(self.bio); + } + } + } + } + + unsafe impl Send for PySslMemoryBio {} + unsafe impl Sync for PySslMemoryBio {} + + // OpenSSL functions not in openssl-sys + + unsafe extern "C" { + // X509_check_ca returns 1 for CA certificates, 0 otherwise + fn X509_check_ca(x: *const sys::X509) -> libc::c_int; + } + + unsafe extern "C" { + fn SSL_get_ciphers(ssl: *const sys::SSL) -> *const sys::stack_st_SSL_CIPHER; + } + + #[cfg(ossl110)] + unsafe extern "C" { + fn SSL_get_client_ciphers(ssl: *const sys::SSL) -> *const sys::stack_st_SSL_CIPHER; + } + + #[cfg(ossl111)] + unsafe extern "C" { + fn SSL_verify_client_post_handshake(ssl: *const sys::SSL) -> libc::c_int; + fn SSL_set_post_handshake_auth(ssl: *mut sys::SSL, val: libc::c_int); + } + + #[cfg(ossl110)] + unsafe extern "C" { + fn SSL_CTX_get_security_level(ctx: *const sys::SSL_CTX) -> libc::c_int; + } + + unsafe extern "C" { + fn SSL_set_SSL_CTX(ssl: *mut sys::SSL, ctx: *mut sys::SSL_CTX) -> *mut sys::SSL_CTX; + } + + // Message callback type + #[allow(non_camel_case_types)] + type SSL_CTX_msg_callback = Option< + unsafe extern "C" fn( + write_p: libc::c_int, + version: libc::c_int, + content_type: libc::c_int, + buf: *const libc::c_void, + len: usize, + ssl: *mut sys::SSL, + arg: *mut libc::c_void, + ), + >; + + unsafe extern "C" { + fn SSL_CTX_set_msg_callback(ctx: *mut sys::SSL_CTX, cb: SSL_CTX_msg_callback); + } + + #[cfg(ossl110)] + unsafe extern "C" { + fn SSL_SESSION_has_ticket(session: *const sys::SSL_SESSION) -> libc::c_int; + fn SSL_SESSION_get_ticket_lifetime_hint(session: *const sys::SSL_SESSION) -> libc::c_ulong; + } + + // X509 object types + const X509_LU_X509: libc::c_int = 1; + const X509_LU_CRL: libc::c_int = 2; + + unsafe extern "C" { + fn X509_OBJECT_get_type(obj: *const sys::X509_OBJECT) -> libc::c_int; + fn SSL_set_session_id_context( + ssl: *mut sys::SSL, + sid_ctx: *const libc::c_uchar, + sid_ctx_len: libc::c_uint, + ) -> libc::c_int; + fn SSL_get1_session(ssl: *const sys::SSL) -> *mut sys::SSL_SESSION; + } + + // SSL session statistics constants (used with SSL_CTX_ctrl) + const SSL_CTRL_SESS_NUMBER: libc::c_int = 20; + const SSL_CTRL_SESS_CONNECT: libc::c_int = 21; + const SSL_CTRL_SESS_CONNECT_GOOD: libc::c_int = 22; + const SSL_CTRL_SESS_CONNECT_RENEGOTIATE: libc::c_int = 23; + const SSL_CTRL_SESS_ACCEPT: libc::c_int = 24; + const SSL_CTRL_SESS_ACCEPT_GOOD: libc::c_int = 25; + const SSL_CTRL_SESS_ACCEPT_RENEGOTIATE: libc::c_int = 26; + const SSL_CTRL_SESS_HIT: libc::c_int = 27; + const SSL_CTRL_SESS_MISSES: libc::c_int = 29; + const SSL_CTRL_SESS_TIMEOUTS: libc::c_int = 30; + const SSL_CTRL_SESS_CACHE_FULL: libc::c_int = 31; + + // SSL session statistics functions (implemented as macros in OpenSSL) + #[allow(non_snake_case)] + unsafe fn SSL_CTX_sess_number(ctx: *const sys::SSL_CTX) -> libc::c_long { + unsafe { sys::SSL_CTX_ctrl(ctx as *mut _, SSL_CTRL_SESS_NUMBER, 0, std::ptr::null_mut()) } + } + + #[allow(non_snake_case)] + unsafe fn SSL_CTX_sess_connect(ctx: *const sys::SSL_CTX) -> libc::c_long { + unsafe { + sys::SSL_CTX_ctrl( + ctx as *mut _, + SSL_CTRL_SESS_CONNECT, + 0, + std::ptr::null_mut(), + ) + } + } + + #[allow(non_snake_case)] + unsafe fn SSL_CTX_sess_connect_good(ctx: *const sys::SSL_CTX) -> libc::c_long { + unsafe { + sys::SSL_CTX_ctrl( + ctx as *mut _, + SSL_CTRL_SESS_CONNECT_GOOD, + 0, + std::ptr::null_mut(), + ) + } + } + + #[allow(non_snake_case)] + unsafe fn SSL_CTX_sess_connect_renegotiate(ctx: *const sys::SSL_CTX) -> libc::c_long { + unsafe { + sys::SSL_CTX_ctrl( + ctx as *mut _, + SSL_CTRL_SESS_CONNECT_RENEGOTIATE, + 0, + std::ptr::null_mut(), + ) + } + } + + #[allow(non_snake_case)] + unsafe fn SSL_CTX_sess_accept(ctx: *const sys::SSL_CTX) -> libc::c_long { + unsafe { sys::SSL_CTX_ctrl(ctx as *mut _, SSL_CTRL_SESS_ACCEPT, 0, std::ptr::null_mut()) } + } + + #[allow(non_snake_case)] + unsafe fn SSL_CTX_sess_accept_good(ctx: *const sys::SSL_CTX) -> libc::c_long { + unsafe { + sys::SSL_CTX_ctrl( + ctx as *mut _, + SSL_CTRL_SESS_ACCEPT_GOOD, + 0, + std::ptr::null_mut(), + ) + } + } + + #[allow(non_snake_case)] + unsafe fn SSL_CTX_sess_accept_renegotiate(ctx: *const sys::SSL_CTX) -> libc::c_long { + unsafe { + sys::SSL_CTX_ctrl( + ctx as *mut _, + SSL_CTRL_SESS_ACCEPT_RENEGOTIATE, + 0, + std::ptr::null_mut(), + ) + } + } + + #[allow(non_snake_case)] + unsafe fn SSL_CTX_sess_hits(ctx: *const sys::SSL_CTX) -> libc::c_long { + unsafe { sys::SSL_CTX_ctrl(ctx as *mut _, SSL_CTRL_SESS_HIT, 0, std::ptr::null_mut()) } + } + + #[allow(non_snake_case)] + unsafe fn SSL_CTX_sess_misses(ctx: *const sys::SSL_CTX) -> libc::c_long { + unsafe { sys::SSL_CTX_ctrl(ctx as *mut _, SSL_CTRL_SESS_MISSES, 0, std::ptr::null_mut()) } + } + + #[allow(non_snake_case)] + unsafe fn SSL_CTX_sess_timeouts(ctx: *const sys::SSL_CTX) -> libc::c_long { + unsafe { + sys::SSL_CTX_ctrl( + ctx as *mut _, + SSL_CTRL_SESS_TIMEOUTS, + 0, + std::ptr::null_mut(), + ) + } + } + + #[allow(non_snake_case)] + unsafe fn SSL_CTX_sess_cache_full(ctx: *const sys::SSL_CTX) -> libc::c_long { + unsafe { + sys::SSL_CTX_ctrl( + ctx as *mut _, + SSL_CTRL_SESS_CACHE_FULL, + 0, + std::ptr::null_mut(), + ) + } + } + + // DH parameters functions + unsafe extern "C" { + fn PEM_read_DHparams( + fp: *mut libc::FILE, + x: *mut *mut sys::DH, + cb: *mut libc::c_void, + u: *mut libc::c_void, + ) -> *mut sys::DH; + } + + // OpenSSL BIO helper functions + // These are typically macros in OpenSSL, implemented via BIO_ctrl + const BIO_CTRL_PENDING: libc::c_int = 10; + const BIO_CTRL_SET_EOF: libc::c_int = 2; + + #[allow(non_snake_case)] + unsafe fn BIO_ctrl_pending(bio: *mut sys::BIO) -> usize { + unsafe { sys::BIO_ctrl(bio, BIO_CTRL_PENDING, 0, std::ptr::null_mut()) as usize } + } + + #[allow(non_snake_case)] + unsafe fn BIO_set_mem_eof_return(bio: *mut sys::BIO, eof: libc::c_int) -> libc::c_int { + unsafe { + sys::BIO_ctrl( + bio, + BIO_CTRL_SET_EOF, + eof as libc::c_long, + std::ptr::null_mut(), + ) as libc::c_int + } + } + + #[allow(non_snake_case)] + unsafe fn BIO_clear_retry_flags(bio: *mut sys::BIO) { + unsafe { + sys::BIO_clear_flags(bio, sys::BIO_FLAGS_RWS | sys::BIO_FLAGS_SHOULD_RETRY); + } + } + + impl Constructor for PySslMemoryBio { + type Args = (); + + fn py_new(cls: PyTypeRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult { + unsafe { + let bio = sys::BIO_new(sys::BIO_s_mem()); + if bio.is_null() { + return Err(vm.new_memory_error("failed to allocate BIO".to_owned())); + } + + sys::BIO_set_retry_read(bio); + BIO_set_mem_eof_return(bio, -1); + + PySslMemoryBio { + bio, + eof_written: AtomicCell::new(false), + } + .into_ref_with_type(vm, cls) + .map(Into::into) + } + } + } + + #[pyclass(flags(IMMUTABLETYPE), with(Constructor))] + impl PySslMemoryBio { + #[pygetset] + fn pending(&self) -> usize { + unsafe { BIO_ctrl_pending(self.bio) } + } + + #[pygetset] + fn eof(&self) -> bool { + let pending = unsafe { BIO_ctrl_pending(self.bio) }; + pending == 0 && self.eof_written.load() + } + + #[pymethod] + fn read(&self, size: OptionalArg, vm: &VirtualMachine) -> PyResult> { + unsafe { + let avail = BIO_ctrl_pending(self.bio).min(i32::MAX as usize) as i32; + let len = size.unwrap_or(-1); + let len = if len < 0 || len > avail { avail } else { len }; + + // Check if EOF has been written and no data available + // This matches CPython's behavior where read() returns b'' when EOF is set + if len == 0 && self.eof_written.load() { + return Ok(Vec::new()); + } + + if len == 0 { + // No data available and no EOF - would block + // Call BIO_read() to get the proper error (SSL_ERROR_WANT_READ) + let mut test_buf = [0u8; 1]; + let nbytes = sys::BIO_read(self.bio, test_buf.as_mut_ptr() as *mut _, 1); + if nbytes < 0 { + return Err(convert_openssl_error(vm, ErrorStack::get())); + } + // Shouldn't reach here, but if we do, return what we got + return Ok(test_buf[..nbytes as usize].to_vec()); + } + + let mut buf = vec![0u8; len as usize]; + let nbytes = sys::BIO_read(self.bio, buf.as_mut_ptr() as *mut _, len); + + if nbytes < 0 { + return Err(convert_openssl_error(vm, ErrorStack::get())); + } + + buf.truncate(nbytes as usize); + Ok(buf) + } + } + + #[pymethod] + fn write(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult { + if self.eof_written.load() { + return Err(vm.new_exception_msg( + PySslError::class(&vm.ctx).to_owned(), + "cannot write() after write_eof()".to_owned(), + )); + } + + data.with_ref(|buf| unsafe { + if buf.len() > i32::MAX as usize { + return Err( + vm.new_overflow_error(format!("string longer than {} bytes", i32::MAX)) + ); + } + + let nbytes = sys::BIO_write(self.bio, buf.as_ptr() as *const _, buf.len() as i32); + if nbytes < 0 { + return Err(convert_openssl_error(vm, ErrorStack::get())); + } + + Ok(nbytes) + }) + } + + #[pymethod] + fn write_eof(&self) { + self.eof_written.store(true); + unsafe { + BIO_clear_retry_flags(self.bio); + BIO_set_mem_eof_return(self.bio, 0); + } + } + } + + #[pyclass(flags(IMMUTABLETYPE), with(Comparable))] + impl PySslSession { + #[pygetset] + fn time(&self) -> i64 { + unsafe { + #[cfg(ossl330)] + { + sys::SSL_SESSION_get_time(self.session) as i64 + } + #[cfg(not(ossl330))] + { + sys::SSL_SESSION_get_time(self.session) as i64 + } + } + } + + #[pygetset] + fn timeout(&self) -> i64 { + unsafe { sys::SSL_SESSION_get_timeout(self.session) as i64 } + } + + #[pygetset] + fn ticket_lifetime_hint(&self) -> u64 { + // SSL_SESSION_get_ticket_lifetime_hint available in OpenSSL 1.1.0+ + #[cfg(ossl110)] + { + unsafe { SSL_SESSION_get_ticket_lifetime_hint(self.session) as u64 } + } + #[cfg(not(ossl110))] + { + // Not available in older OpenSSL versions + 0 + } + } + + #[pygetset] + fn id(&self, vm: &VirtualMachine) -> PyBytesRef { + unsafe { + let mut len: libc::c_uint = 0; + let id_ptr = sys::SSL_SESSION_get_id(self.session, &mut len); + let id_slice = std::slice::from_raw_parts(id_ptr, len as usize); + vm.ctx.new_bytes(id_slice.to_vec()) + } + } + + #[pygetset] + fn has_ticket(&self) -> bool { + // SSL_SESSION_has_ticket available in OpenSSL 1.1.0+ + #[cfg(ossl110)] + { + unsafe { SSL_SESSION_has_ticket(self.session) != 0 } + } + #[cfg(not(ossl110))] + { + // Not available in older OpenSSL versions + false + } + } + } + + #[track_caller] + pub(crate) fn convert_openssl_error( + vm: &VirtualMachine, + err: ErrorStack, + ) -> PyBaseExceptionRef { + match err.errors().last() { + Some(e) => { + // Check if this is a system library error (errno-based) + let lib = sys::ERR_GET_LIB(e.code()); + + if lib == sys::ERR_LIB_SYS { + // A system error is being reported; reason is set to errno + let reason = sys::ERR_GET_REASON(e.code()); + + // errno 2 = ENOENT = FileNotFoundError + let exc_type = if reason == 2 { + vm.ctx.exceptions.file_not_found_error.to_owned() + } else { + vm.ctx.exceptions.os_error.to_owned() + }; + let exc = vm.new_exception(exc_type, vec![vm.ctx.new_int(reason).into()]); + // Set errno attribute explicitly + let _ = exc + .as_object() + .set_attr("errno", vm.ctx.new_int(reason), vm); + return exc; + } + + let caller = std::panic::Location::caller(); + let (file, line) = (caller.file(), caller.line()); + let file = file + .rsplit_once(&['/', '\\'][..]) + .map_or(file, |(_, basename)| basename); + + // Get error codes - same approach as CPython + let lib = sys::ERR_GET_LIB(e.code()); + let reason = sys::ERR_GET_REASON(e.code()); + + // Look up error mnemonic from our static tables + // CPython uses dict lookup: err_codes_to_names[(lib, reason)] + let key = super::ssl_data::encode_error_key(lib, reason); + let errstr = super::ssl_data::ERROR_CODES + .get(&key) + .copied() + .or_else(|| { + // Fallback: use OpenSSL's error string + e.reason() + }) + .unwrap_or("unknown error"); + + // Check if this is a certificate verification error + // ERR_LIB_SSL = 20 (from _ssl_data_300.h) + // SSL_R_CERTIFICATE_VERIFY_FAILED = 134 (from _ssl_data_300.h) + let is_cert_verify_error = lib == 20 && reason == 134; + + // Look up library name from our static table + // CPython uses: lib_codes_to_names[lib] + let lib_name = super::ssl_data::LIBRARY_CODES.get(&(lib as u32)).copied(); + + // Use SSLCertVerificationError for certificate verification failures + let cls = if is_cert_verify_error { + PySslCertVerificationError::class(&vm.ctx).to_owned() + } else { + PySslError::class(&vm.ctx).to_owned() + }; + + // Build message + let msg = if let Some(lib_str) = lib_name { + format!("[{lib_str}] {errstr} ({file}:{line})") + } else { + format!("{errstr} ({file}:{line})") + }; + + // Create exception instance + let reason = sys::ERR_GET_REASON(e.code()); + let exc = vm.new_exception( + cls, + vec![vm.ctx.new_int(reason).into(), vm.ctx.new_str(msg).into()], + ); + + // Set attributes on instance, not class + let exc_obj: PyObjectRef = exc.into(); + + // Set reason attribute (always set, even if just the error string) + let reason_value = vm.ctx.new_str(errstr); + let _ = exc_obj.set_attr("reason", reason_value, vm); + + // Set library attribute (None if not available) + let library_value: PyObjectRef = if let Some(lib_str) = lib_name { + vm.ctx.new_str(lib_str).into() + } else { + vm.ctx.none() + }; + let _ = exc_obj.set_attr("library", library_value, vm); + + // For SSLCertVerificationError, set verify_code and verify_message + // Note: These will be set to None here, and can be updated by the caller + // if they have access to the SSL object + if is_cert_verify_error { + let _ = exc_obj.set_attr("verify_code", vm.ctx.none(), vm); + let _ = exc_obj.set_attr("verify_message", vm.ctx.none(), vm); + } + + // Convert back to PyBaseExceptionRef + exc_obj.downcast().expect( + "exc_obj is created as PyBaseExceptionRef and must downcast successfully", + ) + } + None => { + let cls = PySslError::class(&vm.ctx).to_owned(); + vm.new_exception_empty(cls) + } + } + } + + // Helper function to set verify_code and verify_message on SSLCertVerificationError + fn set_verify_error_info( + exc: &PyBaseExceptionRef, + ssl_ptr: *const sys::SSL, + vm: &VirtualMachine, + ) { + // Get verify result + let verify_code = unsafe { sys::SSL_get_verify_result(ssl_ptr) }; + let verify_code_obj = vm.ctx.new_int(verify_code); + + // Get verify message + let verify_message = unsafe { + let verify_str = sys::X509_verify_cert_error_string(verify_code); + if verify_str.is_null() { + vm.ctx.none() + } else { + let c_str = std::ffi::CStr::from_ptr(verify_str); + vm.ctx.new_str(c_str.to_string_lossy()).into() + } + }; + + let exc_obj = exc.as_object(); + let _ = exc_obj.set_attr("verify_code", verify_code_obj, vm); + let _ = exc_obj.set_attr("verify_message", verify_message, vm); + } + #[track_caller] + fn convert_ssl_error( + vm: &VirtualMachine, + e: impl std::borrow::Borrow, + ) -> PyBaseExceptionRef { + let e = e.borrow(); + let (cls, msg) = match e.code() { + ssl::ErrorCode::WANT_READ => ( + PySslWantReadError::class(&vm.ctx).to_owned(), + "The operation did not complete (read)", + ), + ssl::ErrorCode::WANT_WRITE => ( + PySslWantWriteError::class(&vm.ctx).to_owned(), + "The operation did not complete (write)", + ), + ssl::ErrorCode::SYSCALL => match e.io_error() { + Some(io_err) => return io_err.to_pyexception(vm), + // When no I/O error and OpenSSL error queue is empty, + // this is an EOF in violation of protocol -> SSLEOFError + // Need to set args[0] = SSL_ERROR_EOF for suppress_ragged_eofs check + None => { + return vm.new_exception( + PySslEOFError::class(&vm.ctx).to_owned(), + vec![ + vm.ctx.new_int(SSL_ERROR_EOF).into(), + vm.ctx + .new_str("EOF occurred in violation of protocol") + .into(), + ], + ); + } + }, + ssl::ErrorCode::SSL => { + // Check for OpenSSL 3.0 SSL_R_UNEXPECTED_EOF_WHILE_READING + if let Some(ssl_err) = e.ssl_error() { + // In OpenSSL 3.0+, unexpected EOF is reported as SSL_ERROR_SSL + // with this specific reason code instead of SSL_ERROR_SYSCALL + unsafe { + let err_code = sys::ERR_peek_last_error(); + let reason = sys::ERR_GET_REASON(err_code); + let lib = sys::ERR_GET_LIB(err_code); + if lib == ERR_LIB_SSL && reason == SSL_R_UNEXPECTED_EOF_WHILE_READING { + return vm.new_exception( + PySslEOFError::class(&vm.ctx).to_owned(), + vec![ + vm.ctx.new_int(SSL_ERROR_EOF).into(), + vm.ctx + .new_str("EOF occurred in violation of protocol") + .into(), + ], + ); + } + } + return convert_openssl_error(vm, ssl_err.clone()); + } + ( + PySslError::class(&vm.ctx).to_owned(), + "A failure in the SSL library occurred", + ) + } + _ => ( + PySslError::class(&vm.ctx).to_owned(), + "A failure in the SSL library occurred", + ), + }; + vm.new_exception_msg(cls, msg.to_owned()) + } + + // SSL_FILETYPE_ASN1 part of _add_ca_certs in CPython + fn x509_stack_from_der(der: &[u8]) -> Result, ErrorStack> { + unsafe { + openssl::init(); + let bio = bio::MemBioSlice::new(der)?; + + let mut certs = vec![]; + + loop { + let cert = sys::d2i_X509_bio(bio.as_ptr(), std::ptr::null_mut()); + if cert.is_null() { + break; + } + certs.push(X509::from_ptr(cert)); + } + + if certs.is_empty() { + // No certificates loaded at all + return Err(ErrorStack::get()); + } + + // Successfully loaded at least one certificate from DER data. + // Clear any trailing errors from EOF. + // CPython clears errors when: + // - DER: was_bio_eof is set (EOF reached) + // - PEM: PEM_R_NO_START_LINE error (normal EOF) + // Both cases mean successful completion with loaded certs. + eprintln!( + "[x509_stack_from_der] SUCCESS: Clearing errors and returning {} certs", + certs.len() + ); + sys::ERR_clear_error(); + Ok(certs) + } + } + + type CipherTuple = (&'static str, &'static str, i32); + + fn cipher_to_tuple(cipher: &ssl::SslCipherRef) -> CipherTuple { + (cipher.name(), cipher.version(), cipher.bits().secret) + } + + fn cipher_description(cipher: *const sys::SSL_CIPHER) -> String { + unsafe { + // SSL_CIPHER_description writes up to 128 bytes + let mut buf = vec![0u8; 256]; + let result = sys::SSL_CIPHER_description( + cipher, + buf.as_mut_ptr() as *mut libc::c_char, + buf.len() as i32, + ); + if result.is_null() { + return String::from("No description available"); + } + // Find the null terminator + let len = buf.iter().position(|&c| c == 0).unwrap_or(buf.len()); + String::from_utf8_lossy(&buf[..len]).trim().to_string() + } + } + + impl Read for SocketStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let mut socket: &PySocket = &self.0; + socket.read(buf) + } + } + + impl Write for SocketStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut socket: &PySocket = &self.0; + socket.write(buf) + } + fn flush(&mut self) -> std::io::Result<()> { + let mut socket: &PySocket = &self.0; + socket.flush() + } + } + + #[cfg(target_os = "android")] + mod android { + use super::convert_openssl_error; + use crate::vm::{VirtualMachine, builtins::PyBaseExceptionRef}; + use openssl::{ + ssl::SslContextBuilder, + x509::{X509, store::X509StoreBuilder}, + }; + use std::{ + fs::{File, read_dir}, + io::Read, + path::Path, + }; + + static CERT_DIR: &'static str = "/system/etc/security/cacerts"; + + pub(super) fn load_client_ca_list( + vm: &VirtualMachine, + b: &mut SslContextBuilder, + ) -> Result<(), PyBaseExceptionRef> { + let root = Path::new(CERT_DIR); + if !root.is_dir() { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.file_not_found_error.to_owned(), + CERT_DIR.to_string(), + )); + } + + let mut combined_pem = String::new(); + let entries = read_dir(root) + .map_err(|err| vm.new_os_error(format!("read cert root: {}", err)))?; + for entry in entries { + let entry = + entry.map_err(|err| vm.new_os_error(format!("iter cert root: {}", err)))?; + + let path = entry.path(); + if !path.is_file() { + continue; + } + + File::open(&path) + .and_then(|mut file| file.read_to_string(&mut combined_pem)) + .map_err(|err| { + vm.new_os_error(format!("open cert file {}: {}", path.display(), err)) + })?; + + combined_pem.push('\n'); + } + + let mut store_b = + X509StoreBuilder::new().map_err(|err| convert_openssl_error(vm, err))?; + let x509_vec = X509::stack_from_pem(combined_pem.as_bytes()) + .map_err(|err| convert_openssl_error(vm, err))?; + for x509 in x509_vec { + store_b + .add_cert(x509) + .map_err(|err| convert_openssl_error(vm, err))?; + } + b.set_cert_store(store_b.build()); + + Ok(()) + } + } +} + +#[cfg(not(ossl101))] +#[pymodule(sub)] +mod ossl101 {} + +#[cfg(not(ossl111))] +#[pymodule(sub)] +mod ossl111 {} + +#[cfg(not(windows))] +#[pymodule(sub)] +mod windows {} + +#[allow(non_upper_case_globals)] +#[cfg(ossl101)] +#[pymodule(sub)] +mod ossl101 { + #[pyattr] + use openssl_sys::{ + SSL_OP_NO_COMPRESSION as OP_NO_COMPRESSION, SSL_OP_NO_TLSv1_1 as OP_NO_TLSv1_1, + SSL_OP_NO_TLSv1_2 as OP_NO_TLSv1_2, + }; +} + +#[allow(non_upper_case_globals)] +#[cfg(ossl111)] +#[pymodule(sub)] +mod ossl111 { + #[pyattr] + use openssl_sys::SSL_OP_NO_TLSv1_3 as OP_NO_TLSv1_3; +} + +#[cfg(windows)] +#[pymodule(sub)] +mod windows { + use crate::{ + common::ascii, + vm::{ + PyObjectRef, PyPayload, PyResult, VirtualMachine, + builtins::{PyFrozenSet, PyStrRef}, + convert::ToPyException, + }, + }; + + #[pyfunction] + fn enum_certificates(store_name: PyStrRef, vm: &VirtualMachine) -> PyResult> { + use schannel::{RawPointer, cert_context::ValidUses, cert_store::CertStore}; + use windows_sys::Win32::Security::Cryptography; + + // TODO: check every store for it, not just 2 of them: + // https://github.com/python/cpython/blob/3.8/Modules/_ssl.c#L5603-L5610 + let open_fns = [CertStore::open_current_user, CertStore::open_local_machine]; + let stores = open_fns + .iter() + .filter_map(|open| open(store_name.as_str()).ok()) + .collect::>(); + let certs = stores.iter().flat_map(|s| s.certs()).map(|c| { + let cert = vm.ctx.new_bytes(c.to_der().to_owned()); + let enc_type = unsafe { + let ptr = c.as_ptr() as *const Cryptography::CERT_CONTEXT; + (*ptr).dwCertEncodingType + }; + let enc_type = match enc_type { + Cryptography::X509_ASN_ENCODING => vm.new_pyobj(ascii!("x509_asn")), + Cryptography::PKCS_7_ASN_ENCODING => vm.new_pyobj(ascii!("pkcs_7_asn")), + other => vm.new_pyobj(other), + }; + let usage: PyObjectRef = match c.valid_uses().map_err(|e| e.to_pyexception(vm))? { + ValidUses::All => vm.ctx.new_bool(true).into(), + ValidUses::Oids(oids) => PyFrozenSet::from_iter( + vm, + oids.into_iter().map(|oid| vm.ctx.new_str(oid).into()), + )? + .into_ref(&vm.ctx) + .into(), + }; + Ok(vm.new_tuple((cert, enc_type, usage)).into()) + }); + let certs: Vec = certs.collect::>>()?; + Ok(certs) + } +} + +mod bio { + //! based off rust-openssl's private `bio` module + + use libc::c_int; + use openssl::error::ErrorStack; + use openssl_sys as sys; + use std::marker::PhantomData; + + pub struct MemBioSlice<'a>(*mut sys::BIO, PhantomData<&'a [u8]>); + + impl Drop for MemBioSlice<'_> { + fn drop(&mut self) { + unsafe { + sys::BIO_free_all(self.0); + } + } + } + + impl<'a> MemBioSlice<'a> { + pub fn new(buf: &'a [u8]) -> Result, ErrorStack> { + openssl::init(); + + assert!(buf.len() <= c_int::MAX as usize); + let bio = unsafe { sys::BIO_new_mem_buf(buf.as_ptr() as *const _, buf.len() as c_int) }; + if bio.is_null() { + return Err(ErrorStack::get()); + } + + Ok(MemBioSlice(bio, PhantomData)) + } + + pub fn as_ptr(&self) -> *mut sys::BIO { + self.0 + } + } +} diff --git a/stdlib/src/openssl/cert.rs b/stdlib/src/openssl/cert.rs new file mode 100644 index 00000000000..1139f0e26f0 --- /dev/null +++ b/stdlib/src/openssl/cert.rs @@ -0,0 +1,229 @@ +pub(super) use ssl_cert::{PySSLCertificate, cert_to_certificate, cert_to_py, obj2txt}; + +// Certificate type for SSL module + +#[pymodule(sub)] +pub(crate) mod ssl_cert { + use crate::{ + common::ascii, + vm::{ + PyObjectRef, PyPayload, PyResult, VirtualMachine, + convert::{ToPyException, ToPyObject}, + function::{FsPath, OptionalArg}, + }, + }; + use foreign_types_shared::ForeignTypeRef; + use openssl::{ + asn1::Asn1ObjectRef, + x509::{self, X509, X509Ref}, + }; + use openssl_sys as sys; + use std::fmt; + + // Import constants and error converter from _ssl module + use crate::openssl::_ssl::{ENCODING_DER, ENCODING_PEM, convert_openssl_error}; + + pub(crate) fn obj2txt(obj: &Asn1ObjectRef, no_name: bool) -> Option { + let no_name = i32::from(no_name); + let ptr = obj.as_ptr(); + let b = unsafe { + let buflen = sys::OBJ_obj2txt(std::ptr::null_mut(), 0, ptr, no_name); + assert!(buflen >= 0); + if buflen == 0 { + return None; + } + let buflen = buflen as usize; + let mut buf = Vec::::with_capacity(buflen + 1); + let ret = sys::OBJ_obj2txt( + buf.as_mut_ptr() as *mut libc::c_char, + buf.capacity() as _, + ptr, + no_name, + ); + assert!(ret >= 0); + // SAFETY: OBJ_obj2txt initialized the buffer successfully + buf.set_len(buflen); + buf + }; + let s = String::from_utf8(b) + .unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned()); + Some(s) + } + + #[pyattr] + #[pyclass(module = "ssl", name = "Certificate")] + #[derive(PyPayload)] + pub(crate) struct PySSLCertificate { + cert: X509, + } + + impl fmt::Debug for PySSLCertificate { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("Certificate") + } + } + + #[pyclass] + impl PySSLCertificate { + #[pymethod] + fn public_bytes( + &self, + format: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let format = format.unwrap_or(ENCODING_PEM); + + match format { + ENCODING_DER => { + // DER encoding + let der = self + .cert + .to_der() + .map_err(|e| convert_openssl_error(vm, e))?; + Ok(vm.ctx.new_bytes(der).into()) + } + ENCODING_PEM => { + // PEM encoding + let pem = self + .cert + .to_pem() + .map_err(|e| convert_openssl_error(vm, e))?; + Ok(vm.ctx.new_bytes(pem).into()) + } + _ => Err(vm.new_value_error("Unsupported format")), + } + } + + #[pymethod] + fn get_info(&self, vm: &VirtualMachine) -> PyResult { + cert_to_dict(vm, &self.cert) + } + } + + fn name_to_py(vm: &VirtualMachine, name: &x509::X509NameRef) -> PyResult { + let list = name + .entries() + .map(|entry| { + let txt = obj2txt(entry.object(), false).to_pyobject(vm); + let asn1_str = entry.data(); + let data_bytes = asn1_str.as_slice(); + let data = match std::str::from_utf8(data_bytes) { + Ok(s) => vm.ctx.new_str(s.to_owned()), + Err(_) => vm + .ctx + .new_str(String::from_utf8_lossy(data_bytes).into_owned()), + }; + Ok(vm.new_tuple(((txt, data),)).into()) + }) + .collect::>()?; + Ok(vm.ctx.new_tuple(list).into()) + } + + // Helper to convert X509 to dict (for getpeercert with binary=False) + fn cert_to_dict(vm: &VirtualMachine, cert: &X509Ref) -> PyResult { + let dict = vm.ctx.new_dict(); + + dict.set_item("subject", name_to_py(vm, cert.subject_name())?, vm)?; + dict.set_item("issuer", name_to_py(vm, cert.issuer_name())?, vm)?; + // X.509 version: OpenSSL uses 0-based (0=v1, 1=v2, 2=v3) but Python uses 1-based (1=v1, 2=v2, 3=v3) + dict.set_item("version", vm.new_pyobj(cert.version() + 1), vm)?; + + let serial_num = cert + .serial_number() + .to_bn() + .and_then(|bn| bn.to_hex_str()) + .map_err(|e| convert_openssl_error(vm, e))?; + dict.set_item( + "serialNumber", + vm.ctx.new_str(serial_num.to_owned()).into(), + vm, + )?; + + dict.set_item( + "notBefore", + vm.ctx.new_str(cert.not_before().to_string()).into(), + vm, + )?; + dict.set_item( + "notAfter", + vm.ctx.new_str(cert.not_after().to_string()).into(), + vm, + )?; + + if let Some(names) = cert.subject_alt_names() { + let san: Vec = names + .iter() + .map(|gen_name| { + if let Some(email) = gen_name.email() { + vm.new_tuple((ascii!("email"), email)).into() + } else if let Some(dnsname) = gen_name.dnsname() { + vm.new_tuple((ascii!("DNS"), dnsname)).into() + } else if let Some(ip) = gen_name.ipaddress() { + // Parse IP address properly (IPv4 or IPv6) + let ip_str = if ip.len() == 4 { + // IPv4 + format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]) + } else if ip.len() == 16 { + // IPv6 - format with all zeros visible (not compressed) + let ip_addr = std::net::Ipv6Addr::from(ip[0..16]); + let s = ip_addr.segments(); + format!( + "{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}", + s[0], s[1], s[2], s[3], s[4], s[5], s[6], s[7] + ) + } else { + // Fallback for unexpected length + String::from_utf8_lossy(ip).into_owned() + }; + vm.new_tuple((ascii!("IP Address"), ip_str)).into() + } else if let Some(uri) = gen_name.uri() { + vm.new_tuple((ascii!("URI"), uri)).into() + } else { + // Handle DirName, Registered ID, and othername + // Check if this is a directory name + if let Some(dirname) = gen_name.directory_name() + && let Ok(py_name) = name_to_py(vm, dirname) + { + return vm.new_tuple((ascii!("DirName"), py_name)).into(); + } + + // TODO: Handle Registered ID (GEN_RID) + // CPython implementation uses i2t_ASN1_OBJECT to convert OID + // This requires accessing GENERAL_NAME union which is complex in Rust + // For now, we return for unhandled types + + // For othername and other unsupported types + vm.new_tuple((ascii!("othername"), ascii!(""))) + .into() + } + }) + .collect(); + dict.set_item("subjectAltName", vm.ctx.new_tuple(san).into(), vm)?; + }; + + Ok(dict.into()) + } + + // Helper to create Certificate object from X509 + pub(crate) fn cert_to_certificate(vm: &VirtualMachine, cert: X509) -> PyResult { + Ok(PySSLCertificate { cert }.into_ref(&vm.ctx).into()) + } + + // For getpeercert() - returns bytes or dict depending on binary flag + pub(crate) fn cert_to_py(vm: &VirtualMachine, cert: &X509Ref, binary: bool) -> PyResult { + if binary { + let b = cert.to_der().map_err(|e| convert_openssl_error(vm, e))?; + Ok(vm.ctx.new_bytes(b).into()) + } else { + cert_to_dict(vm, cert) + } + } + + #[pyfunction] + pub(crate) fn _test_decode_cert(path: FsPath, vm: &VirtualMachine) -> PyResult { + let path = path.to_path_buf(vm)?; + let pem = std::fs::read(path).map_err(|e| e.to_pyexception(vm))?; + let x509 = X509::from_pem(&pem).map_err(|e| convert_openssl_error(vm, e))?; + cert_to_py(vm, &x509, false) + } +} diff --git a/stdlib/src/ssl/ssl_data_111.rs b/stdlib/src/openssl/ssl_data_111.rs similarity index 100% rename from stdlib/src/ssl/ssl_data_111.rs rename to stdlib/src/openssl/ssl_data_111.rs diff --git a/stdlib/src/ssl/ssl_data_300.rs b/stdlib/src/openssl/ssl_data_300.rs similarity index 100% rename from stdlib/src/ssl/ssl_data_300.rs rename to stdlib/src/openssl/ssl_data_300.rs diff --git a/stdlib/src/ssl/ssl_data_31.rs b/stdlib/src/openssl/ssl_data_31.rs similarity index 100% rename from stdlib/src/ssl/ssl_data_31.rs rename to stdlib/src/openssl/ssl_data_31.rs diff --git a/stdlib/src/ssl.rs b/stdlib/src/ssl.rs index 9604999d7da..e0019ae4750 100644 --- a/stdlib/src/ssl.rs +++ b/stdlib/src/ssl.rs @@ -1,257 +1,356 @@ -// spell-checker:disable - +// spell-checker: ignore ssleof aesccm aesgcm getblocking setblocking ENDTLS + +//! Pure Rust SSL/TLS implementation using rustls +//! +//! This module provides SSL/TLS support without requiring C dependencies. +//! It implements the Python ssl module API using: +//! - rustls: TLS protocol implementation +//! - x509-parser/x509-cert: Certificate parsing +//! - ring: Cryptographic primitives +//! - rustls-platform-verifier: Platform-native certificate verification +//! +//! DO NOT add openssl dependency here. +//! +//! Warning: This library contains AI-generated code and comments. Do not trust any code or comment without verification. Please have a qualified expert review the code and remove this notice after review. + +// OID (Object Identifier) management module +mod oid; + +// Certificate operations module (parsing, validation, conversion) mod cert; -// Conditional compilation for OpenSSL version-specific error codes -cfg_if::cfg_if! { - if #[cfg(ossl310)] { - // OpenSSL 3.1.0+ - #[path = "ssl/ssl_data_31.rs"] - mod ssl_data; - } else if #[cfg(ossl300)] { - // OpenSSL 3.0.0+ - #[path = "ssl/ssl_data_300.rs"] - mod ssl_data; - } else { - // OpenSSL 1.1.1+ (fallback) - #[path = "ssl/ssl_data_111.rs"] - mod ssl_data; - } -} - -use crate::vm::{PyRef, VirtualMachine, builtins::PyModule}; -use openssl_probe::ProbeResult; +// OpenSSL compatibility layer (abstracts rustls operations) +mod compat; -pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { - // if openssl is vendored, it doesn't know the locations - // of system certificates - cache the probe result now. - #[cfg(openssl_vendored)] - LazyLock::force(&PROBE); - _ssl::make_module(vm) -} - -// define our own copy of ProbeResult so we can handle the vendor case -// easily, without having to have a bunch of cfgs -cfg_if::cfg_if! { - if #[cfg(openssl_vendored)] { - use std::sync::LazyLock; - static PROBE: LazyLock = LazyLock::new(openssl_probe::probe); - fn probe() -> &'static ProbeResult { &PROBE } - } else { - fn probe() -> &'static ProbeResult { - &ProbeResult { cert_file: None, cert_dir: None } - } - } -} +pub(crate) use _ssl::make_module; +#[allow(non_snake_case)] #[allow(non_upper_case_globals)] -#[pymodule(with(cert::ssl_cert, ossl101, ossl111, windows))] +#[pymodule] mod _ssl { - use super::{bio, probe}; use crate::{ - common::lock::{ - PyMappedRwLockReadGuard, PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, + common::{ + hash::PyHash, + lock::{PyMutex, PyRwLock}, }, - socket::{self, PySocket}, + socket::{PySocket, SelectKind, sock_select, timeout_error_msg}, vm::{ - AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, - builtins::{ - PyBaseExceptionRef, PyBytesRef, PyListRef, PyOSError, PyStrRef, PyTypeRef, PyWeak, - }, - class_or_notimplemented, - convert::ToPyException, - exceptions, - function::{ - ArgBytesLike, ArgCallable, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath, - OptionalArg, PyComparisonValue, - }, - types::{Comparable, Constructor, PyComparisonOp}, - utils::ToCString, + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, + VirtualMachine, + builtins::{PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyTypeRef}, + convert::IntoPyException, + function::{ArgBytesLike, ArgMemoryBuffer, OptionalArg, PyComparisonValue}, + stdlib::warnings, + types::{Comparable, Constructor, Hashable, PyComparisonOp, Representable}, }, }; - use crossbeam_utils::atomic::AtomicCell; - use foreign_types_shared::{ForeignType, ForeignTypeRef}; - use openssl::{ - asn1::{Asn1Object, Asn1ObjectRef}, - error::ErrorStack, - nid::Nid, - ssl::{self, SslContextBuilder, SslOptions, SslVerifyMode}, - x509::X509, - }; - use openssl_sys as sys; - use rustpython_vm::ospath::OsPath; use std::{ - ffi::CStr, - fmt, - io::{Read, Write}, - path::{Path, PathBuf}, - sync::LazyLock, - time::Instant, + collections::HashMap, + io::Write, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + time::{Duration, SystemTime}, + }; + + // Rustls imports + use parking_lot::{Mutex as ParkingMutex, RwLock as ParkingRwLock}; + use pem_rfc7468::{LineEnding, encode_string}; + use rustls::{ + ClientConfig, ClientConnection, RootCertStore, ServerConfig, ServerConnection, + client::{ClientSessionMemoryCache, ClientSessionStore}, + crypto::SupportedKxGroup, + pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName}, + server::{ClientHello, ResolvesServerCert}, + sign::CertifiedKey, + version::{TLS12, TLS13}, }; + use sha2::{Digest, Sha256}; + + // Import certificate operations module + use super::cert; - // Import certificate types from parent module - use super::cert::{self, cert_to_certificate, cert_to_py}; - - // Re-export PySSLCertificate to make it available in the _ssl module - // It will be automatically exposed to Python via #[pyclass] - #[allow(unused_imports)] - use super::cert::PySSLCertificate; - - // Constants - #[pyattr] - use sys::{ - // TODO: so many more of these - SSL_AD_DECODE_ERROR as ALERT_DESCRIPTION_DECODE_ERROR, - SSL_AD_ILLEGAL_PARAMETER as ALERT_DESCRIPTION_ILLEGAL_PARAMETER, - SSL_AD_UNRECOGNIZED_NAME as ALERT_DESCRIPTION_UNRECOGNIZED_NAME, - // SSL_ERROR_INVALID_ERROR_CODE, - SSL_ERROR_SSL, - // SSL_ERROR_WANT_X509_LOOKUP, - SSL_ERROR_SYSCALL, - SSL_ERROR_WANT_CONNECT, - SSL_ERROR_WANT_READ, - SSL_ERROR_WANT_WRITE, - SSL_ERROR_ZERO_RETURN, - SSL_OP_CIPHER_SERVER_PREFERENCE as OP_CIPHER_SERVER_PREFERENCE, - SSL_OP_ENABLE_MIDDLEBOX_COMPAT as OP_ENABLE_MIDDLEBOX_COMPAT, - SSL_OP_LEGACY_SERVER_CONNECT as OP_LEGACY_SERVER_CONNECT, - SSL_OP_NO_SSLv2 as OP_NO_SSLv2, - SSL_OP_NO_SSLv3 as OP_NO_SSLv3, - SSL_OP_NO_TICKET as OP_NO_TICKET, - SSL_OP_NO_TLSv1 as OP_NO_TLSv1, - SSL_OP_SINGLE_DH_USE as OP_SINGLE_DH_USE, - SSL_OP_SINGLE_ECDH_USE as OP_SINGLE_ECDH_USE, - X509_V_FLAG_ALLOW_PROXY_CERTS as VERIFY_ALLOW_PROXY_CERTS, - X509_V_FLAG_CRL_CHECK as VERIFY_CRL_CHECK_LEAF, - X509_V_FLAG_PARTIAL_CHAIN as VERIFY_X509_PARTIAL_CHAIN, - X509_V_FLAG_TRUSTED_FIRST as VERIFY_X509_TRUSTED_FIRST, - X509_V_FLAG_X509_STRICT as VERIFY_X509_STRICT, + // Import OID module + use super::oid; + + // Import compat module (OpenSSL compatibility layer) + use super::compat::{ + ClientConfigOptions, MultiCertResolver, ProtocolSettings, ServerConfigOptions, SslError, + TlsConnection, create_client_config, create_server_config, curve_name_to_kx_group, + extract_cipher_info, get_cipher_encryption_desc, is_blocking_io_error, + normalize_cipher_name, ssl_do_handshake, }; - // CRL verification constants + // Type aliases for better readability + // Additional type alias for certificate/key pairs (SessionCache and SniCertName defined below) + + /// Certificate and private key pair used in SSL contexts + type CertKeyPair = (Arc, PrivateKeyDer<'static>); + + // Constants matching Python ssl module + + // SSL/TLS Protocol versions + #[pyattr] + const PROTOCOL_TLS: i32 = 2; // Auto-negotiate best version + #[pyattr] + const PROTOCOL_SSLv23: i32 = PROTOCOL_TLS; // Alias for PROTOCOL_TLS #[pyattr] - const VERIFY_CRL_CHECK_CHAIN: libc::c_ulong = - sys::X509_V_FLAG_CRL_CHECK | sys::X509_V_FLAG_CRL_CHECK_ALL; + const PROTOCOL_TLS_CLIENT: i32 = 16; + #[pyattr] + const PROTOCOL_TLS_SERVER: i32 = 17; - // taken from CPython, should probably be kept up to date with their version if it ever changes + // Note: rustls doesn't support TLS 1.0/1.1 for security reasons + // These are defined for API compatibility but will raise errors if used #[pyattr] - const _DEFAULT_CIPHERS: &str = - "DEFAULT:!aNULL:!eNULL:!MD5:!3DES:!DES:!RC4:!IDEA:!SEED:!aDSS:!SRP:!PSK"; - // #[pyattr] PROTOCOL_SSLv2: u32 = SslVersion::Ssl2 as u32; // unsupported - // #[pyattr] PROTOCOL_SSLv3: u32 = SslVersion::Ssl3 as u32; + const PROTOCOL_TLSv1: i32 = 3; #[pyattr] - const PROTOCOL_SSLv23: u32 = SslVersion::Tls as u32; + const PROTOCOL_TLSv1_1: i32 = 4; #[pyattr] - const PROTOCOL_TLS: u32 = SslVersion::Tls as u32; + const PROTOCOL_TLSv1_2: i32 = 5; #[pyattr] - const PROTOCOL_TLS_CLIENT: u32 = SslVersion::TlsClient as u32; + const PROTOCOL_TLSv1_3: i32 = 6; + + // Protocol version constants for TLSVersion enum #[pyattr] - const PROTOCOL_TLS_SERVER: u32 = SslVersion::TlsServer as u32; + const PROTO_SSLv3: i32 = 0x0300; #[pyattr] - const PROTOCOL_TLSv1: u32 = SslVersion::Tls1 as u32; + const PROTO_TLSv1: i32 = 0x0301; #[pyattr] - const PROTOCOL_TLSv1_1: u32 = SslVersion::Tls1_1 as u32; + const PROTO_TLSv1_1: i32 = 0x0302; #[pyattr] - const PROTOCOL_TLSv1_2: u32 = SslVersion::Tls1_2 as u32; + const PROTO_TLSv1_2: i32 = 0x0303; #[pyattr] - const PROTO_MINIMUM_SUPPORTED: i32 = ProtoVersion::MinSupported as i32; + const PROTO_TLSv1_3: i32 = 0x0304; + + // Minimum and maximum supported protocol versions for rustls + // Use special values -2 and -1 to avoid enum name conflicts #[pyattr] - const PROTO_SSLv3: i32 = ProtoVersion::Ssl3 as i32; + const PROTO_MINIMUM_SUPPORTED: i32 = -2; // special value #[pyattr] - const PROTO_TLSv1: i32 = ProtoVersion::Tls1 as i32; + const PROTO_MAXIMUM_SUPPORTED: i32 = -1; // special value + + // Internal constants for rustls actual supported versions + // rustls only supports TLS 1.2 and TLS 1.3 + const MINIMUM_VERSION: i32 = PROTO_TLSv1_2; // 0x0303 + const MAXIMUM_VERSION: i32 = PROTO_TLSv1_3; // 0x0304 + + // Buffer sizes and limits (OpenSSL/CPython compatibility) + const PEM_BUFSIZE: usize = 1024; + // OpenSSL: ssl/ssl_local.h + const SSL3_RT_MAX_PLAIN_LENGTH: usize = 16384; + // SSL session cache size (common practice, similar to OpenSSL defaults) + const SSL_SESSION_CACHE_SIZE: usize = 256; + + // Certificate verification modes #[pyattr] - const PROTO_TLSv1_1: i32 = ProtoVersion::Tls1_1 as i32; + const CERT_NONE: i32 = 0; #[pyattr] - const PROTO_TLSv1_2: i32 = ProtoVersion::Tls1_2 as i32; + const CERT_OPTIONAL: i32 = 1; #[pyattr] - const PROTO_TLSv1_3: i32 = ProtoVersion::Tls1_3 as i32; + const CERT_REQUIRED: i32 = 2; + + // Certificate requirements #[pyattr] - const PROTO_MAXIMUM_SUPPORTED: i32 = ProtoVersion::MaxSupported as i32; + const VERIFY_DEFAULT: i32 = 0; #[pyattr] - const OP_ALL: libc::c_ulong = (sys::SSL_OP_ALL & !sys::SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS) as _; + const VERIFY_CRL_CHECK_LEAF: i32 = 4; #[pyattr] - const HAS_TLS_UNIQUE: bool = true; + const VERIFY_CRL_CHECK_CHAIN: i32 = 12; #[pyattr] - const CERT_NONE: u32 = CertRequirements::None as u32; + const VERIFY_X509_STRICT: i32 = 32; #[pyattr] - const CERT_OPTIONAL: u32 = CertRequirements::Optional as u32; + const VERIFY_ALLOW_PROXY_CERTS: i32 = 64; #[pyattr] - const CERT_REQUIRED: u32 = CertRequirements::Required as u32; + const VERIFY_X509_TRUSTED_FIRST: i32 = 32768; #[pyattr] - const VERIFY_DEFAULT: u32 = 0; + const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000; + + // Options (OpenSSL-compatible flags, mostly no-op in rustls) #[pyattr] - const SSL_ERROR_EOF: u32 = 8; // custom for python + const OP_NO_SSLv2: i32 = 0x00000000; // Not supported anyway #[pyattr] - const HAS_SNI: bool = true; + const OP_NO_SSLv3: i32 = 0x02000000; #[pyattr] - const HAS_ECDH: bool = true; + const OP_NO_TLSv1: i32 = 0x04000000; #[pyattr] - const HAS_NPN: bool = false; + const OP_NO_TLSv1_1: i32 = 0x10000000; #[pyattr] - const HAS_ALPN: bool = true; + const OP_NO_TLSv1_2: i32 = 0x08000000; #[pyattr] - const HAS_SSLv2: bool = false; + const OP_NO_TLSv1_3: i32 = 0x20000000; #[pyattr] - const HAS_SSLv3: bool = false; + const OP_NO_COMPRESSION: i32 = 0x00020000; + #[pyattr] + const OP_CIPHER_SERVER_PREFERENCE: i32 = 0x00400000; + #[pyattr] + const OP_SINGLE_DH_USE: i32 = 0x00000000; // No-op in rustls #[pyattr] - const HAS_TLSv1: bool = true; + const OP_SINGLE_ECDH_USE: i32 = 0x00000000; // No-op in rustls #[pyattr] - const HAS_TLSv1_1: bool = true; + const OP_NO_TICKET: i32 = 0x00004000; #[pyattr] - const HAS_TLSv1_2: bool = true; + const OP_LEGACY_SERVER_CONNECT: i32 = 0x00000004; #[pyattr] - const HAS_TLSv1_3: bool = cfg!(ossl111); + const OP_NO_RENEGOTIATION: i32 = 0x40000000; #[pyattr] - const HAS_PSK: bool = true; + const OP_IGNORE_UNEXPECTED_EOF: i32 = 0x00000080; + #[pyattr] + const OP_ENABLE_MIDDLEBOX_COMPAT: i32 = 0x00100000; + #[pyattr] + const OP_ALL: i32 = 0x00000BFB; // Combined "safe" options (reduced for i32, excluding OP_LEGACY_SERVER_CONNECT for OpenSSL 3.0.0+ compatibility) - // Encoding constants for Certificate.public_bytes() + // Error types + #[pyattr] + const SSL_ERROR_NONE: i32 = 0; + #[pyattr] + const SSL_ERROR_SSL: i32 = 1; #[pyattr] - pub(crate) const ENCODING_PEM: i32 = sys::X509_FILETYPE_PEM; + const SSL_ERROR_WANT_READ: i32 = 2; #[pyattr] - pub(crate) const ENCODING_DER: i32 = sys::X509_FILETYPE_ASN1; + const SSL_ERROR_WANT_WRITE: i32 = 3; #[pyattr] - const ENCODING_PEM_AUX: i32 = sys::X509_FILETYPE_PEM + 0x100; + const SSL_ERROR_WANT_X509_LOOKUP: i32 = 4; + #[pyattr] + const SSL_ERROR_SYSCALL: i32 = 5; + #[pyattr] + const SSL_ERROR_ZERO_RETURN: i32 = 6; + #[pyattr] + const SSL_ERROR_WANT_CONNECT: i32 = 7; + #[pyattr] + const SSL_ERROR_EOF: i32 = 8; + #[pyattr] + const SSL_ERROR_INVALID_ERROR_CODE: i32 = 10; - // OpenSSL error codes for unexpected EOF detection - const ERR_LIB_SSL: i32 = 20; - const SSL_R_UNEXPECTED_EOF_WHILE_READING: i32 = 294; + // Alert types (matching _TLSAlertType enum) + #[pyattr] + const ALERT_DESCRIPTION_CLOSE_NOTIFY: i32 = 0; + #[pyattr] + const ALERT_DESCRIPTION_UNEXPECTED_MESSAGE: i32 = 10; + #[pyattr] + const ALERT_DESCRIPTION_BAD_RECORD_MAC: i32 = 20; + #[pyattr] + const ALERT_DESCRIPTION_DECRYPTION_FAILED: i32 = 21; + #[pyattr] + const ALERT_DESCRIPTION_RECORD_OVERFLOW: i32 = 22; + #[pyattr] + const ALERT_DESCRIPTION_DECOMPRESSION_FAILURE: i32 = 30; + #[pyattr] + const ALERT_DESCRIPTION_HANDSHAKE_FAILURE: i32 = 40; + #[pyattr] + const ALERT_DESCRIPTION_NO_CERTIFICATE: i32 = 41; + #[pyattr] + const ALERT_DESCRIPTION_BAD_CERTIFICATE: i32 = 42; + #[pyattr] + const ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE: i32 = 43; + #[pyattr] + const ALERT_DESCRIPTION_CERTIFICATE_REVOKED: i32 = 44; + #[pyattr] + const ALERT_DESCRIPTION_CERTIFICATE_EXPIRED: i32 = 45; + #[pyattr] + const ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN: i32 = 46; + #[pyattr] + const ALERT_DESCRIPTION_ILLEGAL_PARAMETER: i32 = 47; + #[pyattr] + const ALERT_DESCRIPTION_UNKNOWN_CA: i32 = 48; + #[pyattr] + const ALERT_DESCRIPTION_ACCESS_DENIED: i32 = 49; + #[pyattr] + const ALERT_DESCRIPTION_DECODE_ERROR: i32 = 50; + #[pyattr] + const ALERT_DESCRIPTION_DECRYPT_ERROR: i32 = 51; + #[pyattr] + const ALERT_DESCRIPTION_EXPORT_RESTRICTION: i32 = 60; + #[pyattr] + const ALERT_DESCRIPTION_PROTOCOL_VERSION: i32 = 70; + #[pyattr] + const ALERT_DESCRIPTION_INSUFFICIENT_SECURITY: i32 = 71; + #[pyattr] + const ALERT_DESCRIPTION_INTERNAL_ERROR: i32 = 80; + #[pyattr] + const ALERT_DESCRIPTION_INAPPROPRIATE_FALLBACK: i32 = 86; + #[pyattr] + const ALERT_DESCRIPTION_USER_CANCELLED: i32 = 90; + #[pyattr] + const ALERT_DESCRIPTION_NO_RENEGOTIATION: i32 = 100; + #[pyattr] + const ALERT_DESCRIPTION_MISSING_EXTENSION: i32 = 109; + #[pyattr] + const ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION: i32 = 110; + #[pyattr] + const ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE: i32 = 111; + #[pyattr] + const ALERT_DESCRIPTION_UNRECOGNIZED_NAME: i32 = 112; + #[pyattr] + const ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE: i32 = 113; + #[pyattr] + const ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE: i32 = 114; + #[pyattr] + const ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY: i32 = 115; + #[pyattr] + const ALERT_DESCRIPTION_CERTIFICATE_REQUIRED: i32 = 116; + #[pyattr] + const ALERT_DESCRIPTION_NO_APPLICATION_PROTOCOL: i32 = 120; - // SSL_VERIFY constants for post-handshake authentication - #[cfg(ossl111)] - const SSL_VERIFY_POST_HANDSHAKE: libc::c_int = 0x20; + // Version info - reporting as OpenSSL 3.3.0 for compatibility + #[pyattr] + const OPENSSL_VERSION_NUMBER: i32 = 0x30300000; // OpenSSL 3.3.0 (808452096) + #[pyattr] + const OPENSSL_VERSION: &str = "OpenSSL 3.3.0 (rustls/0.23)"; + #[pyattr] + const OPENSSL_VERSION_INFO: (i32, i32, i32, i32, i32) = (3, 3, 0, 0, 15); // 3.3.0 release + #[pyattr] + const _OPENSSL_API_VERSION: (i32, i32, i32, i32, i32) = (3, 3, 0, 0, 15); // 3.3.0 release - // the openssl version from the API headers + // Default cipher list for rustls - using modern secure ciphers + #[pyattr] + const _DEFAULT_CIPHERS: &str = + "TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256:TLS_CHACHA20_POLY1305_SHA256"; - #[pyattr(name = "OPENSSL_VERSION")] - fn openssl_version(_vm: &VirtualMachine) -> &str { - openssl::version::version() - } - #[pyattr(name = "OPENSSL_VERSION_NUMBER")] - fn openssl_version_number(_vm: &VirtualMachine) -> i64 { - openssl::version::number() - } - #[pyattr(name = "OPENSSL_VERSION_INFO")] - fn openssl_version_info(_vm: &VirtualMachine) -> OpensslVersionInfo { - parse_version_info(openssl::version::number()) - } + // Has features + #[pyattr] + const HAS_SNI: bool = true; + #[pyattr] + const HAS_TLS_UNIQUE: bool = false; // Not supported + #[pyattr] + const HAS_ECDH: bool = true; + #[pyattr] + const HAS_NPN: bool = false; // Deprecated, use ALPN + #[pyattr] + const HAS_ALPN: bool = true; + #[pyattr] + const HAS_PSK: bool = false; // PSK not supported in rustls + #[pyattr] + const HAS_SSLv2: bool = false; + #[pyattr] + const HAS_SSLv3: bool = false; + #[pyattr] + const HAS_TLSv1: bool = false; // Not supported for security + #[pyattr] + const HAS_TLSv1_1: bool = false; // Not supported for security + #[pyattr] + const HAS_TLSv1_2: bool = true; // rustls supports TLS 1.2 + #[pyattr] + const HAS_TLSv1_3: bool = true; - #[pyattr(name = "_OPENSSL_API_VERSION")] - fn _openssl_api_version(_vm: &VirtualMachine) -> OpensslVersionInfo { - let openssl_api_version = i64::from_str_radix(env!("OPENSSL_API_VERSION"), 16) - .expect("OPENSSL_API_VERSION is malformed"); - parse_version_info(openssl_api_version) - } + // Encoding constants (matching OpenSSL) + #[pyattr] + const ENCODING_PEM: i32 = 1; + #[pyattr] + const ENCODING_DER: i32 = 2; + #[pyattr] + const ENCODING_PEM_AUX: i32 = 0x101; // PEM + 0x100 - // SSL Exception Types + // Exception types + use rustpython_vm::builtins::PyOSError; - /// An error occurred in the SSL implementation. #[pyattr] #[pyexception(name = "SSLError", base = PyOSError)] #[derive(Debug)] - pub struct PySslError {} + pub struct PySSLError {} #[pyexception] - impl PySslError { + impl PySSLError { // Returns strerror attribute if available, otherwise str(args) #[pymethod] fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { @@ -263,890 +362,973 @@ mod _ssl { } // Otherwise return str(args) - exc.args().as_object().str(vm) + let args = exc.args(); + if args.len() == 1 { + args.as_slice()[0].str(vm) + } else { + args.as_object().str(vm) + } } } - /// A certificate could not be verified. #[pyattr] - #[pyexception(name = "SSLCertVerificationError", base = PySslError)] + #[pyexception(name = "SSLZeroReturnError", base = PySSLError)] #[derive(Debug)] - pub struct PySslCertVerificationError {} + pub struct PySSLZeroReturnError {} #[pyexception] - impl PySslCertVerificationError {} + impl PySSLZeroReturnError {} - /// SSL/TLS session closed cleanly. #[pyattr] - #[pyexception(name = "SSLZeroReturnError", base = PySslError)] + #[pyexception(name = "SSLWantReadError", base = PySSLError)] #[derive(Debug)] - pub struct PySslZeroReturnError {} + pub struct PySSLWantReadError {} #[pyexception] - impl PySslZeroReturnError {} + impl PySSLWantReadError {} - /// Non-blocking SSL socket needs to read more data. #[pyattr] - #[pyexception(name = "SSLWantReadError", base = PySslError)] + #[pyexception(name = "SSLWantWriteError", base = PySSLError)] #[derive(Debug)] - pub struct PySslWantReadError {} + pub struct PySSLWantWriteError {} #[pyexception] - impl PySslWantReadError {} + impl PySSLWantWriteError {} - /// Non-blocking SSL socket needs to write more data. #[pyattr] - #[pyexception(name = "SSLWantWriteError", base = PySslError)] + #[pyexception(name = "SSLSyscallError", base = PySSLError)] #[derive(Debug)] - pub struct PySslWantWriteError {} + pub struct PySSLSyscallError {} #[pyexception] - impl PySslWantWriteError {} + impl PySSLSyscallError {} - /// System error when attempting SSL operation. #[pyattr] - #[pyexception(name = "SSLSyscallError", base = PySslError)] + #[pyexception(name = "SSLEOFError", base = PySSLError)] #[derive(Debug)] - pub struct PySslSyscallError {} + pub struct PySSLEOFError {} #[pyexception] - impl PySslSyscallError {} + impl PySSLEOFError {} - /// SSL/TLS connection terminated abruptly. #[pyattr] - #[pyexception(name = "SSLEOFError", base = PySslError)] + #[pyexception(name = "SSLCertVerificationError", base = PySSLError)] #[derive(Debug)] - pub struct PySslEOFError {} + pub struct PySSLCertVerificationError {} #[pyexception] - impl PySslEOFError {} - - type OpensslVersionInfo = (u8, u8, u8, u8, u8); - const fn parse_version_info(mut n: i64) -> OpensslVersionInfo { - let status = (n & 0xF) as u8; - n >>= 4; - let patch = (n & 0xFF) as u8; - n >>= 8; - let fix = (n & 0xFF) as u8; - n >>= 8; - let minor = (n & 0xFF) as u8; - n >>= 8; - let major = (n & 0xFF) as u8; - (major, minor, fix, patch, status) - } - - #[derive(Copy, Clone, num_enum::IntoPrimitive, num_enum::TryFromPrimitive, PartialEq)] - #[repr(i32)] - enum SslVersion { - Ssl2, - Ssl3 = 1, - Tls, - Tls1, - Tls1_1, - Tls1_2, - TlsClient = 0x10, - TlsServer, + impl PySSLCertVerificationError {} + + // Helper functions to create SSL exceptions with proper errno attribute + pub(super) fn create_ssl_want_read_error(vm: &VirtualMachine) -> PyBaseExceptionRef { + // args = (errno, message) + vm.new_exception( + PySSLWantReadError::class(&vm.ctx).to_owned(), + vec![ + vm.ctx.new_int(SSL_ERROR_WANT_READ).into(), + vm.ctx + .new_str("The operation did not complete (read)") + .into(), + ], + ) } - #[derive(Copy, Clone, num_enum::IntoPrimitive, num_enum::TryFromPrimitive)] - #[repr(i32)] - enum ProtoVersion { - MinSupported = -2, - Ssl3 = sys::SSL3_VERSION, - Tls1 = sys::TLS1_VERSION, - Tls1_1 = sys::TLS1_1_VERSION, - Tls1_2 = sys::TLS1_2_VERSION, - #[cfg(ossl111)] - Tls1_3 = sys::TLS1_3_VERSION, - #[cfg(not(ossl111))] - Tls1_3 = 0x304, - MaxSupported = -1, + pub(super) fn create_ssl_want_write_error(vm: &VirtualMachine) -> PyBaseExceptionRef { + // args = (errno, message) + vm.new_exception( + PySSLWantWriteError::class(&vm.ctx).to_owned(), + vec![ + vm.ctx.new_int(SSL_ERROR_WANT_WRITE).into(), + vm.ctx + .new_str("The operation did not complete (write)") + .into(), + ], + ) } - #[derive(num_enum::IntoPrimitive, num_enum::TryFromPrimitive)] - #[repr(i32)] - enum CertRequirements { - None, - Optional, - Required, + pub(crate) fn create_ssl_eof_error(vm: &VirtualMachine) -> PyBaseExceptionRef { + vm.new_exception_msg( + PySSLEOFError::class(&vm.ctx).to_owned(), + "EOF occurred in violation of protocol".to_owned(), + ) } - #[derive(Debug, PartialEq)] - enum SslServerOrClient { - Client, - Server, + pub(crate) fn create_ssl_zero_return_error(vm: &VirtualMachine) -> PyBaseExceptionRef { + vm.new_exception_msg( + PySSLZeroReturnError::class(&vm.ctx).to_owned(), + "TLS/SSL connection has been closed (EOF)".to_owned(), + ) } - unsafe fn ptr2obj(ptr: *mut sys::ASN1_OBJECT) -> Option { - if ptr.is_null() { - None - } else { - Some(unsafe { Asn1Object::from_ptr(ptr) }) + /// Validate server hostname for TLS SNI + /// + /// Checks that the hostname: + /// - Is not empty + /// - Does not start with a dot + /// - Is not an IP address (SNI requires DNS names) + /// - Does not contain null bytes + /// - Does not exceed 253 characters (DNS limit) + /// + /// Returns Ok(()) if validation passes, or an appropriate error. + fn validate_hostname(hostname: &str, vm: &VirtualMachine) -> PyResult<()> { + if hostname.is_empty() { + return Err(vm.new_value_error("server_hostname cannot be an empty string")); } - } - - fn _txt2obj(s: &CStr, no_name: bool) -> Option { - unsafe { ptr2obj(sys::OBJ_txt2obj(s.as_ptr(), i32::from(no_name))) } - } - fn _nid2obj(nid: Nid) -> Option { - unsafe { ptr2obj(sys::OBJ_nid2obj(nid.as_raw())) } - } - - type PyNid = (libc::c_int, String, String, Option); - fn obj2py(obj: &Asn1ObjectRef, vm: &VirtualMachine) -> PyResult { - let nid = obj.nid(); - let short_name = nid - .short_name() - .map_err(|_| vm.new_value_error("NID has no short name".to_owned()))? - .to_owned(); - let long_name = nid - .long_name() - .map_err(|_| vm.new_value_error("NID has no long name".to_owned()))? - .to_owned(); - Ok(( - nid.as_raw(), - short_name, - long_name, - cert::obj2txt(obj, true), - )) - } - - #[derive(FromArgs)] - struct Txt2ObjArgs { - txt: PyStrRef, - #[pyarg(any, default = false)] - name: bool, - } - #[pyfunction] - fn txt2obj(args: Txt2ObjArgs, vm: &VirtualMachine) -> PyResult { - _txt2obj(&args.txt.to_cstring(vm)?, !args.name) - .as_deref() - .ok_or_else(|| vm.new_value_error(format!("unknown object '{}'", args.txt))) - .and_then(|obj| obj2py(obj, vm)) - } + if hostname.starts_with('.') { + return Err(vm.new_value_error("server_hostname cannot start with a dot")); + } - #[pyfunction] - fn nid2obj(nid: libc::c_int, vm: &VirtualMachine) -> PyResult { - _nid2obj(Nid::from_raw(nid)) - .as_deref() - .ok_or_else(|| vm.new_value_error(format!("unknown NID {nid}"))) - .and_then(|obj| obj2py(obj, vm)) - } + if hostname.parse::().is_ok() { + return Err(vm.new_value_error("server_hostname cannot be an IP address")); + } - // Lazily compute and cache cert file/dir paths - static CERT_PATHS: LazyLock<(PathBuf, PathBuf)> = LazyLock::new(|| { - fn path_from_cstr(c: &CStr) -> PathBuf { - #[cfg(unix)] - { - use std::os::unix::ffi::OsStrExt; - std::ffi::OsStr::from_bytes(c.to_bytes()).into() - } - #[cfg(windows)] - { - // Use lossy conversion for potential non-UTF8 - PathBuf::from(c.to_string_lossy().as_ref()) - } + if hostname.contains('\0') { + return Err(vm.new_type_error("embedded null character")); } - let probe = probe(); - let cert_file = probe - .cert_file - .as_ref() - .map(PathBuf::from) - .unwrap_or_else(|| { - path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_file()) }) - }); - let cert_dir = probe - .cert_dir - .as_ref() - .map(PathBuf::from) - .unwrap_or_else(|| { - path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir()) }) - }); - (cert_file, cert_dir) - }); + if hostname.len() > 253 { + return Err(vm.new_value_error("server_hostname is too long (maximum 253 characters)")); + } - fn get_cert_file_dir() -> (&'static Path, &'static Path) { - let (cert_file, cert_dir) = &*CERT_PATHS; - (cert_file.as_path(), cert_dir.as_path()) + Ok(()) } - // Lazily compute and cache cert environment variable names - static CERT_ENV_NAMES: LazyLock<(String, String)> = LazyLock::new(|| { - let cert_file_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_file_env()) } - .to_string_lossy() - .into_owned(); - let cert_dir_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir_env()) } - .to_string_lossy() - .into_owned(); - (cert_file_env, cert_dir_env) - }); - - #[pyfunction] - fn get_default_verify_paths( - vm: &VirtualMachine, - ) -> PyResult<(&'static str, PyObjectRef, &'static str, PyObjectRef)> { - let (cert_file_env, cert_dir_env) = &*CERT_ENV_NAMES; - let (cert_file, cert_dir) = get_cert_file_dir(); - let cert_file = OsPath::new_str(cert_file).filename(vm); - let cert_dir = OsPath::new_str(cert_dir).filename(vm); - Ok(( - cert_file_env.as_str(), - cert_file, - cert_dir_env.as_str(), - cert_dir, - )) + // SNI certificate resolver that uses shared mutable state + // The Python SNI callback updates this state, and resolve() reads from it + #[derive(Debug)] + struct SniCertResolver { + // SNI state: (certificate, server_name) + sni_state: Arc>, } - #[pyfunction(name = "RAND_status")] - fn rand_status() -> i32 { - unsafe { sys::RAND_status() } - } + impl ResolvesServerCert for SniCertResolver { + fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { + let mut state = self.sni_state.lock(); - #[pyfunction(name = "RAND_add")] - fn rand_add(string: ArgStrOrBytesLike, entropy: f64) { - let f = |b: &[u8]| { - for buf in b.chunks(libc::c_int::MAX as usize) { - unsafe { sys::RAND_add(buf.as_ptr() as *const _, buf.len() as _, entropy) } + // Extract and store SNI from client hello for later use + if let Some(sni) = client_hello.server_name() { + state.1 = Some(sni.to_string()); + } else { + state.1 = None; } - }; - f(&string.borrow_bytes()) - } - #[pyfunction(name = "RAND_bytes")] - fn rand_bytes(n: i32, vm: &VirtualMachine) -> PyResult> { - if n < 0 { - return Err(vm.new_value_error("num must be positive")); + // Return the current certificate (may have been updated by Python callback) + Some(state.0.clone()) } - let mut buf = vec![0; n as usize]; - openssl::rand::rand_bytes(&mut buf).map_err(|e| convert_openssl_error(vm, e))?; - Ok(buf) } - // Callback data stored in SSL context for SNI - struct SniCallbackData { - ssl_context: PyRef, - vm_ptr: *const VirtualMachine, + // Session data structure for tracking TLS sessions + #[derive(Debug, Clone)] + struct SessionData { + #[allow(dead_code)] + server_name: String, + session_id: Vec, + creation_time: SystemTime, + lifetime: u64, } - impl Drop for SniCallbackData { - fn drop(&mut self) { - // PyRef will handle reference counting - } + // Type alias to simplify complex session cache type + type SessionCache = Arc, Arc>>>>; + + // Type alias for SNI state + type SniCertName = (Arc, Option); + + // SESSION EMULATION IMPLEMENTATION + // + // IMPORTANT: This is an EMULATION of CPython's SSL session management. + // Rustls 0.23 does NOT expose session data (ticket bytes, session IDs, etc.) + // through public APIs. All session value fields are private. + // + // LIMITATIONS: + // - Session IDs are generated from metadata (server name + timestamp hash) + // NOT actual TLS session IDs + // - Ticket data is not stored (Rustls keeps it internally) + // - Session resumption works (via Rustls's automatic mechanism) + // but we can't access the actual session state + // + // This implementation provides: + // ✓ session.id - synthetic ID based on metadata + // ✓ session.time - creation timestamp + // ✓ session.timeout - default lifetime value + // ✓ session.has_ticket - always True when session exists + // ✓ session_reused - tracked via handshake_kind() + // ✗ Actual TLS session ID/ticket data - NOT ACCESSIBLE + + // Generate a synthetic session ID from server name and timestamp + // NOTE: This is NOT the actual TLS session ID, just a unique identifier + fn generate_session_id_from_metadata(server_name: &str, time: &SystemTime) -> Vec { + let mut hasher = Sha256::new(); + hasher.update(server_name.as_bytes()); + hasher.update( + time.duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs() + .to_le_bytes(), + ); + hasher.finalize()[..16].to_vec() } - // Get or create an ex_data index for SNI callback data - fn get_sni_ex_data_index() -> libc::c_int { - use std::sync::LazyLock; - static SNI_EX_DATA_IDX: LazyLock = LazyLock::new(|| unsafe { - sys::SSL_get_ex_new_index( - 0, - std::ptr::null_mut(), - None, - None, - Some(sni_callback_data_free), - ) - }); - *SNI_EX_DATA_IDX + // Custom ClientSessionStore that tracks session metadata for Python access + // NOTE: This wraps ClientSessionMemoryCache and records metadata when sessions are stored + #[derive(Debug)] + struct PythonClientSessionStore { + inner: Arc, + session_cache: SessionCache, } - // Free function for callback data - unsafe extern "C" fn sni_callback_data_free( - _parent: *mut libc::c_void, - ptr: *mut libc::c_void, - _ad: *mut sys::CRYPTO_EX_DATA, - _idx: libc::c_int, - _argl: libc::c_long, - _argp: *mut libc::c_void, - ) { - if !ptr.is_null() { - unsafe { - let _ = Box::from_raw(ptr as *mut SniCallbackData); - } + impl ClientSessionStore for PythonClientSessionStore { + fn set_kx_hint(&self, server_name: ServerName<'static>, group: rustls::NamedGroup) { + self.inner.set_kx_hint(server_name, group); } - } - // SNI callback function called by OpenSSL - unsafe extern "C" fn _servername_callback( - ssl_ptr: *mut sys::SSL, - _al: *mut libc::c_int, - arg: *mut libc::c_void, - ) -> libc::c_int { - const SSL_TLSEXT_ERR_OK: libc::c_int = 0; - const SSL_TLSEXT_ERR_ALERT_FATAL: libc::c_int = 2; - const TLSEXT_NAMETYPE_host_name: libc::c_int = 0; - - if arg.is_null() { - return SSL_TLSEXT_ERR_OK; + fn kx_hint(&self, server_name: &ServerName<'_>) -> Option { + self.inner.kx_hint(server_name) } - unsafe { - let ctx = &*(arg as *const PySslContext); - - // Get the callback - let callback_opt = ctx.sni_callback.lock().clone(); - let Some(callback) = callback_opt else { - return SSL_TLSEXT_ERR_OK; - }; - - // Get callback data from SSL ex_data - let idx = get_sni_ex_data_index(); - let data_ptr = sys::SSL_get_ex_data(ssl_ptr, idx); - if data_ptr.is_null() { - return SSL_TLSEXT_ERR_ALERT_FATAL; - } - - let callback_data = &*(data_ptr as *const SniCallbackData); - - // SAFETY: vm_ptr is stored during wrap_socket and is valid for the lifetime - // of the SSL connection. The handshake happens synchronously in the same thread. - let vm = &*callback_data.vm_ptr; - - // Get server name - let servername = sys::SSL_get_servername(ssl_ptr, TLSEXT_NAMETYPE_host_name); - let server_name_arg = if servername.is_null() { - vm.ctx.none() - } else { - let name_cstr = std::ffi::CStr::from_ptr(servername); - match name_cstr.to_str() { - Ok(name_str) => vm.ctx.new_str(name_str).into(), - Err(_) => vm.ctx.none(), - } - }; - - // Get SSL socket from SSL ex_data (stored as PySslSocket pointer) - let ssl_socket_ptr = sys::SSL_get_ex_data(ssl_ptr, 0); // Index 0 for SSL socket - let ssl_socket_obj = if !ssl_socket_ptr.is_null() { - let ssl_socket = &*(ssl_socket_ptr as *const PySslSocket); - // Try to get owner first - ssl_socket - .owner - .read() - .as_ref() - .and_then(|weak| weak.upgrade()) - .unwrap_or_else(|| vm.ctx.none()) - } else { - vm.ctx.none() + fn set_tls12_session( + &self, + server_name: ServerName<'static>, + value: rustls::client::Tls12ClientSessionValue, + ) { + // Store in inner cache for actual resumption (Rustls handles this) + self.inner.set_tls12_session(server_name.clone(), value); + + // Record metadata in Python-accessible cache + // NOTE: We can't access value.session_id or value.ticket (private fields) + // So we generate a synthetic ID from metadata + let creation_time = SystemTime::now(); + let server_name_str = server_name.to_str(); + let session_data = SessionData { + server_name: server_name_str.as_ref().to_string(), + session_id: generate_session_id_from_metadata( + server_name_str.as_ref(), + &creation_time, + ), + creation_time, + lifetime: 7200, // TLS 1.2 default session lifetime }; - // Call the Python callback - match callback.call( - ( - ssl_socket_obj, - server_name_arg, - callback_data.ssl_context.to_owned(), - ), - vm, - ) { - Ok(_) => SSL_TLSEXT_ERR_OK, - Err(exc) => { - // Log the exception but don't propagate it - vm.run_unraisable(exc, None, vm.ctx.none()); - SSL_TLSEXT_ERR_ALERT_FATAL - } - } + let key = server_name_str.as_bytes().to_vec(); + self.session_cache + .write() + .insert(key, Arc::new(ParkingMutex::new(session_data))); } - } - #[pyfunction(name = "RAND_pseudo_bytes")] - fn rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec, bool)> { - if n < 0 { - return Err(vm.new_value_error("num must be positive")); - } - let mut buf = vec![0; n as usize]; - let ret = unsafe { sys::RAND_bytes(buf.as_mut_ptr(), n) }; - match ret { - 0 | 1 => Ok((buf, ret == 1)), - _ => Err(convert_openssl_error(vm, ErrorStack::get())), + fn tls12_session( + &self, + server_name: &ServerName<'_>, + ) -> Option { + self.inner.tls12_session(server_name) } - } - #[pyattr] - #[pyclass(module = "ssl", name = "_SSLContext")] - #[derive(PyPayload)] - struct PySslContext { - ctx: PyRwLock, - check_hostname: AtomicCell, - protocol: SslVersion, - post_handshake_auth: PyMutex, - sni_callback: PyMutex>, - } + fn remove_tls12_session(&self, server_name: &ServerName<'static>) { + self.inner.remove_tls12_session(server_name); - impl fmt::Debug for PySslContext { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.pad("_SSLContext") + // Also remove from Python cache + let key = server_name.to_str().as_bytes().to_vec(); + self.session_cache.write().remove(&key); } - } - - fn builder_as_ctx(x: &SslContextBuilder) -> &ssl::SslContextRef { - unsafe { ssl::SslContextRef::from_ptr(x.as_ptr()) } - } - impl Constructor for PySslContext { - type Args = i32; - - fn py_new(cls: PyTypeRef, proto_version: Self::Args, vm: &VirtualMachine) -> PyResult { - let proto = SslVersion::try_from(proto_version) - .map_err(|_| vm.new_value_error("invalid protocol version"))?; - let method = match proto { - // SslVersion::Ssl3 => unsafe { ssl::SslMethod::from_ptr(sys::SSLv3_method()) }, - SslVersion::Tls => ssl::SslMethod::tls(), - SslVersion::Tls1 => ssl::SslMethod::tls(), - SslVersion::Tls1_1 => ssl::SslMethod::tls(), - SslVersion::Tls1_2 => ssl::SslMethod::tls(), - SslVersion::TlsClient => ssl::SslMethod::tls_client(), - SslVersion::TlsServer => ssl::SslMethod::tls_server(), - _ => return Err(vm.new_value_error("invalid protocol version")), + fn insert_tls13_ticket( + &self, + server_name: ServerName<'static>, + value: rustls::client::Tls13ClientSessionValue, + ) { + // Store in inner cache for actual resumption (Rustls handles this) + self.inner.insert_tls13_ticket(server_name.clone(), value); + + // Record metadata in Python-accessible cache + // NOTE: We can't access value.ticket or value.lifetime_secs (private fields) + // So we use default values + let creation_time = SystemTime::now(); + let server_name_str = server_name.to_str(); + let session_data = SessionData { + server_name: server_name_str.to_string(), + session_id: generate_session_id_from_metadata( + server_name_str.as_ref(), + &creation_time, + ), + creation_time, + lifetime: 7200, // Default TLS 1.3 ticket lifetime (Rustls uses this) }; - let mut builder = - SslContextBuilder::new(method).map_err(|e| convert_openssl_error(vm, e))?; - - #[cfg(target_os = "android")] - android::load_client_ca_list(vm, &mut builder)?; - - let check_hostname = proto == SslVersion::TlsClient; - builder.set_verify(if check_hostname { - SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT - } else { - SslVerifyMode::NONE - }); - let mut options = SslOptions::ALL & !SslOptions::DONT_INSERT_EMPTY_FRAGMENTS; - if proto != SslVersion::Ssl2 { - options |= SslOptions::NO_SSLV2; - } - if proto != SslVersion::Ssl3 { - options |= SslOptions::NO_SSLV3; - } - options |= SslOptions::NO_COMPRESSION; - options |= SslOptions::CIPHER_SERVER_PREFERENCE; - options |= SslOptions::SINGLE_DH_USE; - options |= SslOptions::SINGLE_ECDH_USE; - options |= SslOptions::ENABLE_MIDDLEBOX_COMPAT; - builder.set_options(options); + let key = server_name_str.as_bytes().to_vec(); + self.session_cache + .write() + .insert(key, Arc::new(ParkingMutex::new(session_data))); + } - let mode = ssl::SslMode::ACCEPT_MOVING_WRITE_BUFFER | ssl::SslMode::AUTO_RETRY; - builder.set_mode(mode); + fn take_tls13_ticket( + &self, + server_name: &ServerName<'static>, + ) -> Option { + self.inner.take_tls13_ticket(server_name) + } + } - #[cfg(ossl111)] - unsafe { - sys::SSL_CTX_set_post_handshake_auth(builder.as_ptr(), 0); + /// Parse length-prefixed ALPN protocol list + /// + /// Format: [len1, proto1..., len2, proto2..., ...] + /// + /// This is the wire format used by Python's ssl.py when calling _set_alpn_protocols(). + /// Each protocol is prefixed with a single byte indicating its length. + /// + /// # Arguments + /// * `bytes` - The length-prefixed protocol data + /// * `vm` - VirtualMachine for error creation + /// + /// # Returns + /// * `Ok(Vec>)` - List of protocol names as byte vectors + /// * `Err(PyBaseExceptionRef)` - ValueError with detailed error message + fn parse_length_prefixed_alpn(bytes: &[u8], vm: &VirtualMachine) -> PyResult>> { + let mut alpn_list = Vec::new(); + let mut offset = 0; + + while offset < bytes.len() { + // Check if we can read the length byte + if offset + 1 > bytes.len() { + return Err(vm.new_value_error(format!( + "Invalid ALPN protocol data: unexpected end at offset {offset}", + ))); } - builder - .set_session_id_context(b"Python") - .map_err(|e| convert_openssl_error(vm, e))?; + let proto_len = bytes[offset] as usize; + offset += 1; - // Set protocol version limits based on the protocol version - unsafe { - let ctx_ptr = builder.as_ptr(); - match proto { - SslVersion::Tls1 => { - sys::SSL_CTX_set_min_proto_version(ctx_ptr, sys::TLS1_VERSION); - sys::SSL_CTX_set_max_proto_version(ctx_ptr, sys::TLS1_VERSION); - } - SslVersion::Tls1_1 => { - sys::SSL_CTX_set_min_proto_version(ctx_ptr, sys::TLS1_1_VERSION); - sys::SSL_CTX_set_max_proto_version(ctx_ptr, sys::TLS1_1_VERSION); - } - SslVersion::Tls1_2 => { - sys::SSL_CTX_set_min_proto_version(ctx_ptr, sys::TLS1_2_VERSION); - sys::SSL_CTX_set_max_proto_version(ctx_ptr, sys::TLS1_2_VERSION); - } - _ => { - // For Tls, TlsClient, TlsServer, use default (no restrictions) - } - } + // Validate protocol length + if proto_len == 0 { + return Err(vm.new_value_error(format!( + "Invalid ALPN protocol data: protocol length cannot be 0 at offset {}", + offset - 1 + ))); } - // Set default verify flags: VERIFY_X509_TRUSTED_FIRST - unsafe { - let ctx_ptr = builder.as_ptr(); - let param = sys::SSL_CTX_get0_param(ctx_ptr); - sys::X509_VERIFY_PARAM_set_flags(param, sys::X509_V_FLAG_TRUSTED_FIRST); + // Check if we have enough bytes for the protocol data + if offset + proto_len > bytes.len() { + return Err(vm.new_value_error(format!( + "Invalid ALPN protocol data: expected {} bytes at offset {}, but only {} bytes remain", + proto_len, offset, bytes.len() - offset + ))); } - PySslContext { - ctx: PyRwLock::new(builder), - check_hostname: AtomicCell::new(check_hostname), - protocol: proto, - post_handshake_auth: PyMutex::new(false), - sni_callback: PyMutex::new(None), - } - .into_ref_with_type(vm, cls) - .map(Into::into) + // Extract protocol bytes + let proto = bytes[offset..offset + proto_len].to_vec(); + alpn_list.push(proto); + offset += proto_len; } + + Ok(alpn_list) } - #[pyclass(flags(BASETYPE, IMMUTABLETYPE), with(Constructor))] - impl PySslContext { - fn builder(&self) -> PyRwLockWriteGuard<'_, SslContextBuilder> { - self.ctx.write() - } - fn ctx(&self) -> PyMappedRwLockReadGuard<'_, ssl::SslContextRef> { - PyRwLockReadGuard::map(self.ctx.read(), builder_as_ctx) + /// Parse OpenSSL cipher string to rustls SupportedCipherSuite list + /// + /// Supports patterns like: + /// - "AES128" → filters for AES_128 + /// - "AES256" → filters for AES_256 + /// - "AES128:AES256" → both + /// - "ECDHE+AESGCM" → ECDHE AND AESGCM (both conditions must match) + /// - "ALL" or "DEFAULT" → all available + /// - "!MD5" → exclusion (ignored, rustls doesn't support weak ciphers anyway) + fn parse_cipher_string(cipher_str: &str) -> Result, String> { + use rustls::crypto::aws_lc_rs::ALL_CIPHER_SUITES; + + if cipher_str.is_empty() { + return Err("No cipher can be selected".to_string()); } - #[pygetset] - fn post_handshake_auth(&self) -> bool { - *self.post_handshake_auth.lock() - } - #[pygetset(setter)] - fn set_post_handshake_auth( - &self, - value: Option, - vm: &VirtualMachine, - ) -> PyResult<()> { - let value = value.ok_or_else(|| vm.new_attribute_error("cannot delete attribute"))?; - *self.post_handshake_auth.lock() = value.is_true(vm)?; - Ok(()) - } + let all_suites = ALL_CIPHER_SUITES; + let mut selected = Vec::new(); - #[cfg(ossl110)] - #[pygetset] - fn security_level(&self) -> i32 { - unsafe { SSL_CTX_get_security_level(self.ctx().as_ptr()) } - } + for part in cipher_str.split(':') { + let part = part.trim(); - #[pymethod] - fn set_ciphers(&self, cipherlist: PyStrRef, vm: &VirtualMachine) -> PyResult<()> { - let ciphers = cipherlist.as_str(); - if ciphers.contains('\0') { - return Err(exceptions::cstring_error(vm)); + // Skip exclusions (rustls doesn't support these) + if part.starts_with('!') { + continue; } - self.builder().set_cipher_list(ciphers).map_err(|_| { - vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "No cipher can be selected.".to_owned(), - ) - }) - } - - #[pymethod] - fn get_ciphers(&self, vm: &VirtualMachine) -> PyResult { - let ctx = self.ctx(); - let ssl = ssl::Ssl::new(&ctx).map_err(|e| convert_openssl_error(vm, e))?; - - unsafe { - let ciphers_ptr = SSL_get_ciphers(ssl.as_ptr()); - if ciphers_ptr.is_null() { - return Ok(vm.ctx.new_list(vec![])); - } - - let num_ciphers = sys::OPENSSL_sk_num(ciphers_ptr as *const _); - let mut result = Vec::new(); - - for i in 0..num_ciphers { - let cipher_ptr = - sys::OPENSSL_sk_value(ciphers_ptr as *const _, i) as *const sys::SSL_CIPHER; - let cipher = ssl::SslCipherRef::from_ptr(cipher_ptr as *mut _); - - let (name, version, bits) = cipher_to_tuple(cipher); - let dict = vm.ctx.new_dict(); - dict.set_item("name", vm.ctx.new_str(name).into(), vm)?; - dict.set_item("protocol", vm.ctx.new_str(version).into(), vm)?; - dict.set_item("secret_bits", vm.ctx.new_int(bits).into(), vm)?; - // Add description field - let description = cipher_description(cipher_ptr); - dict.set_item("description", vm.ctx.new_str(description).into(), vm)?; + // Skip priority markers starting with + + if part.starts_with('+') { + continue; + } - result.push(dict.into()); + // Match pattern + match part { + "ALL" | "DEFAULT" | "HIGH" => { + // Add all available cipher suites + selected.extend_from_slice(all_suites); } + _ => { + // Check if this is a compound pattern with + (AND condition) + // e.g., "ECDHE+AESGCM" means ECDHE AND AESGCM + let patterns: Vec<&str> = part.split('+').collect(); + + let mut found_any = false; + for suite in all_suites { + let name = format!("{:?}", suite.suite()); + + // Check if all patterns match (AND condition) + let matches = patterns.iter().all(|&pattern| { + // Handle common OpenSSL pattern variations + if pattern.contains("AES128") { + name.contains("AES_128") + } else if pattern.contains("AES256") { + name.contains("AES_256") + } else if pattern == "AESGCM" { + // AESGCM: AES with GCM mode + name.contains("AES") && name.contains("GCM") + } else if pattern == "AESCCM" { + // AESCCM: AES with CCM mode + name.contains("AES") && name.contains("CCM") + } else if pattern == "CHACHA20" { + name.contains("CHACHA20") + } else if pattern == "ECDHE" { + name.contains("ECDHE") + } else if pattern == "DHE" { + // DHE but not ECDHE + name.contains("DHE") && !name.contains("ECDHE") + } else if pattern == "ECDH" { + // ECDH but not ECDHE + name.contains("ECDH") && !name.contains("ECDHE") + } else if pattern == "DH" { + // DH but not DHE or ECDH + name.contains("DH") + && !name.contains("DHE") + && !name.contains("ECDH") + } else if pattern == "RSA" { + name.contains("RSA") + } else if pattern == "AES" { + name.contains("AES") + } else if pattern == "ECDSA" { + name.contains("ECDSA") + } else { + // Direct substring match for other patterns + name.contains(pattern) + } + }); - Ok(vm.ctx.new_list(result)) - } - } - - #[pymethod] - fn set_ecdh_curve( - &self, - name: Either, - vm: &VirtualMachine, - ) -> PyResult<()> { - use openssl::ec::{EcGroup, EcKey}; + if matches { + selected.push(*suite); + found_any = true; + } + } - // Convert name to CString, supporting both str and bytes - let name_cstr = match name { - Either::A(s) => { - if s.as_str().contains('\0') { - return Err(exceptions::cstring_error(vm)); + if !found_any { + // No matching cipher suite found - warn but continue } - s.to_cstring(vm)? } - Either::B(b) => std::ffi::CString::new(b.borrow_buf().to_vec()) - .map_err(|_| exceptions::cstring_error(vm))?, - }; - - // Find the NID for the curve name using OBJ_sn2nid - let nid_raw = unsafe { sys::OBJ_sn2nid(name_cstr.as_ptr()) }; - if nid_raw == 0 { - return Err(vm.new_value_error("unknown curve name")); } - let nid = Nid::from_raw(nid_raw); + } - // Create EC key from the curve - let group = EcGroup::from_curve_name(nid).map_err(|e| convert_openssl_error(vm, e))?; - let key = EcKey::from_group(&group).map_err(|e| convert_openssl_error(vm, e))?; + // Remove duplicates + selected.dedup_by_key(|s| s.suite()); - // Set the temporary ECDH key - self.builder() - .set_tmp_ecdh(&key) - .map_err(|e| convert_openssl_error(vm, e)) + if selected.is_empty() { + Err("No cipher can be selected".to_string()) + } else { + Ok(selected) + } + } + + // SSLContext - manages TLS configuration + #[pyattr] + #[pyclass(name = "_SSLContext", module = "ssl", traverse)] + #[derive(Debug, PyPayload)] + struct PySSLContext { + #[pytraverse(skip)] + protocol: i32, + #[pytraverse(skip)] + check_hostname: PyRwLock, + #[pytraverse(skip)] + verify_mode: PyRwLock, + #[pytraverse(skip)] + verify_flags: PyRwLock, + // Rustls configuration (built lazily) + #[allow(dead_code)] + #[pytraverse(skip)] + client_config: PyRwLock>>, + #[allow(dead_code)] + #[pytraverse(skip)] + server_config: PyRwLock>>, + // Certificate store + #[pytraverse(skip)] + root_certs: PyRwLock, + // Store full CA certificates for get_ca_certs() + // RootCertStore only keeps TrustAnchors, not full certificates + #[pytraverse(skip)] + ca_certs_der: PyRwLock>>, + // Store CA certificates from capath for lazy loading simulation + // (CPython only returns these in get_ca_certs() after they're used in handshake) + #[pytraverse(skip)] + capath_certs_der: PyRwLock>>, + // Certificate Revocation Lists for CRL checking + #[pytraverse(skip)] + crls: PyRwLock>>, + // Server certificate/key pairs (supports multiple for RSA+ECC dual mode) + // OpenSSL allows multiple cert/key pairs to be loaded, and selects the appropriate + // one based on client capabilities during handshake + // Stored as (CertifiedKey, PrivateKeyDer) to support both server and client usage + #[pytraverse(skip)] + cert_keys: PyRwLock>, + // Options + #[allow(dead_code)] + #[pytraverse(skip)] + options: PyRwLock, + // ALPN protocols + #[allow(dead_code)] + #[pytraverse(skip)] + alpn_protocols: PyRwLock>>, + // ALPN strict matching flag + // When false (default), mimics OpenSSL behavior: no ALPN negotiation failure + // When true, requires ALPN match (Rustls default behavior) + #[allow(dead_code)] + #[pytraverse(skip)] + require_alpn_match: PyRwLock, + // TLS 1.3 features + #[pytraverse(skip)] + post_handshake_auth: PyRwLock, + #[pytraverse(skip)] + num_tickets: PyRwLock, + // Protocol version limits + #[pytraverse(skip)] + minimum_version: PyRwLock, + #[pytraverse(skip)] + maximum_version: PyRwLock, + // SNI callback for server-side (contains PyObjectRef - needs GC tracking) + sni_callback: PyRwLock>, + // Message callback for debugging (contains PyObjectRef - needs GC tracking) + msg_callback: PyRwLock>, + // ECDH curve name for key exchange + #[pytraverse(skip)] + ecdh_curve: PyRwLock>, + // Certificate statistics for cert_store_stats() + #[pytraverse(skip)] + ca_cert_count: PyRwLock, // Number of CA certificates + #[pytraverse(skip)] + x509_cert_count: PyRwLock, // Total number of certificates + // Session management + #[pytraverse(skip)] + client_session_cache: SessionCache, + // Rustls session store for actual TLS session resumption + #[pytraverse(skip)] + rustls_session_store: Arc, + // Rustls server session store for server-side session resumption + #[pytraverse(skip)] + rustls_server_session_store: Arc, + // Shared ticketer for TLS 1.2 session tickets + #[pytraverse(skip)] + server_ticketer: Arc, + // Server-side session statistics + #[pytraverse(skip)] + accept_count: AtomicUsize, // Total number of accepts + #[pytraverse(skip)] + session_hits: AtomicUsize, // Number of session reuses + // Cipher suite selection + /// Selected cipher suites (None = use all rustls defaults) + #[pytraverse(skip)] + selected_ciphers: PyRwLock>>, + } + + #[derive(FromArgs)] + struct WrapSocketArgs { + sock: PyObjectRef, + server_side: bool, + #[pyarg(positional, optional)] + server_hostname: OptionalArg>, + #[pyarg(named, optional)] + owner: OptionalArg, + #[pyarg(named, optional)] + session: OptionalArg, + } + + #[derive(FromArgs)] + struct WrapBioArgs { + incoming: PyRef, + outgoing: PyRef, + #[pyarg(named, optional)] + server_side: OptionalArg, + #[pyarg(named, optional)] + server_hostname: OptionalArg>, + #[pyarg(named, optional)] + owner: OptionalArg, + #[pyarg(named, optional)] + session: OptionalArg, + } + + #[derive(FromArgs)] + struct LoadVerifyLocationsArgs { + #[pyarg(any, optional)] + cafile: OptionalArg>, + #[pyarg(any, optional)] + capath: OptionalArg>, + #[pyarg(any, optional)] + cadata: OptionalArg, + } + + #[derive(FromArgs)] + struct LoadCertChainArgs { + #[pyarg(any)] + certfile: PyObjectRef, + #[pyarg(any, optional)] + keyfile: OptionalArg>, + #[pyarg(any, optional)] + password: OptionalArg, + } + + #[pyclass(with(Constructor), flags(BASETYPE))] + impl PySSLContext { + // Helper method to convert DER certificate bytes to Python dict + fn cert_der_to_dict(&self, vm: &VirtualMachine, cert_der: &[u8]) -> PyResult { + cert::cert_der_to_dict_helper(vm, cert_der) + } + + #[pymethod] + fn __repr__(&self) -> String { + format!("", self.protocol) + } + + #[pygetset] + fn check_hostname(&self) -> bool { + *self.check_hostname.read() + } + + #[pygetset(setter)] + fn set_check_hostname(&self, value: bool) { + *self.check_hostname.write() = value; + // When check_hostname is enabled, ensure verify_mode is at least CERT_REQUIRED + if value { + let current_verify_mode = *self.verify_mode.read(); + if current_verify_mode == CERT_NONE { + *self.verify_mode.write() = CERT_REQUIRED; + } + } } #[pygetset] - fn options(&self) -> libc::c_ulong { - self.ctx.read().options().bits() as _ + fn verify_mode(&self) -> i32 { + *self.verify_mode.read() } + #[pygetset(setter)] - fn set_options(&self, opts: libc::c_ulong) { - self.builder() - .set_options(SslOptions::from_bits_truncate(opts as _)); + fn set_verify_mode(&self, mode: i32, vm: &VirtualMachine) -> PyResult<()> { + if !(CERT_NONE..=CERT_REQUIRED).contains(&mode) { + return Err(vm.new_value_error("invalid verify mode")); + } + // Cannot set CERT_NONE when check_hostname is enabled + if mode == CERT_NONE && *self.check_hostname.read() { + return Err(vm.new_value_error( + "Cannot set verify_mode to CERT_NONE when check_hostname is enabled", + )); + } + *self.verify_mode.write() = mode; + Ok(()) } + #[pygetset] fn protocol(&self) -> i32 { - self.protocol as i32 + self.protocol } + #[pygetset] - fn verify_mode(&self) -> i32 { - let mode = self.ctx().verify_mode(); - if mode == SslVerifyMode::NONE { - CertRequirements::None.into() - } else if mode == SslVerifyMode::PEER { - CertRequirements::Optional.into() - } else if mode == SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT { - CertRequirements::Required.into() - } else { - unreachable!() - } + fn verify_flags(&self) -> i32 { + *self.verify_flags.read() } + #[pygetset(setter)] - fn set_verify_mode(&self, cert: i32, vm: &VirtualMachine) -> PyResult<()> { - let mut ctx = self.builder(); - let cert_req = CertRequirements::try_from(cert) - .map_err(|_| vm.new_value_error("invalid value for verify_mode"))?; - let mode = match cert_req { - CertRequirements::None if self.check_hostname.load() => { - return Err(vm.new_value_error( - "Cannot set verify_mode to CERT_NONE when check_hostname is enabled.", - )); - } - CertRequirements::None => SslVerifyMode::NONE, - CertRequirements::Optional => SslVerifyMode::PEER, - CertRequirements::Required => { - SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT - } - }; - ctx.set_verify(mode); - Ok(()) + fn set_verify_flags(&self, value: i32) { + *self.verify_flags.write() = value; } + #[pygetset] - fn verify_flags(&self) -> libc::c_ulong { - unsafe { - let ctx_ptr = self.ctx().as_ptr(); - let param = sys::SSL_CTX_get0_param(ctx_ptr); - sys::X509_VERIFY_PARAM_get_flags(param) - } + fn post_handshake_auth(&self) -> bool { + *self.post_handshake_auth.read() } + #[pygetset(setter)] - fn set_verify_flags(&self, new_flags: libc::c_ulong, vm: &VirtualMachine) -> PyResult<()> { - unsafe { - let ctx_ptr = self.ctx().as_ptr(); - let param = sys::SSL_CTX_get0_param(ctx_ptr); - let flags = sys::X509_VERIFY_PARAM_get_flags(param); - let clear = flags & !new_flags; - let set = !flags & new_flags; - - if clear != 0 && sys::X509_VERIFY_PARAM_clear_flags(param, clear) == 0 { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Failed to clear verify flags".to_owned(), - )); - } - if set != 0 && sys::X509_VERIFY_PARAM_set_flags(param, set) == 0 { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Failed to set verify flags".to_owned(), - )); - } - Ok(()) + fn set_post_handshake_auth(&self, value: bool) { + *self.post_handshake_auth.write() = value; + } + + #[pygetset] + fn num_tickets(&self) -> i32 { + *self.num_tickets.read() + } + + #[pygetset(setter)] + fn set_num_tickets(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> { + if value < 0 { + return Err(vm.new_value_error("num_tickets must be a non-negative integer")); } + if self.protocol != PROTOCOL_TLS_SERVER { + return Err( + vm.new_value_error("num_tickets can only be set on server-side contexts") + ); + } + *self.num_tickets.write() = value; + Ok(()) } + #[pygetset] - fn check_hostname(&self) -> bool { - self.check_hostname.load() + fn options(&self) -> i32 { + *self.options.read() } + #[pygetset(setter)] - fn set_check_hostname(&self, ch: bool) { - let mut ctx = self.builder(); - if ch && builder_as_ctx(&ctx).verify_mode() == SslVerifyMode::NONE { - ctx.set_verify(SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT); + fn set_options(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> { + // Validate that the value is non-negative + if value < 0 { + return Err(vm.new_overflow_error("options must be non-negative".to_owned())); } - self.check_hostname.store(ch); + + // Deprecated SSL/TLS protocol version options + let opt_no = OP_NO_SSLv2 + | OP_NO_SSLv3 + | OP_NO_TLSv1 + | OP_NO_TLSv1_1 + | OP_NO_TLSv1_2 + | OP_NO_TLSv1_3; + + // Get current options and calculate newly set bits + let old_opts = *self.options.read(); + let set = !old_opts & value; // Bits being newly set + + // Warn if any deprecated options are being newly set + if (set & opt_no) != 0 { + warnings::warn( + vm.ctx.exceptions.deprecation_warning, + "ssl.OP_NO_SSL*/ssl.OP_NO_TLS* options are deprecated".to_owned(), + 2, // stack_level = 2 + vm, + )?; + } + + *self.options.write() = value; + Ok(()) } - // PY_PROTO_MINIMUM_SUPPORTED = -2, PY_PROTO_MAXIMUM_SUPPORTED = -1 #[pygetset] fn minimum_version(&self) -> i32 { - let ctx = self.ctx(); - let version = unsafe { sys::SSL_CTX_get_min_proto_version(ctx.as_ptr()) }; - if version == 0 { - -2 // PY_PROTO_MINIMUM_SUPPORTED - } else { - version - } + let v = *self.minimum_version.read(); + // return MINIMUM_SUPPORTED if value is 0 + if v == 0 { PROTO_MINIMUM_SUPPORTED } else { v } } + #[pygetset(setter)] fn set_minimum_version(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> { - // Handle special values - let proto_version = match value { - -2 => { - // PY_PROTO_MINIMUM_SUPPORTED -> use minimum available (TLS 1.2) - sys::TLS1_2_VERSION - } - -1 => { - // PY_PROTO_MAXIMUM_SUPPORTED -> use maximum available - // For max on min_proto_version, we use the newest available - sys::TLS1_3_VERSION - } + // Validate that the value is a valid TLS version constant + // Valid values: 0 (default), -2 (MINIMUM_SUPPORTED), -1 (MAXIMUM_SUPPORTED), + // or 0x0300-0x0304 (SSLv3-TLSv1.3) + if value != 0 + && value != -2 + && value != -1 + && !(PROTO_SSLv3..=PROTO_TLSv1_3).contains(&value) + { + return Err(vm.new_value_error(format!("invalid protocol version: {value}"))); + } + // Convert special values to rustls actual supported versions + // MINIMUM_SUPPORTED (-2) -> 0 (auto-negotiate) + // MAXIMUM_SUPPORTED (-1) -> MAXIMUM_VERSION (TLSv1.3) + let normalized_value = match value { + PROTO_MINIMUM_SUPPORTED => 0, // Auto-negotiate + PROTO_MAXIMUM_SUPPORTED => MAXIMUM_VERSION, // TLSv1.3 _ => value, }; - - let ctx = self.builder(); - let result = unsafe { sys::SSL_CTX_set_min_proto_version(ctx.as_ptr(), proto_version) }; - if result == 0 { - return Err(vm.new_value_error("invalid protocol version")); - } + *self.minimum_version.write() = normalized_value; Ok(()) } #[pygetset] fn maximum_version(&self) -> i32 { - let ctx = self.ctx(); - let version = unsafe { sys::SSL_CTX_get_max_proto_version(ctx.as_ptr()) }; - if version == 0 { - -1 // PY_PROTO_MAXIMUM_SUPPORTED - } else { - version - } + let v = *self.maximum_version.read(); + // return MAXIMUM_SUPPORTED if value is 0 + if v == 0 { PROTO_MAXIMUM_SUPPORTED } else { v } } + #[pygetset(setter)] fn set_maximum_version(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> { - // Handle special values - let proto_version = match value { - -1 => { - // PY_PROTO_MAXIMUM_SUPPORTED -> use 0 for OpenSSL (means no limit) - 0 - } - -2 => { - // PY_PROTO_MINIMUM_SUPPORTED -> use minimum available (TLS 1.2) - sys::TLS1_2_VERSION - } + // Validate that the value is a valid TLS version constant + // Valid values: 0 (default), -2 (MINIMUM_SUPPORTED), -1 (MAXIMUM_SUPPORTED), + // or 0x0300-0x0304 (SSLv3-TLSv1.3) + if value != 0 + && value != -2 + && value != -1 + && !(PROTO_SSLv3..=PROTO_TLSv1_3).contains(&value) + { + return Err(vm.new_value_error(format!("invalid protocol version: {value}"))); + } + // Convert special values to rustls actual supported versions + // MAXIMUM_SUPPORTED (-1) -> 0 (auto-negotiate) + // MINIMUM_SUPPORTED (-2) -> MINIMUM_VERSION (TLSv1.2) + let normalized_value = match value { + PROTO_MAXIMUM_SUPPORTED => 0, // Auto-negotiate + PROTO_MINIMUM_SUPPORTED => MINIMUM_VERSION, // TLSv1.2 _ => value, }; - - let ctx = self.builder(); - let result = unsafe { sys::SSL_CTX_set_max_proto_version(ctx.as_ptr(), proto_version) }; - if result == 0 { - return Err(vm.new_value_error("invalid protocol version")); - } + *self.maximum_version.write() = normalized_value; Ok(()) } - #[pygetset] - fn num_tickets(&self, _vm: &VirtualMachine) -> PyResult { - // Only supported for TLS 1.3 - #[cfg(ossl110)] - { - let ctx = self.ctx(); - let num = unsafe { sys::SSL_CTX_get_num_tickets(ctx.as_ptr()) }; - Ok(num) - } - #[cfg(not(ossl110))] + #[pymethod] + fn load_cert_chain(&self, args: LoadCertChainArgs, vm: &VirtualMachine) -> PyResult<()> { + // Parse certfile argument (str or bytes) to path + let cert_path = Self::parse_path_arg(&args.certfile, vm)?; + + // Parse keyfile argument (default to certfile if not provided) + let key_path = match args.keyfile { + OptionalArg::Present(Some(ref k)) => Self::parse_path_arg(k, vm)?, + _ => cert_path.clone(), + }; + + // Parse password argument (str, bytes-like, or callable) + // Callable passwords are NOT invoked immediately (lazy evaluation) + let (password_str, password_callable) = + Self::parse_password_argument(&args.password, vm)?; + + // Validate immediate password length (limit: PEM_BUFSIZE = 1024 bytes) + if let Some(ref pwd) = password_str + && pwd.len() > PEM_BUFSIZE { - let _ = vm; - Ok(0) - } - } - #[pygetset(setter)] - fn set_num_tickets(&self, value: isize, vm: &VirtualMachine) -> PyResult<()> { - // Check for negative values - if value < 0 { - return Err( - vm.new_value_error("num_tickets must be a non-negative integer".to_owned()) - ); + return Err(vm.new_value_error(format!( + "password cannot be longer than {PEM_BUFSIZE} bytes", + ))); } - // Check that this is a server context - if self.protocol != SslVersion::TlsServer { - return Err(vm.new_value_error("SSLContext is not a server context.".to_owned())); - } + // First attempt: Load with immediate password (or None if callable) + let mut result = + cert::load_cert_chain_from_file(&cert_path, &key_path, password_str.as_deref()); - #[cfg(ossl110)] + // If failed and callable exists, invoke it and retry + // This implements lazy evaluation: callable only invoked if password is actually needed + if result.is_err() + && let Some(callable) = password_callable { - let ctx = self.builder(); - let result = unsafe { sys::SSL_CTX_set_num_tickets(ctx.as_ptr(), value as usize) }; - if result != 1 { - return Err(vm.new_value_error("failed to set num tickets.")); + // Invoke callable - exceptions propagate naturally + let pwd_result = callable.call((), vm)?; + + // Convert callable result to string + let password_from_callable = if let Ok(pwd_str) = + PyStrRef::try_from_object(vm, pwd_result.clone()) + { + pwd_str.as_str().to_owned() + } else if let Ok(pwd_bytes_like) = ArgBytesLike::try_from_object(vm, pwd_result) { + String::from_utf8(pwd_bytes_like.borrow_buf().to_vec()).map_err(|_| { + vm.new_type_error( + "password callback returned invalid UTF-8 bytes".to_owned(), + ) + })? + } else { + return Err(vm.new_type_error( + "password callback must return a string or bytes".to_owned(), + )); + }; + + // Validate callable password length + if password_from_callable.len() > PEM_BUFSIZE { + return Err(vm.new_value_error(format!( + "password cannot be longer than {PEM_BUFSIZE} bytes", + ))); } - Ok(()) - } - #[cfg(not(ossl110))] - { - let _ = (value, vm); - Ok(()) + + // Retry with callable password + result = cert::load_cert_chain_from_file( + &cert_path, + &key_path, + Some(&password_from_callable), + ); } - } - #[pymethod] - fn set_default_verify_paths(&self, vm: &VirtualMachine) -> PyResult<()> { - cfg_if::cfg_if! { - if #[cfg(openssl_vendored)] { - let (cert_file, cert_dir) = get_cert_file_dir(); - self.builder() - .load_verify_locations(Some(cert_file), Some(cert_dir)) - .map_err(|e| convert_openssl_error(vm, e)) + // Process result + let (certs, key) = result.map_err(|e| { + // Try to downcast to io::Error to preserve errno information + if let Ok(io_err) = e.downcast::() { + match io_err.kind() { + // File access errors (NotFound, PermissionDenied) - preserve errno + std::io::ErrorKind::NotFound | std::io::ErrorKind::PermissionDenied => { + io_err.into_pyexception(vm) + } + // Other io::Error types + std::io::ErrorKind::Other => { + let msg = io_err.to_string(); + if msg.contains("Failed to decrypt") || msg.contains("wrong password") { + // Wrong password error + vm.new_exception_msg(PySSLError::class(&vm.ctx).to_owned(), msg) + } else { + // [SSL] PEM lib + super::compat::SslError::create_ssl_error_with_reason( + vm, "SSL", "", "PEM lib", + ) + } + } + // PEM parsing errors - [SSL] PEM lib + _ => super::compat::SslError::create_ssl_error_with_reason( + vm, "SSL", "", "PEM lib", + ), + } } else { - self.builder() - .set_default_verify_paths() - .map_err(|e| convert_openssl_error(vm, e)) + // Unknown error type - [SSL] PEM lib + super::compat::SslError::create_ssl_error_with_reason(vm, "SSL", "", "PEM lib") } - } - } + })?; - #[pymethod] - fn _set_alpn_protocols(&self, protos: ArgBytesLike, vm: &VirtualMachine) -> PyResult<()> { - #[cfg(ossl102)] - { - let mut ctx = self.builder(); - let server = protos.with_ref(|pbuf| { - if pbuf.len() > libc::c_uint::MAX as usize { - return Err(vm.new_overflow_error(format!( - "protocols longer than {} bytes", - libc::c_uint::MAX - ))); + // Validate certificate and key match + cert::validate_cert_key_match(&certs, &key).map_err(|e| { + vm.new_exception_msg( + PySSLError::class(&vm.ctx).to_owned(), + if e.contains("key values mismatch") { + "[SSL: KEY_VALUES_MISMATCH] key values mismatch".to_owned() + } else { + e + }, + ) + })?; + + // Auto-build certificate chain: if only leaf cert is in file, try to add CA certs + // This matches OpenSSL behavior where it automatically includes intermediate/CA certs + let mut full_chain = certs.clone(); + if full_chain.len() == 1 { + // Only have leaf cert, try to build chain from CA certs + let ca_certs_der = self.ca_certs_der.read(); + if !ca_certs_der.is_empty() { + // Use build_verified_chain to construct full chain + let chain_result = cert::build_verified_chain(&full_chain, &ca_certs_der); + if chain_result.len() > 1 { + // Successfully built a longer chain + full_chain = chain_result.into_iter().map(CertificateDer::from).collect(); } - ctx.set_alpn_protos(pbuf) - .map_err(|e| convert_openssl_error(vm, e))?; - Ok(pbuf.to_vec()) + } + } + + // Additional validation: Create CertifiedKey to ensure rustls accepts it + let signing_key = + rustls::crypto::aws_lc_rs::sign::any_supported_type(&key).map_err(|_| { + vm.new_exception_msg( + PySSLError::class(&vm.ctx).to_owned(), + "[SSL: KEY_VALUES_MISMATCH] key values mismatch".to_owned(), + ) })?; - ctx.set_alpn_select_callback(move |_, client| { - let proto = - ssl::select_next_proto(&server, client).ok_or(ssl::AlpnError::NOACK)?; - let pos = memchr::memmem::find(client, proto) - .expect("selected alpn proto should be present in client protos"); - Ok(&client[pos..proto.len()]) - }); - Ok(()) - } - #[cfg(not(ossl102))] - { - Err(vm.new_not_implemented_error( - "The NPN extension requires OpenSSL 1.0.1 or later.", - )) + + let certified_key = CertifiedKey::new(full_chain.clone(), signing_key); + if certified_key.keys_match().is_err() { + return Err(vm.new_exception_msg( + PySSLError::class(&vm.ctx).to_owned(), + "[SSL: KEY_VALUES_MISMATCH] key values mismatch".to_owned(), + )); } + + // Add cert/key pair to collection (OpenSSL allows multiple cert/key pairs) + // Store both CertifiedKey (for server) and PrivateKeyDer (for client mTLS) + let cert_der = &full_chain[0]; + let mut cert_keys = self.cert_keys.write(); + + // Remove any existing cert/key pair with the same certificate + // (This allows updating cert/key pair without duplicating) + cert_keys.retain(|(existing, _)| &existing.cert[0] != cert_der); + + // Add new cert/key pair as tuple + cert_keys.push((Arc::new(certified_key), key)); + + Ok(()) } #[pymethod] @@ -1155,264 +1337,357 @@ mod _ssl { args: LoadVerifyLocationsArgs, vm: &VirtualMachine, ) -> PyResult<()> { - if let (None, None, None) = (&args.cafile, &args.capath, &args.cadata) { - return Err(vm.new_type_error("cafile, capath and cadata cannot be all omitted")); - } + // Check that at least one argument is provided + let has_cafile = matches!(&args.cafile, OptionalArg::Present(Some(_))); + let has_capath = matches!(&args.capath, OptionalArg::Present(Some(_))); + let has_cadata = matches!(&args.cadata, OptionalArg::Present(obj) if !vm.is_none(obj)); - #[cold] - fn invalid_cadata(vm: &VirtualMachine) -> PyBaseExceptionRef { - vm.new_type_error("cadata should be an ASCII string or a bytes-like object") + if !has_cafile && !has_capath && !has_cadata { + return Err( + vm.new_type_error("cafile, capath and cadata cannot be all omitted".to_owned()) + ); } - let mut ctx = self.builder(); + // Get mutable references to store and ca_certs_der + let mut root_store = self.root_certs.write(); + let mut ca_certs_der = self.ca_certs_der.write(); - // validate cadata type and load cadata - if let Some(cadata) = args.cadata { - let certs = match cadata { - Either::A(s) => { - if !s.is_ascii() { - return Err(invalid_cadata(vm)); - } - X509::stack_from_pem(s.as_bytes()) - } - Either::B(b) => b.with_ref(x509_stack_from_der), - }; - let certs = certs.map_err(|e| convert_openssl_error(vm, e))?; - let store = ctx.cert_store_mut(); - for cert in certs { - store - .add_cert(cert) - .map_err(|e| convert_openssl_error(vm, e))?; + // Load from file + if let OptionalArg::Present(Some(ref cafile_obj)) = args.cafile { + let path = Self::parse_path_arg(cafile_obj, vm)?; + + // Try to load as CRL first + if let Some(crl) = self.load_crl_from_file(&path, vm)? { + self.crls.write().push(crl); + } else { + // Not a CRL, load as certificate + let stats = self.load_certs_from_file_helper( + &mut root_store, + &mut ca_certs_der, + &path, + vm, + )?; + self.update_cert_stats(stats); } } - if args.cafile.is_some() || args.capath.is_some() { - let cafile_path = args.cafile.map(|p| p.to_path_buf(vm)).transpose()?; - let capath_path = args.capath.map(|p| p.to_path_buf(vm)).transpose()?; - ctx.load_verify_locations(cafile_path.as_deref(), capath_path.as_deref()) - .map_err(|e| convert_openssl_error(vm, e))?; + // Load from directory (don't add to ca_certs_der) + if let OptionalArg::Present(Some(ref capath_obj)) = args.capath { + let dir_path = Self::parse_path_arg(capath_obj, vm)?; + let stats = self.load_certs_from_dir_helper(&mut root_store, &dir_path, vm)?; + self.update_cert_stats(stats); + } + + // Load from bytes or str + if let OptionalArg::Present(cadata_obj) = args.cadata + && !vm.is_none(&cadata_obj) + { + // Check if input is string or bytes + let is_string = PyStrRef::try_from_object(vm, cadata_obj.clone()).is_ok(); + let data_vec = self.parse_cadata_arg(&cadata_obj, vm)?; + let stats = self.load_certs_from_bytes_helper( + &mut root_store, + &mut ca_certs_der, + &data_vec, + is_string, // PEM only for strings + vm, + )?; + self.update_cert_stats(stats); } Ok(()) } - #[pymethod] - fn get_ca_certs( - &self, - binary_form: OptionalArg, + /// Helper: Get path from Python's os.environ + fn get_env_path( + environ: &PyObjectRef, + var_name: &str, vm: &VirtualMachine, - ) -> PyResult> { - let binary_form = binary_form.unwrap_or(false); - let ctx = self.ctx(); - #[cfg(ossl300)] - let certs = ctx.cert_store().all_certificates(); - #[cfg(not(ossl300))] - let certs = ctx.cert_store().objects().iter().filter_map(|x| x.x509()); - - // Filter to only include CA certificates (Basic Constraints: CA=TRUE) - let certs = certs - .into_iter() - .filter(|cert| { - unsafe { - // X509_check_ca() returns 1 for CA certificates - X509_check_ca(cert.as_ptr()) == 1 - } - }) - .map(|ref cert| cert_to_py(vm, cert, binary_form)) - .collect::, _>>()?; - Ok(certs) + ) -> PyResult { + let path_obj = environ.get_item(var_name, vm)?; + path_obj.try_into_value(vm) } - #[pymethod] - fn cert_store_stats(&self, vm: &VirtualMachine) -> PyResult { - let ctx = self.ctx(); - let store_ptr = unsafe { sys::SSL_CTX_get_cert_store(ctx.as_ptr()) }; + /// Helper: Try to load certificates from Python's os.environ variables + /// + /// Returns true if certificates were successfully loaded. + /// + /// We use Python's os.environ instead of Rust's std::env + /// because Python code can modify os.environ at runtime (e.g., + /// `os.environ['SSL_CERT_FILE'] = '/path'`), but rustls-native-certs uses + /// std::env which only sees the process environment at startup. + fn try_load_from_python_environ( + &self, + loader: &mut cert::CertLoader<'_>, + vm: &VirtualMachine, + ) -> PyResult { + use std::path::Path; - if store_ptr.is_null() { - return Err(vm.new_memory_error("failed to get cert store".to_owned())); - } + let os_module = vm.import("os", 0)?; + let environ = os_module.get_attr("environ", vm)?; - let objs_ptr = unsafe { sys::X509_STORE_get0_objects(store_ptr) }; - if objs_ptr.is_null() { - return Err(vm.new_memory_error("failed to query cert store".to_owned())); + // Try SSL_CERT_FILE first + if let Ok(cert_file) = Self::get_env_path(&environ, "SSL_CERT_FILE", vm) + && Path::new(&cert_file).exists() + && let Ok(stats) = loader.load_from_file(&cert_file) + { + self.update_cert_stats(stats); + return Ok(true); } - let mut x509_count = 0; - let mut crl_count = 0; - let mut ca_count = 0; + // Try SSL_CERT_DIR (only if SSL_CERT_FILE didn't work) + if let Ok(cert_dir) = Self::get_env_path(&environ, "SSL_CERT_DIR", vm) + && Path::new(&cert_dir).is_dir() + && let Ok(stats) = loader.load_from_dir(&cert_dir) + { + self.update_cert_stats(stats); + return Ok(true); + } - unsafe { - let num_objs = sys::OPENSSL_sk_num(objs_ptr as *const _); - for i in 0..num_objs { - let obj_ptr = - sys::OPENSSL_sk_value(objs_ptr as *const _, i) as *const sys::X509_OBJECT; - let obj_type = X509_OBJECT_get_type(obj_ptr); + Ok(false) + } - match obj_type { - X509_LU_X509 => { - x509_count += 1; - let x509_ptr = sys::X509_OBJECT_get0_X509(obj_ptr); - if !x509_ptr.is_null() && X509_check_ca(x509_ptr) == 1 { - ca_count += 1; - } - } - X509_LU_CRL => { - crl_count += 1; - } - _ => { - // Ignore unrecognized types - } + /// Helper: Load system certificates using rustls-native-certs + /// + /// This uses platform-specific methods: + /// - Linux: openssl-probe to find certificate files + /// - macOS: Keychain API + /// - Windows: System certificate store + fn load_system_certificates( + &self, + store: &mut rustls::RootCertStore, + vm: &VirtualMachine, + ) -> PyResult<()> { + let result = rustls_native_certs::load_native_certs(); + + // Load successfully found certificates + for cert in result.certs { + let is_ca = cert::is_ca_certificate(cert.as_ref()); + if store.add(cert).is_ok() { + *self.x509_cert_count.write() += 1; + if is_ca { + *self.ca_cert_count.write() += 1; } } - // Note: No need to free objs_ptr as X509_STORE_get0_objects returns - // a pointer to internal data that should not be freed by the caller } - let dict = vm.ctx.new_dict(); - dict.set_item("x509", vm.ctx.new_int(x509_count).into(), vm)?; - dict.set_item("crl", vm.ctx.new_int(crl_count).into(), vm)?; - dict.set_item("x509_ca", vm.ctx.new_int(ca_count).into(), vm)?; - Ok(dict.into()) - } - - #[pymethod] - fn session_stats(&self, vm: &VirtualMachine) -> PyResult { - let ctx = self.ctx(); - let ctx_ptr = ctx.as_ptr(); - - let dict = vm.ctx.new_dict(); - - macro_rules! add_stat { - ($key:expr, $func:ident) => { - let value = unsafe { $func(ctx_ptr) }; - dict.set_item($key, vm.ctx.new_int(value).into(), vm)?; - }; + // If there were errors but some certs loaded, just continue + // If NO certs loaded and there were errors, report the first error + if *self.x509_cert_count.read() == 0 && !result.errors.is_empty() { + return Err(vm.new_os_error(format!( + "Failed to load native certificates: {}", + result.errors[0] + ))); } - add_stat!("number", SSL_CTX_sess_number); - add_stat!("connect", SSL_CTX_sess_connect); - add_stat!("connect_good", SSL_CTX_sess_connect_good); - add_stat!("connect_renegotiate", SSL_CTX_sess_connect_renegotiate); - add_stat!("accept", SSL_CTX_sess_accept); - add_stat!("accept_good", SSL_CTX_sess_accept_good); - add_stat!("accept_renegotiate", SSL_CTX_sess_accept_renegotiate); - add_stat!("hits", SSL_CTX_sess_hits); - add_stat!("misses", SSL_CTX_sess_misses); - add_stat!("timeouts", SSL_CTX_sess_timeouts); - add_stat!("cache_full", SSL_CTX_sess_cache_full); - - Ok(dict.into()) + Ok(()) } #[pymethod] - fn load_dh_params(&self, filepath: FsPath, vm: &VirtualMachine) -> PyResult<()> { - let path = filepath.to_path_buf(vm)?; - - // Open the file using fopen (cross-platform) - let fp = - rustpython_common::fileutils::fopen(path.as_path(), "rb").map_err(|e| { - match e.kind() { - std::io::ErrorKind::NotFound => vm.new_exception_msg( - vm.ctx.exceptions.file_not_found_error.to_owned(), - e.to_string(), - ), - _ => vm.new_os_error(e.to_string()), - } - })?; + fn load_default_certs( + &self, + _purpose: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult<()> { + let mut store = self.root_certs.write(); - // Read DH parameters - let dh = unsafe { - PEM_read_DHparams( - fp, - std::ptr::null_mut(), - std::ptr::null_mut(), - std::ptr::null_mut(), - ) - }; - unsafe { - libc::fclose(fp); - } + // Create loader (without ca_certs_der - default certs don't go to get_ca_certs()) + let mut lazy_ca_certs = Vec::new(); + let mut loader = cert::CertLoader::new(&mut store, &mut lazy_ca_certs); - if dh.is_null() { - return Err(convert_openssl_error(vm, ErrorStack::get())); - } + // Try Python os.environ first (allows runtime env changes) + // This checks SSL_CERT_FILE and SSL_CERT_DIR from Python's os.environ + let loaded = self.try_load_from_python_environ(&mut loader, vm)?; - // Set temporary DH parameters - let ctx = self.builder(); - let result = unsafe { sys::SSL_CTX_set_tmp_dh(ctx.as_ptr(), dh) }; - unsafe { - sys::DH_free(dh); + // Fallback to system certificates if environment variables didn't provide any + if !loaded { + let _ = self.load_system_certificates(&mut store, vm); } - if result != 1 { - return Err(convert_openssl_error(vm, ErrorStack::get())); + // If no certificates were loaded from system, fallback to webpki-roots (Mozilla CA bundle) + // This ensures we always have some trusted root certificates even if system cert loading fails + if *self.x509_cert_count.read() == 0 { + use webpki_roots; + + // webpki_roots provides TLS_SERVER_ROOTS as &[TrustAnchor] + // We can use extend() to add them to the RootCertStore + let webpki_count = webpki_roots::TLS_SERVER_ROOTS.len(); + store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + + *self.x509_cert_count.write() += webpki_count; + *self.ca_cert_count.write() += webpki_count; } Ok(()) } - #[pygetset] + #[pymethod] + fn set_alpn_protocols(&self, protocols: PyListRef, vm: &VirtualMachine) -> PyResult<()> { + let mut alpn_list = Vec::new(); + for item in protocols.borrow_vec().iter() { + let bytes = ArgBytesLike::try_from_object(vm, item.clone())?; + alpn_list.push(bytes.borrow_buf().to_vec()); + } + *self.alpn_protocols.write() = alpn_list; + Ok(()) + } + + #[pymethod] + fn _set_alpn_protocols(&self, protos: ArgBytesLike, vm: &VirtualMachine) -> PyResult<()> { + let bytes = protos.borrow_buf(); + let alpn_list = parse_length_prefixed_alpn(&bytes, vm)?; + *self.alpn_protocols.write() = alpn_list; + Ok(()) + } + + #[pymethod] + fn set_ciphers(&self, ciphers: PyStrRef, vm: &VirtualMachine) -> PyResult<()> { + let cipher_str = ciphers.as_str(); + + // Parse cipher string and store selected ciphers + let selected_ciphers = parse_cipher_string(cipher_str) + .map_err(|e| vm.new_exception_msg(PySSLError::class(&vm.ctx).to_owned(), e))?; + + // Store in context + *self.selected_ciphers.write() = Some(selected_ciphers); + + Ok(()) + } + + #[pymethod] + fn get_ciphers(&self, vm: &VirtualMachine) -> PyResult { + // Dynamically generate cipher list from rustls ALL_CIPHER_SUITES + // This automatically includes all cipher suites supported by the current rustls version + use rustls::crypto::aws_lc_rs::ALL_CIPHER_SUITES; + + let cipher_list = ALL_CIPHER_SUITES + .iter() + .map(|suite| { + // Extract cipher information using unified helper + let cipher_info = extract_cipher_info(suite); + + // Convert to OpenSSL-style name + // e.g., "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" -> "ECDHE-RSA-AES128-GCM-SHA256" + let openssl_name = normalize_cipher_name(&cipher_info.name); + + // Determine key exchange and auth methods + let (kx, auth) = if cipher_info.protocol == "TLSv1.3" { + // TLS 1.3 doesn't distinguish - all use modern algos + ("any", "any") + } else if cipher_info.name.contains("ECDHE") { + // TLS 1.2 with ECDHE + let auth = if cipher_info.name.contains("ECDSA") { + "ECDSA" + } else if cipher_info.name.contains("RSA") { + "RSA" + } else { + "any" + }; + ("ECDH", auth) + } else { + ("any", "any") + }; + + // Build description string + // Format: "{name} {protocol} Kx={kx} Au={auth} Enc={enc} Mac={mac}" + let enc = get_cipher_encryption_desc(&openssl_name); + + let description = format!( + "{} {} Kx={} Au={} Enc={} Mac=AEAD", + openssl_name, cipher_info.protocol, kx, auth, enc + ); + + // Create cipher dict + let dict = vm.ctx.new_dict(); + dict.set_item("name", vm.ctx.new_str(openssl_name).into(), vm) + .unwrap(); + dict.set_item("protocol", vm.ctx.new_str(cipher_info.protocol).into(), vm) + .unwrap(); + dict.set_item("id", vm.ctx.new_int(0).into(), vm).unwrap(); // Placeholder ID + dict.set_item("strength_bits", vm.ctx.new_int(cipher_info.bits).into(), vm) + .unwrap(); + dict.set_item("alg_bits", vm.ctx.new_int(cipher_info.bits).into(), vm) + .unwrap(); + dict.set_item("description", vm.ctx.new_str(description).into(), vm) + .unwrap(); + dict.into() + }) + .collect::>(); + + Ok(PyListRef::from(vm.ctx.new_list(cipher_list))) + } + + #[pymethod] + fn set_default_verify_paths(&self, vm: &VirtualMachine) -> PyResult<()> { + // Just call load_default_certs + self.load_default_certs(OptionalArg::Missing, vm) + } + + #[pymethod] + fn cert_store_stats(&self, vm: &VirtualMachine) -> PyResult { + // Use the certificate counters that are updated in load_verify_locations + let x509_count = *self.x509_cert_count.read() as i32; + let ca_count = *self.ca_cert_count.read() as i32; + + let dict = vm.ctx.new_dict(); + dict.set_item("x509", vm.ctx.new_int(x509_count).into(), vm)?; + dict.set_item("crl", vm.ctx.new_int(0).into(), vm)?; // CRL not supported + dict.set_item("x509_ca", vm.ctx.new_int(ca_count).into(), vm)?; + Ok(dict.into()) + } + + #[pymethod] + fn session_stats(&self, vm: &VirtualMachine) -> PyResult { + // Return session statistics + // NOTE: This is a partial implementation - rustls doesn't expose all OpenSSL stats + let dict = vm.ctx.new_dict(); + + // Number of sessions currently in the cache + let session_count = self.client_session_cache.read().len() as i32; + dict.set_item("number", vm.ctx.new_int(session_count).into(), vm)?; + + // Client-side statistics (not tracked separately in this implementation) + dict.set_item("connect", vm.ctx.new_int(0).into(), vm)?; + dict.set_item("connect_good", vm.ctx.new_int(0).into(), vm)?; + dict.set_item("connect_renegotiate", vm.ctx.new_int(0).into(), vm)?; // rustls doesn't support renegotiation + + // Server-side statistics + let accept_count = self.accept_count.load(Ordering::SeqCst) as i32; + dict.set_item("accept", vm.ctx.new_int(accept_count).into(), vm)?; + dict.set_item("accept_good", vm.ctx.new_int(accept_count).into(), vm)?; // Assume all accepts are good + dict.set_item("accept_renegotiate", vm.ctx.new_int(0).into(), vm)?; // rustls doesn't support renegotiation + + // Session reuse statistics + let hits = self.session_hits.load(Ordering::SeqCst) as i32; + dict.set_item("hits", vm.ctx.new_int(hits).into(), vm)?; + + // Misses, timeouts, and cache_full are not tracked in this implementation + dict.set_item("misses", vm.ctx.new_int(0).into(), vm)?; + dict.set_item("timeouts", vm.ctx.new_int(0).into(), vm)?; + dict.set_item("cache_full", vm.ctx.new_int(0).into(), vm)?; + + Ok(dict.into()) + } + + #[pygetset] fn sni_callback(&self) -> Option { - self.sni_callback.lock().clone() + self.sni_callback.read().clone() } #[pygetset(setter)] fn set_sni_callback( &self, - value: Option, + callback: Option, vm: &VirtualMachine, ) -> PyResult<()> { - // Check if this is a server context - if self.protocol == SslVersion::TlsClient { - return Err(vm.new_value_error( - "sni_callback cannot be set on TLS_CLIENT context".to_owned(), - )); - } - - let mut callback_guard = self.sni_callback.lock(); - - if let Some(callback_obj) = value { - if !vm.is_none(&callback_obj) { - // Check if callable - if !callback_obj.is_callable() { - return Err(vm.new_type_error("not a callable object".to_owned())); - } - - // Set the callback - *callback_guard = Some(callback_obj); - - // Set OpenSSL callback - unsafe { - sys::SSL_CTX_set_tlsext_servername_callback__fixed_rust( - self.ctx().as_ptr(), - Some(_servername_callback), - ); - sys::SSL_CTX_set_tlsext_servername_arg( - self.ctx().as_ptr(), - self as *const _ as *mut _, - ); - } - } else { - // Clear callback - *callback_guard = None; - unsafe { - sys::SSL_CTX_set_tlsext_servername_callback__fixed_rust( - self.ctx().as_ptr(), - None, - ); - } - } - } else { - // Clear callback - *callback_guard = None; - unsafe { - sys::SSL_CTX_set_tlsext_servername_callback__fixed_rust( - self.ctx().as_ptr(), - None, - ); - } + // Validate callback is callable or None + if let Some(ref cb) = callback + && !cb.is(vm.ctx.types.none_type) + && !cb.is_callable() + { + return Err(vm.new_type_error("sni_callback must be callable or None")); } - + *self.sni_callback.write() = callback; Ok(()) } @@ -1422,156 +1697,258 @@ mod _ssl { callback: Option, vm: &VirtualMachine, ) -> PyResult<()> { + // Alias for set_sni_callback self.set_sni_callback(callback, vm) } - #[pymethod] - fn load_cert_chain(&self, args: LoadCertChainArgs, vm: &VirtualMachine) -> PyResult<()> { - let LoadCertChainArgs { - certfile, - keyfile, - password, - } = args; - // TODO: requires passing a callback to C - if password.is_some() { - return Err(vm.new_not_implemented_error("password arg not yet supported")); - } - let mut ctx = self.builder(); - let key_path = keyfile.map(|path| path.to_path_buf(vm)).transpose()?; - let cert_path = certfile.to_path_buf(vm)?; - ctx.set_certificate_chain_file(&cert_path) - .and_then(|()| { - ctx.set_private_key_file( - key_path.as_ref().unwrap_or(&cert_path), - ssl::SslFiletype::PEM, - ) - }) - .and_then(|()| ctx.check_private_key()) - .map_err(|e| convert_openssl_error(vm, e)) + #[pygetset] + fn security_level(&self) -> i32 { + // rustls uses a fixed security level + // Return 2 which is a reasonable default (equivalent to OpenSSL 1.1.0+ level 2) + 2 + } + + #[pygetset] + fn _msg_callback(&self) -> Option { + self.msg_callback.read().clone() + } + + #[pygetset(setter)] + fn set__msg_callback( + &self, + callback: Option, + vm: &VirtualMachine, + ) -> PyResult<()> { + // Validate callback is callable or None + if let Some(ref cb) = callback + && !cb.is(vm.ctx.types.none_type) + && !cb.is_callable() + { + return Err(vm.new_type_error("msg_callback must be callable or None")); + } + *self.msg_callback.write() = callback; + Ok(()) } #[pymethod] - fn _wrap_socket( - zelf: PyRef, - args: WrapSocketArgs, + fn get_ca_certs( + &self, + binary_form: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { - // validate socket type and context protocol - if !args.server_side && zelf.protocol == SslVersion::TlsServer { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(), - )); + ) -> PyResult { + let binary_form = binary_form.unwrap_or(false); + let ca_certs_der = self.ca_certs_der.read(); + + let mut certs = Vec::new(); + for cert_der in ca_certs_der.iter() { + // Parse certificate to check if it's a CA and get info + match x509_parser::parse_x509_certificate(cert_der) { + Ok((_, cert)) => { + // Check if this is a CA certificate (BasicConstraints: CA=TRUE) + let is_ca = if let Ok(Some(bc_ext)) = cert.basic_constraints() { + bc_ext.value.ca + } else { + false + }; + + // Only include CA certificates + if !is_ca { + continue; + } + + if binary_form { + // Return DER-encoded certificate as bytes + certs.push(vm.ctx.new_bytes(cert_der.clone()).into()); + } else { + // Return certificate as dict (use helper from _test_decode_cert) + let dict = self.cert_der_to_dict(vm, cert_der)?; + certs.push(dict); + } + } + Err(_) => { + // Skip invalid certificates + continue; + } + } } - if args.server_side && zelf.protocol == SslVersion::TlsClient { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(), + + Ok(PyListRef::from(vm.ctx.new_list(certs))) + } + + #[pymethod] + fn load_dh_params(&self, filepath: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + // Validate filepath is not None + if vm.is_none(&filepath) { + return Err(vm.new_type_error("DH params filepath cannot be None".to_owned())); + } + + // Validate filepath is str or bytes + let path_str = if let Ok(s) = PyStrRef::try_from_object(vm, filepath.clone()) { + s.as_str().to_owned() + } else if let Ok(b) = ArgBytesLike::try_from_object(vm, filepath) { + String::from_utf8(b.borrow_buf().to_vec()) + .map_err(|_| vm.new_value_error("Invalid path encoding".to_owned()))? + } else { + return Err(vm.new_type_error("DH params filepath must be str or bytes".to_owned())); + }; + + // Check if file exists + if !std::path::Path::new(&path_str).exists() { + // Create FileNotFoundError with errno=ENOENT (2) using args + let exc = vm.new_exception( + vm.ctx.exceptions.file_not_found_error.to_owned(), + vec![ + vm.ctx.new_int(2).into(), // errno = ENOENT (2) + vm.ctx.new_str("No such file or directory").into(), + vm.ctx.new_str(path_str.clone()).into(), // filename + ], + ); + return Err(exc); + } + + // Validate that the file contains DH parameters + // Read the file and check for DH PARAMETERS header + let contents = + std::fs::read_to_string(&path_str).map_err(|e| vm.new_os_error(e.to_string()))?; + + if !contents.contains("BEGIN DH PARAMETERS") + && !contents.contains("BEGIN X9.42 DH PARAMETERS") + { + // File exists but doesn't contain DH parameters - raise SSLError + // [PEM: NO_START_LINE] no start line + return Err(super::compat::SslError::create_ssl_error_with_reason( + vm, + "PEM", + "NO_START_LINE", + "[PEM: NO_START_LINE] no start line", )); } - let mut ssl = ssl::Ssl::new(&zelf.ctx()).map_err(|e| convert_openssl_error(vm, e))?; + // rustls doesn't use DH parameters (it uses ECDHE for key exchange) + // This is a no-op for compatibility with OpenSSL-based code + Ok(()) + } + + #[pymethod] + fn set_ecdh_curve(&self, name: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + // Validate name is not None + if vm.is_none(&name) { + return Err(vm.new_type_error("ECDH curve name cannot be None".to_owned())); + } - let socket_type = if args.server_side { - ssl.set_accept_state(); - SslServerOrClient::Server + // Validate name is str or bytes + let curve_name = if let Ok(s) = PyStrRef::try_from_object(vm, name.clone()) { + s.as_str().to_owned() + } else if let Ok(b) = ArgBytesLike::try_from_object(vm, name) { + String::from_utf8(b.borrow_buf().to_vec()) + .map_err(|_| vm.new_value_error("Invalid curve name encoding".to_owned()))? } else { - ssl.set_connect_state(); - SslServerOrClient::Client + return Err(vm.new_type_error("ECDH curve name must be str or bytes".to_owned())); }; - if let Some(hostname) = &args.server_hostname { - let hostname = hostname.as_str(); - if hostname.is_empty() || hostname.starts_with('.') { - return Err(vm.new_value_error( - "server_hostname cannot be an empty string or start with a leading dot.", - )); - } - if hostname.contains('\0') { - return Err(vm.new_value_error("embedded null byte in server_hostname")); - } - let ip = hostname.parse::(); - if ip.is_err() { - ssl.set_hostname(hostname) - .map_err(|e| convert_openssl_error(vm, e))?; - } - if zelf.check_hostname.load() { - if let Ok(ip) = ip { - ssl.param_mut() - .set_ip(ip) - .map_err(|e| convert_openssl_error(vm, e))?; - } else { - ssl.param_mut() - .set_host(hostname) - .map_err(|e| convert_openssl_error(vm, e))?; - } - } + // Validate curve name (common curves for compatibility) + // rustls supports: X25519, secp256r1 (prime256v1), secp384r1 + let valid_curves = [ + "prime256v1", + "secp256r1", + "prime384v1", + "secp384r1", + "prime521v1", + "secp521r1", + "X25519", + "x25519", + "x448", // For future compatibility + ]; + + if !valid_curves.contains(&curve_name.as_str()) { + return Err(vm.new_value_error(format!("unknown curve name '{curve_name}'"))); } - // Configure post-handshake authentication (PHA) - #[cfg(ossl111)] - if *zelf.post_handshake_auth.lock() { - unsafe { - if args.server_side { - // Server socket: add SSL_VERIFY_POST_HANDSHAKE flag - // Only in combination with SSL_VERIFY_PEER - let mode = sys::SSL_get_verify_mode(ssl.as_ptr()); - if (mode & sys::SSL_VERIFY_PEER as libc::c_int) != 0 { - // Add POST_HANDSHAKE flag (keep existing flags including FAIL_IF_NO_PEER_CERT) - sys::SSL_set_verify( - ssl.as_ptr(), - mode | SSL_VERIFY_POST_HANDSHAKE, - None, - ); - } - } else { - // Client socket: call SSL_set_post_handshake_auth - SSL_set_post_handshake_auth(ssl.as_ptr(), 1); + // Store the curve name to be used during handshake + // This will limit the key exchange groups offered/accepted + *self.ecdh_curve.write() = Some(curve_name); + Ok(()) + } + + #[pymethod] + fn _wrap_socket( + zelf: PyRef, + args: WrapSocketArgs, + vm: &VirtualMachine, + ) -> PyResult> { + // Convert server_hostname to Option + // Handle both missing argument and None value + let hostname = match args.server_hostname.into_option().flatten() { + Some(hostname_str) => { + let hostname = hostname_str.as_str(); + + // Validate hostname + if hostname.is_empty() { + return Err(vm.new_value_error("server_hostname cannot be an empty string")); } - } - } - let stream = ssl::SslStream::new(ssl, SocketStream(args.sock.clone())) - .map_err(|e| convert_openssl_error(vm, e))?; + // Check if it starts with a dot + if hostname.starts_with('.') { + return Err(vm.new_value_error("server_hostname cannot start with a dot")); + } - let py_ssl_socket = PySslSocket { - ctx: PyRwLock::new(zelf.clone()), - connection: PyRwLock::new(SslConnection::Socket(stream)), - socket_type, - server_hostname: args.server_hostname, - owner: PyRwLock::new(args.owner.map(|o| o.downgrade(None, vm)).transpose()?), - }; + // Check if it's a bare IP address (not allowed for SNI) + if hostname.parse::().is_ok() { + return Err(vm.new_value_error("server_hostname cannot be an IP address")); + } + + // Check for NULL bytes + if hostname.contains('\0') { + return Err(vm.new_type_error("embedded null character")); + } - // Convert to PyRef (heap allocation) to avoid use-after-free - let py_ref = - py_ssl_socket.into_ref_with_type(vm, PySslSocket::class(&vm.ctx).to_owned())?; - - // Set SNI callback data if callback is configured - if zelf.sni_callback.lock().is_some() { - unsafe { - let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); - - // Store callback data in SSL ex_data - let callback_data = Box::new(SniCallbackData { - ssl_context: zelf.clone(), - vm_ptr: vm as *const _, - }); - let idx = get_sni_ex_data_index(); - sys::SSL_set_ex_data(ssl_ptr, idx, Box::into_raw(callback_data) as *mut _); - - // Store PyRef pointer (heap-allocated) in ex_data index 0 - sys::SSL_set_ex_data(ssl_ptr, 0, &*py_ref as *const _ as *mut _); + Some(hostname.to_string()) } - } + None => None, + }; - // Set session if provided - if let Some(session) = args.session - && !vm.is_none(&session) - { - py_ref.set_session(session, vm)?; + // Validate socket type and context protocol + if args.server_side && zelf.protocol == PROTOCOL_TLS_CLIENT { + return Err(vm.new_exception_msg( + PySSLError::class(&vm.ctx).to_owned(), + "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(), + )); + } + if !args.server_side && zelf.protocol == PROTOCOL_TLS_SERVER { + return Err(vm.new_exception_msg( + PySSLError::class(&vm.ctx).to_owned(), + "Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(), + )); } - Ok(py_ref.into()) + // Create _SSLSocket instance + let ssl_socket = PySSLSocket { + sock: args.sock.clone(), + context: PyRwLock::new(zelf), + server_side: args.server_side, + server_hostname: PyRwLock::new(hostname), + connection: PyMutex::new(None), + handshake_done: PyMutex::new(false), + session_was_reused: PyMutex::new(false), + owner: PyRwLock::new(args.owner.into_option()), + // Filter out Python None objects - only store actual SSLSession objects + session: PyRwLock::new(args.session.into_option().filter(|s| !vm.is_none(s))), + verified_chain: PyRwLock::new(None), + incoming_bio: None, + outgoing_bio: None, + sni_state: PyRwLock::new(None), + pending_context: PyRwLock::new(None), + client_hello_buffer: PyMutex::new(None), + shutdown_state: PyMutex::new(ShutdownState::NotStarted), + deferred_cert_error: Arc::new(ParkingRwLock::new(None)), + }; + + // Create PyRef with correct type + let ssl_socket_ref = ssl_socket + .into_ref_with_type(vm, vm.class("_ssl", "_SSLSocket")) + .map_err(|_| vm.new_type_error("Failed to create SSLSocket"))?; + + Ok(ssl_socket_ref) } #[pymethod] @@ -1579,1911 +1956,2546 @@ mod _ssl { zelf: PyRef, args: WrapBioArgs, vm: &VirtualMachine, - ) -> PyResult { - // validate socket type and context protocol - if !args.server_side && zelf.protocol == SslVersion::TlsServer { + ) -> PyResult> { + // Convert server_hostname to Option + // Handle both missing argument and None value + let hostname = match args.server_hostname.into_option().flatten() { + Some(hostname_str) => { + let hostname = hostname_str.as_str(); + validate_hostname(hostname, vm)?; + Some(hostname.to_string()) + } + None => None, + }; + + // Extract server_side value + let server_side = args.server_side.unwrap_or(false); + + // Validate socket type and context protocol + if server_side && zelf.protocol == PROTOCOL_TLS_CLIENT { return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(), + PySSLError::class(&vm.ctx).to_owned(), + "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(), )); } - if args.server_side && zelf.protocol == SslVersion::TlsClient { + if !server_side && zelf.protocol == PROTOCOL_TLS_SERVER { return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(), + PySSLError::class(&vm.ctx).to_owned(), + "Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(), )); } - let mut ssl = ssl::Ssl::new(&zelf.ctx()).map_err(|e| convert_openssl_error(vm, e))?; + // Create _SSLSocket instance with BIO mode + let ssl_socket = PySSLSocket { + sock: vm.ctx.none(), // No socket in BIO mode + context: PyRwLock::new(zelf), + server_side, + server_hostname: PyRwLock::new(hostname), + connection: PyMutex::new(None), + handshake_done: PyMutex::new(false), + session_was_reused: PyMutex::new(false), + owner: PyRwLock::new(args.owner.into_option()), + // Filter out Python None objects - only store actual SSLSession objects + session: PyRwLock::new(args.session.into_option().filter(|s| !vm.is_none(s))), + verified_chain: PyRwLock::new(None), + incoming_bio: Some(args.incoming), + outgoing_bio: Some(args.outgoing), + sni_state: PyRwLock::new(None), + pending_context: PyRwLock::new(None), + client_hello_buffer: PyMutex::new(None), + shutdown_state: PyMutex::new(ShutdownState::NotStarted), + deferred_cert_error: Arc::new(ParkingRwLock::new(None)), + }; + + let ssl_socket_ref = ssl_socket + .into_ref_with_type(vm, vm.class("_ssl", "_SSLSocket")) + .map_err(|_| vm.new_type_error("Failed to create SSLSocket"))?; + + Ok(ssl_socket_ref) + } + + // Helper functions (private): - let socket_type = if args.server_side { - ssl.set_accept_state(); - SslServerOrClient::Server + /// Parse path argument (str or bytes) to string + fn parse_path_arg(arg: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Ok(s) = PyStrRef::try_from_object(vm, arg.clone()) { + Ok(s.as_str().to_owned()) + } else if let Ok(b) = ArgBytesLike::try_from_object(vm, arg.clone()) { + String::from_utf8(b.borrow_buf().to_vec()) + .map_err(|_| vm.new_value_error("path contains invalid UTF-8".to_owned())) } else { - ssl.set_connect_state(); - SslServerOrClient::Client - }; + Err(vm.new_type_error("path should be a str or bytes".to_owned())) + } + } - if let Some(hostname) = &args.server_hostname { - let hostname = hostname.as_str(); - if hostname.is_empty() || hostname.starts_with('.') { - return Err(vm.new_value_error( - "server_hostname cannot be an empty string or start with a leading dot.", - )); - } - if hostname.contains('\0') { - return Err(vm.new_value_error("embedded null byte in server_hostname")); - } - let ip = hostname.parse::(); - if ip.is_err() { - ssl.set_hostname(hostname) - .map_err(|e| convert_openssl_error(vm, e))?; - } - if zelf.check_hostname.load() { - if let Ok(ip) = ip { - ssl.param_mut() - .set_ip(ip) - .map_err(|e| convert_openssl_error(vm, e))?; + /// Parse password argument (str, bytes-like, or callable) + /// + /// Returns (immediate_password, callable) where: + /// - immediate_password: Some(string) if password is str/bytes, None if callable + /// - callable: Some(PyObjectRef) if password is callable, None otherwise + fn parse_password_argument( + password: &OptionalArg, + vm: &VirtualMachine, + ) -> PyResult<(Option, Option)> { + match password { + OptionalArg::Present(p) => { + // Try string first + if let Ok(pwd_str) = PyStrRef::try_from_object(vm, p.clone()) { + Ok((Some(pwd_str.as_str().to_owned()), None)) + } + // Try bytes-like + else if let Ok(pwd_bytes_like) = ArgBytesLike::try_from_object(vm, p.clone()) + { + let pwd = String::from_utf8(pwd_bytes_like.borrow_buf().to_vec()).map_err( + |_| vm.new_type_error("password bytes must be valid UTF-8".to_owned()), + )?; + Ok((Some(pwd), None)) + } + // Try callable + else if p.is_callable() { + Ok((None, Some(p.clone()))) } else { - ssl.param_mut() - .set_host(hostname) - .map_err(|e| convert_openssl_error(vm, e))?; + Err(vm.new_type_error( + "password should be a string, bytes, or callable".to_owned(), + )) } } + _ => Ok((None, None)), } + } - // Don't use SSL_set_bio - let SslStream drive I/O through BioStream Read/Write - - // Configure post-handshake authentication (PHA) - #[cfg(ossl111)] - if *zelf.post_handshake_auth.lock() { - unsafe { - if args.server_side { - // Server socket: add SSL_VERIFY_POST_HANDSHAKE flag - // Only in combination with SSL_VERIFY_PEER - let mode = sys::SSL_get_verify_mode(ssl.as_ptr()); - if (mode & sys::SSL_VERIFY_PEER as libc::c_int) != 0 { - // Add POST_HANDSHAKE flag (keep existing flags including FAIL_IF_NO_PEER_CERT) - sys::SSL_set_verify( - ssl.as_ptr(), - mode | SSL_VERIFY_POST_HANDSHAKE, - None, - ); - } - } else { - // Client socket: call SSL_set_post_handshake_auth - SSL_set_post_handshake_auth(ssl.as_ptr(), 1); + /// Helper: Load certificates from file into existing store + fn load_certs_from_file_helper( + &self, + root_store: &mut RootCertStore, + ca_certs_der: &mut Vec>, + path: &str, + vm: &VirtualMachine, + ) -> PyResult { + let mut loader = cert::CertLoader::new(root_store, ca_certs_der); + loader.load_from_file(path).map_err(|e| { + // Preserve errno for file access errors (NotFound, PermissionDenied) + match e.kind() { + std::io::ErrorKind::NotFound | std::io::ErrorKind::PermissionDenied => { + e.into_pyexception(vm) } + // PEM parsing errors + _ => super::compat::SslError::create_ssl_error_with_reason( + vm, "X509", "", "PEM lib", + ), } - } + }) + } - // Create a BioStream wrapper (dummy, actual IO goes through BIOs) - let bio_stream = BioStream { - inbio: args.incoming, - outbio: args.outgoing, - }; + /// Helper: Load certificates from directory into existing store + fn load_certs_from_dir_helper( + &self, + root_store: &mut RootCertStore, + path: &str, + vm: &VirtualMachine, + ) -> PyResult { + // Load certs and store them in capath_certs_der for lazy loading simulation + // (CPython only returns these in get_ca_certs() after they're used in handshake) + let mut capath_certs = Vec::new(); + let mut loader = cert::CertLoader::new(root_store, &mut capath_certs); + let stats = loader + .load_from_dir(path) + .map_err(|e| e.into_pyexception(vm))?; + + // Store loaded certs for potential tracking after handshake + *self.capath_certs_der.write() = capath_certs; + + Ok(stats) + } + + /// Helper: Load certificates from bytes into existing store + fn load_certs_from_bytes_helper( + &self, + root_store: &mut RootCertStore, + ca_certs_der: &mut Vec>, + data: &[u8], + pem_only: bool, + vm: &VirtualMachine, + ) -> PyResult { + let mut loader = cert::CertLoader::new(root_store, ca_certs_der); + // treat_all_as_ca=true: CPython counts all certificates loaded via cadata as CA certs + // regardless of their Basic Constraints extension + // pem_only=true for string input + loader + .load_from_bytes_ex(data, true, pem_only) + .map_err(|e| { + // Preserve specific error messages from cert.rs + let err_msg = e.to_string(); + if err_msg.contains("no start line") { + // no start line: cadata does not contain a certificate + vm.new_exception_msg( + PySSLError::class(&vm.ctx).to_owned(), + "no start line: cadata does not contain a certificate".to_string(), + ) + } else if err_msg.contains("not enough data") { + // not enough data: cadata does not contain a certificate + vm.new_exception_msg( + PySSLError::class(&vm.ctx).to_owned(), + "not enough data: cadata does not contain a certificate".to_string(), + ) + } else { + // Generic PEM error + vm.new_exception_msg(PySSLError::class(&vm.ctx).to_owned(), err_msg) + } + }) + } - // Create SslStream with BioStream - let stream = - ssl::SslStream::new(ssl, bio_stream).map_err(|e| convert_openssl_error(vm, e))?; + /// Helper: Try to parse data as CRL (PEM or DER format) + fn try_parse_crl( + &self, + data: &[u8], + ) -> Result, String> { + // Try PEM format first + let mut cursor = std::io::Cursor::new(data); + let mut crl_iter = rustls_pemfile::crls(&mut cursor); + if let Some(Ok(crl)) = crl_iter.next() { + return Ok(crl); + } - let py_ssl_socket = PySslSocket { - ctx: PyRwLock::new(zelf.clone()), - connection: PyRwLock::new(SslConnection::Bio(stream)), - socket_type, - server_hostname: args.server_hostname, - owner: PyRwLock::new(args.owner.map(|o| o.downgrade(None, vm)).transpose()?), - }; + // Try DER format + // Basic validation: CRL should start with SEQUENCE tag (0x30) + if !data.is_empty() && data[0] == 0x30 { + return Ok(CertificateRevocationListDer::from(data.to_vec())); + } - // Convert to PyRef (heap allocation) to avoid use-after-free - let py_ref = - py_ssl_socket.into_ref_with_type(vm, PySslSocket::class(&vm.ctx).to_owned())?; - - // Set SNI callback data if callback is configured - if zelf.sni_callback.lock().is_some() { - unsafe { - let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); - - // Store callback data in SSL ex_data - let callback_data = Box::new(SniCallbackData { - ssl_context: zelf.clone(), - vm_ptr: vm as *const _, - }); - let idx = get_sni_ex_data_index(); - sys::SSL_set_ex_data(ssl_ptr, idx, Box::into_raw(callback_data) as *mut _); - - // Store PyRef pointer (heap-allocated) in ex_data index 0 - sys::SSL_set_ex_data(ssl_ptr, 0, &*py_ref as *const _ as *mut _); + Err("Not a valid CRL file".to_string()) + } + + /// Helper: Load CRL from file + fn load_crl_from_file( + &self, + path: &str, + vm: &VirtualMachine, + ) -> PyResult>> { + let data = std::fs::read(path).map_err(|e| match e.kind() { + std::io::ErrorKind::NotFound | std::io::ErrorKind::PermissionDenied => { + e.into_pyexception(vm) } + _ => vm.new_os_error(e.to_string()), + })?; + + match self.try_parse_crl(&data) { + Ok(crl) => Ok(Some(crl)), + Err(_) => Ok(None), // Not a CRL file, might be a cert file } + } - // Set session if provided - if let Some(session) = args.session - && !vm.is_none(&session) - { - py_ref.set_session(session, vm)?; + /// Helper: Parse cadata argument (str or bytes) + fn parse_cadata_arg(&self, arg: &PyObjectRef, vm: &VirtualMachine) -> PyResult> { + if let Ok(s) = PyStrRef::try_from_object(vm, arg.clone()) { + Ok(s.as_str().as_bytes().to_vec()) + } else if let Ok(b) = ArgBytesLike::try_from_object(vm, arg.clone()) { + Ok(b.borrow_buf().to_vec()) + } else { + Err(vm.new_type_error("cadata should be a str or bytes".to_owned())) } + } - Ok(py_ref.into()) + /// Helper: Update certificate statistics + fn update_cert_stats(&self, stats: cert::CertStats) { + *self.x509_cert_count.write() += stats.total_certs; + *self.ca_cert_count.write() += stats.ca_certs; } } - #[derive(FromArgs)] - #[allow(dead_code)] // Fields will be used when _wrap_bio is fully implemented - struct WrapBioArgs { - incoming: PyRef, - outgoing: PyRef, - server_side: bool, - #[pyarg(any, default)] - server_hostname: Option, - #[pyarg(named, default)] - owner: Option, - #[pyarg(named, default)] - session: Option, - } + impl Constructor for PySSLContext { + type Args = (i32,); - #[derive(FromArgs)] - struct WrapSocketArgs { - sock: PyRef, - server_side: bool, - #[pyarg(any, default)] - server_hostname: Option, - #[pyarg(named, default)] - owner: Option, - #[pyarg(named, default)] - session: Option, - } + fn py_new(cls: PyTypeRef, (protocol,): Self::Args, vm: &VirtualMachine) -> PyResult { + // Validate protocol + match protocol { + PROTOCOL_TLS | PROTOCOL_TLS_CLIENT | PROTOCOL_TLS_SERVER | PROTOCOL_TLSv1_2 + | PROTOCOL_TLSv1_3 => { + // Valid protocols + } + PROTOCOL_TLSv1 | PROTOCOL_TLSv1_1 => { + return Err(vm.new_value_error( + "TLS 1.0 and 1.1 are not supported by rustls for security reasons", + )); + } + _ => { + return Err(vm.new_value_error(format!("invalid protocol version: {protocol}"))); + } + } - #[derive(FromArgs)] - struct LoadVerifyLocationsArgs { - #[pyarg(any, default)] - cafile: Option, - #[pyarg(any, default)] - capath: Option, - #[pyarg(any, default)] - cadata: Option>, - } + // Set default options + // OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 | OP_NO_COMPRESSION | + // OP_CIPHER_SERVER_PREFERENCE | OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE | + // OP_ENABLE_MIDDLEBOX_COMPAT + let default_options = OP_ALL + | OP_NO_SSLv2 + | OP_NO_SSLv3 + | OP_NO_COMPRESSION + | OP_CIPHER_SERVER_PREFERENCE + | OP_SINGLE_DH_USE + | OP_SINGLE_ECDH_USE + | OP_ENABLE_MIDDLEBOX_COMPAT; + + // Set default verify_mode based on protocol + // PROTOCOL_TLS_CLIENT defaults to CERT_REQUIRED + // PROTOCOL_TLS_SERVER defaults to CERT_NONE + let default_verify_mode = if protocol == PROTOCOL_TLS_CLIENT { + CERT_REQUIRED + } else { + CERT_NONE + }; - #[derive(FromArgs)] - struct LoadCertChainArgs { - certfile: FsPath, - #[pyarg(any, optional)] - keyfile: Option, - #[pyarg(any, optional)] - password: Option>, - } + // Set default verify_flags based on protocol + // Both PROTOCOL_TLS_CLIENT and PROTOCOL_TLS_SERVER only set VERIFY_X509_TRUSTED_FIRST + // Note: VERIFY_X509_PARTIAL_CHAIN and VERIFY_X509_STRICT are NOT set here + // - they're only added by create_default_context() in Python's ssl.py + let default_verify_flags = VERIFY_DEFAULT | VERIFY_X509_TRUSTED_FIRST; + + // Set minimum and maximum protocol versions based on protocol constant + // specific protocol versions fix both min and max + let (min_version, max_version) = match protocol { + PROTOCOL_TLSv1_2 => (PROTO_TLSv1_2, PROTO_TLSv1_2), // Only TLS 1.2 + PROTOCOL_TLSv1_3 => (PROTO_TLSv1_3, PROTO_TLSv1_3), // Only TLS 1.3 + _ => (PROTO_MINIMUM_SUPPORTED, PROTO_MAXIMUM_SUPPORTED), // Auto-negotiate + }; - // Err is true if the socket is blocking - type SocketDeadline = Result; + // IMPORTANT: Create shared session cache BEFORE PySSLContext + // Both client_session_cache and PythonClientSessionStore.session_cache + // MUST point to the same HashMap to ensure Python-level and Rustls-level + // sessions are synchronized + let shared_session_cache = Arc::new(ParkingRwLock::new(HashMap::new())); + let rustls_client_store = Arc::new(PythonClientSessionStore { + inner: Arc::new(rustls::client::ClientSessionMemoryCache::new( + SSL_SESSION_CACHE_SIZE, + )), + session_cache: shared_session_cache.clone(), + }); - enum SelectRet { - Nonblocking, - TimedOut, - IsBlocking, - Closed, - Ok, + PySSLContext { + protocol, + check_hostname: PyRwLock::new(protocol == PROTOCOL_TLS_CLIENT), + verify_mode: PyRwLock::new(default_verify_mode), + verify_flags: PyRwLock::new(default_verify_flags), + client_config: PyRwLock::new(None), + server_config: PyRwLock::new(None), + root_certs: PyRwLock::new(RootCertStore::empty()), + ca_certs_der: PyRwLock::new(Vec::new()), + capath_certs_der: PyRwLock::new(Vec::new()), + crls: PyRwLock::new(Vec::new()), + cert_keys: PyRwLock::new(Vec::new()), + options: PyRwLock::new(default_options), + alpn_protocols: PyRwLock::new(Vec::new()), + require_alpn_match: PyRwLock::new(false), + post_handshake_auth: PyRwLock::new(false), + num_tickets: PyRwLock::new(2), // TLS 1.3 default + minimum_version: PyRwLock::new(min_version), + maximum_version: PyRwLock::new(max_version), + sni_callback: PyRwLock::new(None), + msg_callback: PyRwLock::new(None), + ecdh_curve: PyRwLock::new(None), + ca_cert_count: PyRwLock::new(0), + x509_cert_count: PyRwLock::new(0), + // Use the shared cache created above + client_session_cache: shared_session_cache, + rustls_session_store: rustls_client_store, + rustls_server_session_store: rustls::server::ServerSessionMemoryCache::new( + SSL_SESSION_CACHE_SIZE, + ), + server_ticketer: rustls::crypto::aws_lc_rs::Ticketer::new() + .expect("Failed to create shared ticketer for TLS 1.2 session resumption"), + accept_count: AtomicUsize::new(0), + session_hits: AtomicUsize::new(0), + selected_ciphers: PyRwLock::new(None), + } + .into_ref_with_type(vm, cls) + .map(Into::into) + } } - #[derive(Clone, Copy)] - enum SslNeeds { - Read, - Write, + // SSLSocket - represents a TLS-wrapped socket + #[pyattr] + #[pyclass(name = "_SSLSocket", module = "ssl")] + #[derive(Debug, PyPayload)] + pub(crate) struct PySSLSocket { + // Underlying socket + sock: PyObjectRef, + // SSL context + context: PyRwLock>, + // Server-side or client-side + server_side: bool, + // Server hostname for SNI + server_hostname: PyRwLock>, + // TLS connection state + connection: PyMutex>, + // Handshake completed flag + handshake_done: PyMutex, + // Session was reused (for session resumption tracking) + session_was_reused: PyMutex, + // Owner (SSLSocket instance that owns this _SSLSocket) + owner: PyRwLock>, + // Session for resumption + session: PyRwLock>, + // Verified certificate chain (built during verification) + #[allow(dead_code)] + verified_chain: PyRwLock>>>, + // MemoryBIO mode (optional) + incoming_bio: Option>, + outgoing_bio: Option>, + // SNI certificate resolver state (for server-side only) + sni_state: PyRwLock>>>, + // Pending context change (for SNI callback deferred handling) + pending_context: PyRwLock>>, + // Buffer to store ClientHello for connection recreation + client_hello_buffer: PyMutex>>, + // Shutdown state for tracking close-notify exchange + shutdown_state: PyMutex, + // Deferred client certificate verification error (for TLS 1.3) + // Stores error message if client cert verification failed during handshake + // Error is raised on first I/O operation after handshake + // Using Arc to share with the certificate verifier + deferred_cert_error: Arc>>, } - struct SocketStream(PyRef); + // Shutdown state for tracking close-notify exchange + #[derive(Debug, Clone, Copy, PartialEq)] + enum ShutdownState { + NotStarted, // unwrap() not called yet + SentCloseNotify, // close-notify sent, waiting for peer's response + Completed, // unwrap() completed successfully + } - impl SocketStream { - fn timeout_deadline(&self) -> SocketDeadline { - self.0.get_timeout().map(|d| Instant::now() + d) + #[pyclass(with(Constructor), flags(BASETYPE))] + impl PySSLSocket { + // Check if this is BIO mode + pub(crate) fn is_bio_mode(&self) -> bool { + self.incoming_bio.is_some() && self.outgoing_bio.is_some() } - fn select(&self, needs: SslNeeds, deadline: &SocketDeadline) -> SelectRet { - let sock = match self.0.sock_opt() { - Some(s) => s, - None => return SelectRet::Closed, - }; - let deadline = match &deadline { - Ok(deadline) => match deadline.checked_duration_since(Instant::now()) { - Some(deadline) => deadline, - None => return SelectRet::TimedOut, - }, - Err(true) => return SelectRet::IsBlocking, - Err(false) => return SelectRet::Nonblocking, - }; - let res = socket::sock_select( - &sock, - match needs { - SslNeeds::Read => socket::SelectKind::Read, - SslNeeds::Write => socket::SelectKind::Write, - }, - Some(deadline), - ); - match res { - Ok(true) => SelectRet::TimedOut, - _ => SelectRet::Ok, - } + // Get incoming BIO reference (for EOF checking) + pub(crate) fn incoming_bio(&self) -> Option { + self.incoming_bio.as_ref().map(|bio| bio.clone().into()) } - fn socket_needs( - &self, - err: &ssl::Error, - deadline: &SocketDeadline, - ) -> (Option, SelectRet) { - let needs = match err.code() { - ssl::ErrorCode::WANT_READ => Some(SslNeeds::Read), - ssl::ErrorCode::WANT_WRITE => Some(SslNeeds::Write), - _ => None, - }; - let state = needs.map_or(SelectRet::Ok, |needs| self.select(needs, deadline)); - (needs, state) + // Check for deferred certificate verification errors (TLS 1.3) + // If an error exists, raise it and clear it from storage + fn check_deferred_cert_error(&self, vm: &VirtualMachine) -> PyResult<()> { + let error_opt = self.deferred_cert_error.read().clone(); + if let Some(error_msg) = error_opt { + // Clear the error so it's only raised once + *self.deferred_cert_error.write() = None; + // Raise OSError with the stored error message + return Err(vm.new_os_error(error_msg)); + } + Ok(()) } - } - fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef { - vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Underlying socket has been closed.".to_owned(), - ) - } + // Get socket timeout as Duration + pub(crate) fn get_socket_timeout(&self, vm: &VirtualMachine) -> PyResult> { + if self.is_bio_mode() { + return Ok(None); + } - // BIO stream wrapper to implement Read/Write traits for MemoryBIO - struct BioStream { - inbio: PyRef, - outbio: PyRef, - } + // Get timeout from socket + let timeout_obj = self.sock.get_attr("gettimeout", vm)?.call((), vm)?; - impl Read for BioStream { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - // Read from incoming MemoryBIO - unsafe { - let nbytes = sys::BIO_read( - self.inbio.bio, - buf.as_mut_ptr() as *mut _, - buf.len().min(i32::MAX as usize) as i32, - ); - if nbytes < 0 { - // BIO_read returns -1 on error or when no data is available - // Check if it's a retry condition (WANT_READ) - Err(std::io::Error::new( - std::io::ErrorKind::WouldBlock, - "BIO has no data available", - )) + // timeout can be None (blocking), 0.0 (non-blocking), or positive float + if vm.is_none(&timeout_obj) { + // None means blocking forever + Ok(None) + } else { + let timeout_float: f64 = timeout_obj.try_into_value(vm)?; + if timeout_float <= 0.0 { + // 0 means non-blocking + Ok(Some(Duration::from_secs(0))) } else { - Ok(nbytes as usize) + // Positive timeout + Ok(Some(Duration::from_secs_f64(timeout_float))) } } } - } - impl Write for BioStream { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - // Write to outgoing MemoryBIO - unsafe { - let nbytes = sys::BIO_write( - self.outbio.bio, - buf.as_ptr() as *const _, - buf.len().min(i32::MAX as usize) as i32, - ); - if nbytes < 0 { - return Err(std::io::Error::other("BIO write failed")); + // Create and store a session object after successful handshake + fn create_session_after_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { + // Only create session for client-side connections + if self.server_side { + return Ok(()); + } + + // Check if session already exists + let session_opt = self.session.read().clone(); + if let Some(ref s) = session_opt { + if vm.is_none(s) { + } else { + return Ok(()); } - Ok(nbytes as usize) } - } - fn flush(&mut self) -> std::io::Result<()> { - // MemoryBIO doesn't need flushing + // Get server hostname + let server_name = self.server_hostname.read().clone(); + + // Try to get session data from context's session cache + // IMPORTANT: Acquire and release locks quickly to avoid deadlock + let context = self.context.read(); + let session_cache_arc = context.client_session_cache.clone(); + drop(context); // Release context lock ASAP + + let (session_id, creation_time, lifetime) = if let Some(ref name) = server_name { + let key = name.as_bytes().to_vec(); + + // Clone the data we need while holding the lock, then immediately release + let session_data_opt = { + let cache_guard = session_cache_arc.read(); + cache_guard.get(&key).cloned() // Clone Arc> + }; // Lock released here + + if let Some(session_data_arc) = session_data_opt { + let data = session_data_arc.lock(); + let result = (data.session_id.clone(), data.creation_time, data.lifetime); + drop(data); // Explicit unlock + result + } else { + // Create new session ID if not in cache + let time = std::time::SystemTime::now(); + (generate_session_id_from_metadata(name, &time), time, 7200) + } + } else { + // No server name, use defaults + let time = std::time::SystemTime::now(); + (vec![0; 16], time, 7200) + }; + + // Create a new SSLSession object with real metadata + let session = PySSLSession { + // Use dummy session data to indicate we have a ticket + // TLS 1.2+ always uses session tickets/resumption + session_data: vec![1], // Non-empty to indicate has_ticket=True + session_id, + creation_time, + lifetime, + }; + + let py_session = session.into_pyobject(vm); + + *self.session.write() = Some(py_session); + Ok(()) } - } - // Enum to represent different SSL connection modes - enum SslConnection { - Socket(ssl::SslStream), - Bio(ssl::SslStream), - } + // Complete handshake and create session + /// Track which CA certificate from capath was used to verify peer + /// + /// This simulates lazy loading behavior: capath certificates + /// are only added to get_ca_certs() after they're actually used in a handshake. + fn track_used_ca_from_capath(&self) -> Result<(), String> { + let context = self.context.read(); + let capath_certs = context.capath_certs_der.read(); + + // No capath certs to track + if capath_certs.is_empty() { + return Ok(()); + } - impl SslConnection { - // Get a reference to the SSL object - fn ssl(&self) -> &ssl::SslRef { - match self { - SslConnection::Socket(stream) => stream.ssl(), - SslConnection::Bio(stream) => stream.ssl(), + // Get peer certificate chain + let conn_guard = self.connection.lock(); + let conn = conn_guard.as_ref().ok_or("No connection")?; + + let peer_certs = conn.peer_certificates().ok_or("No peer certificates")?; + + if peer_certs.is_empty() { + return Ok(()); } - } - // Get underlying socket stream reference (only for socket mode) - fn get_ref(&self) -> Option<&SocketStream> { - match self { - SslConnection::Socket(stream) => Some(stream.get_ref()), - SslConnection::Bio(_) => None, + // Get the top certificate in the chain (closest to root) + // Note: Server usually doesn't send the root CA, so we check the last cert's issuer + let top_cert_der = peer_certs.last().unwrap(); + let (_, top_cert) = x509_parser::parse_x509_certificate(top_cert_der) + .map_err(|e| format!("Failed to parse top cert: {e}"))?; + + let top_issuer = top_cert.issuer(); + + // Find matching CA in capath certs + for ca_der in capath_certs.iter() { + let (_, ca) = x509_parser::parse_x509_certificate(ca_der) + .map_err(|e| format!("Failed to parse CA: {e}"))?; + + // Check if this CA is self-signed and matches the issuer + if ca.subject() == ca.issuer() // Self-signed (root CA) + && ca.subject() == top_issuer + // Matches top cert's issuer + { + // Check if not already in ca_certs_der + let mut ca_certs_der = context.ca_certs_der.write(); + if !ca_certs_der.iter().any(|c| c == ca_der) { + ca_certs_der.push(ca_der.clone()); + } + break; + } } - } - // Check if this is in BIO mode - fn is_bio(&self) -> bool { - matches!(self, SslConnection::Bio(_)) + Ok(()) } - // Perform SSL handshake - fn do_handshake(&mut self) -> Result<(), ssl::Error> { - match self { - SslConnection::Socket(stream) => stream.do_handshake(), - SslConnection::Bio(stream) => stream.do_handshake(), + fn complete_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { + *self.handshake_done.lock() = true; + + // Check if session was resumed before creating session object + let conn_guard = self.connection.lock(); + if let Some(ref conn) = *conn_guard { + let was_resumed = conn.is_session_resumed(); + *self.session_was_reused.lock() = was_resumed; + + // Update context session statistics if server-side + if self.server_side { + let context = self.context.read(); + // Increment accept count for every successful server handshake + context.accept_count.fetch_add(1, Ordering::SeqCst); + // Increment hits count if session was resumed + if was_resumed { + context.session_hits.fetch_add(1, Ordering::SeqCst); + } + } } - } + drop(conn_guard); - // Write data to SSL connection - fn ssl_write(&mut self, buf: &[u8]) -> Result { - match self { - SslConnection::Socket(stream) => stream.ssl_write(buf), - SslConnection::Bio(stream) => stream.ssl_write(buf), + // Track CA certificate used during handshake (client-side only) + // This simulates lazy loading behavior for capath certificates + if !self.server_side { + // Don't fail handshake if tracking fails + let _ = self.track_used_ca_from_capath(); } + + self.create_session_after_handshake(vm)?; + Ok(()) } - // Read data from SSL connection - fn ssl_read(&mut self, buf: &mut [u8]) -> Result { - match self { - SslConnection::Socket(stream) => stream.ssl_read(buf), - SslConnection::Bio(stream) => stream.ssl_read(buf), + // Internal implementation with timeout control + pub(crate) fn sock_wait_for_io_impl( + &self, + kind: SelectKind, + vm: &VirtualMachine, + ) -> PyResult { + if self.is_bio_mode() { + // BIO mode doesn't use select + return Ok(false); } - } - // Get SSL shutdown state - fn get_shutdown(&mut self) -> ssl::ShutdownState { - match self { - SslConnection::Socket(stream) => stream.get_shutdown(), - SslConnection::Bio(stream) => stream.get_shutdown(), + // Get timeout + let timeout = self.get_socket_timeout(vm)?; + + // Check for non-blocking mode (timeout = 0) + if let Some(t) = timeout + && t.is_zero() + { + // Non-blocking mode - don't use select + return Ok(false); } - } - } - #[pyattr] - #[pyclass(module = "ssl", name = "_SSLSocket", traverse)] - #[derive(PyPayload)] - struct PySslSocket { - ctx: PyRwLock>, - #[pytraverse(skip)] - connection: PyRwLock, - #[pytraverse(skip)] - socket_type: SslServerOrClient, - server_hostname: Option, - owner: PyRwLock>>, - } + // Use select with the effective timeout + let py_socket: PyRef = self.sock.clone().try_into_value(vm)?; + let socket = py_socket + .sock() + .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; - impl fmt::Debug for PySslSocket { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.pad("_SSLSocket") + let timed_out = sock_select(&socket, kind, timeout) + .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; + + Ok(timed_out) } - } - #[pyclass(flags(IMMUTABLETYPE))] - impl PySslSocket { - #[pygetset] - fn owner(&self) -> Option { - self.owner.read().as_ref().and_then(|weak| weak.upgrade()) + // SNI (Server Name Indication) Helper Methods: + // These methods support the server-side handshake SNI callback mechanism + + /// Check if this is the first read during handshake (for SNI callback) + /// Returns true if we haven't processed ClientHello yet, regardless of SNI presence + pub(crate) fn is_first_sni_read(&self) -> bool { + self.client_hello_buffer.lock().is_none() } - #[pygetset(setter)] - fn set_owner(&self, owner: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - let mut lock = self.owner.write(); - lock.take(); - *lock = Some(owner.downgrade(None, vm)?); - Ok(()) + + /// Check if SNI callback is configured + pub(crate) fn has_sni_callback(&self) -> bool { + self.context.read().sni_callback.read().is_some() } - #[pygetset] - fn server_side(&self) -> bool { - self.socket_type == SslServerOrClient::Server + + /// Save ClientHello data from PyObjectRef for potential connection recreation + pub(crate) fn save_client_hello_from_bytes(&self, bytes_data: &[u8]) { + *self.client_hello_buffer.lock() = Some(bytes_data.to_vec()); } - #[pygetset] - fn context(&self) -> PyRef { - self.ctx.read().clone() + + /// Get the extracted SNI name from resolver + pub(crate) fn get_extracted_sni_name(&self) -> Option { + self.sni_state + .read() + .as_ref() + .and_then(|arc| arc.lock().1.clone()) } - #[pygetset(setter)] - fn set_context(&self, value: PyRef, vm: &VirtualMachine) -> PyResult<()> { - // Update the SSL context in the underlying SSL object - let stream = self.connection.read(); - - // Set the new SSL_CTX on the SSL object - unsafe { - let result = SSL_set_SSL_CTX(stream.ssl().as_ptr(), value.ctx().as_ptr()); - if result.is_null() { - return Err(vm.new_runtime_error("Failed to set SSL context".to_owned())); + + /// Invoke the Python SNI callback + pub(crate) fn invoke_sni_callback( + &self, + sni_name: Option<&str>, + vm: &VirtualMachine, + ) -> PyResult<()> { + let callback = self + .context + .read() + .sni_callback + .read() + .clone() + .ok_or_else(|| vm.new_value_error("SNI callback not set"))?; + + let ssl_sock = self.owner.read().clone().unwrap_or(vm.ctx.none()); + let server_name_py: PyObjectRef = match sni_name { + Some(name) => vm.ctx.new_str(name.to_string()).into(), + None => vm.ctx.none(), + }; + let initial_context: PyObjectRef = self.context.read().clone().into(); + + let result = callback.call((ssl_sock, server_name_py, initial_context), vm)?; + + // Check return value type (must be None or integer) + if !vm.is_none(&result) { + // Try to convert to integer + if result.try_to_value::(vm).is_err() { + // Type conversion failed - raise TypeError as unraisable + let type_error = vm.new_type_error(format!( + "servername callback must return None or an integer, not '{}'", + result.class().name() + )); + vm.run_unraisable(type_error, None, result.clone()); + + // Return SSL error with reason set to TLSV1_ALERT_INTERNAL_ERROR + // + // RUSTLS API LIMITATION: + // We cannot send a TLS InternalError alert to the client here because: + // 1. Rustls does not provide a public API like send_fatal_alert() + // 2. This method is called AFTER dropping the connection lock (to prevent deadlock) + // 3. By the time we detect the error, the connection is no longer available + // + // CPython/OpenSSL behavior: + // - SNI callback runs inside SSL_do_handshake with connection active + // - Sets *al = SSL_AD_INTERNAL_ERROR + // - OpenSSL automatically sends alert before returning + // + // RustPython/Rustls behavior: + // - SNI callback runs after dropping connection lock (deadlock prevention) + // - Exception has _reason='TLSV1_ALERT_INTERNAL_ERROR' for error reporting + // - TCP connection closes without sending TLS alert to client + // + // If rustls adds send_fatal_alert() API in the future, we should: + // - Re-acquire connection lock after callback + // - Call: connection.send_fatal_alert(AlertDescription::InternalError) + // - Then close connection + let exc = vm.new_exception_msg( + PySSLError::class(&vm.ctx).to_owned(), + "SNI callback returned invalid type".to_owned(), + ); + let _ = exc.as_object().set_attr( + "reason", + vm.ctx.new_str("TLSV1_ALERT_INTERNAL_ERROR"), + vm, + ); + return Err(exc); } } - // Update self.ctx to the new context - *self.ctx.write() = value; Ok(()) } - #[pygetset] - fn server_hostname(&self) -> Option { - self.server_hostname.clone() + + // Helper to call socket methods, bypassing any SSL wrapper + pub(crate) fn sock_recv(&self, size: usize, vm: &VirtualMachine) -> PyResult { + // In BIO mode, read from incoming BIO + if let Some(ref bio) = self.incoming_bio { + let bio_obj: PyObjectRef = bio.clone().into(); + let read_method = bio_obj.get_attr("read", vm)?; + return read_method.call((vm.ctx.new_int(size),), vm); + } + + // Normal socket mode + let socket_mod = vm.import("socket", 0)?; + let socket_class = socket_mod.get_attr("socket", vm)?; + + // Call socket.socket.recv(self.sock, size) + let recv_method = socket_class.get_attr("recv", vm)?; + recv_method.call((self.sock.clone(), vm.ctx.new_int(size)), vm) + } + + pub(crate) fn sock_send( + &self, + data: Vec, + vm: &VirtualMachine, + ) -> PyResult { + // In BIO mode, write to outgoing BIO + if let Some(ref bio) = self.outgoing_bio { + let bio_obj: PyObjectRef = bio.clone().into(); + let write_method = bio_obj.get_attr("write", vm)?; + return write_method.call((vm.ctx.new_bytes(data),), vm); + } + + // Normal socket mode + let socket_mod = vm.import("socket", 0)?; + let socket_class = socket_mod.get_attr("socket", vm)?; + + // Call socket.socket.send(self.sock, data) + let send_method = socket_class.get_attr("send", vm)?; + send_method.call((self.sock.clone(), vm.ctx.new_bytes(data)), vm) } #[pymethod] - fn getpeercert( + fn __repr__(&self) -> String { + "".to_string() + } + + // Helper function to convert Python PROTO_* constants to rustls versions + fn get_rustls_versions( + minimum: i32, + maximum: i32, + options: i32, + ) -> &'static [&'static rustls::SupportedProtocolVersion] { + // Rustls only supports TLS 1.2 and 1.3 + // PROTO_TLSv1_2 = 0x0303, PROTO_TLSv1_3 = 0x0304 + // PROTO_MINIMUM_SUPPORTED = -2, PROTO_MAXIMUM_SUPPORTED = -1 + // If minimum and maximum are 0, use default (both TLS 1.2 and 1.3) + + // Static arrays for single-version configurations + static TLS12_ONLY: &[&rustls::SupportedProtocolVersion] = &[&TLS12]; + static TLS13_ONLY: &[&rustls::SupportedProtocolVersion] = &[&TLS13]; + + // Normalize special values: -2 (MINIMUM_SUPPORTED) → TLS 1.2, -1 (MAXIMUM_SUPPORTED) → TLS 1.3 + let min = if minimum == -2 { + PROTO_TLSv1_2 + } else { + minimum + }; + let max = if maximum == -1 { + PROTO_TLSv1_3 + } else { + maximum + }; + + // Check if versions are disabled by options + let tls12_disabled = (options & OP_NO_TLSv1_2) != 0; + let tls13_disabled = (options & OP_NO_TLSv1_3) != 0; + + let want_tls12 = (min == 0 || min <= PROTO_TLSv1_2) + && (max == 0 || max >= PROTO_TLSv1_2) + && !tls12_disabled; + let want_tls13 = (min == 0 || min <= PROTO_TLSv1_3) + && (max == 0 || max >= PROTO_TLSv1_3) + && !tls13_disabled; + + match (want_tls12, want_tls13) { + (true, true) => rustls::DEFAULT_VERSIONS, // Both TLS 1.2 and 1.3 + (true, false) => TLS12_ONLY, // Only TLS 1.2 + (false, true) => TLS13_ONLY, // Only TLS 1.3 + (false, false) => rustls::DEFAULT_VERSIONS, // Fallback to default + } + } + + /// Helper: Prepare TLS versions from context settings + fn prepare_tls_versions(&self) -> &'static [&'static rustls::SupportedProtocolVersion] { + let ctx = self.context.read(); + let min_ver = *ctx.minimum_version.read(); + let max_ver = *ctx.maximum_version.read(); + let options = *ctx.options.read(); + Self::get_rustls_versions(min_ver, max_ver, options) + } + + /// Helper: Prepare KX groups (ECDH curve) from context settings + fn prepare_kx_groups( &self, - binary: OptionalArg, vm: &VirtualMachine, - ) -> PyResult> { - let binary = binary.unwrap_or(false); - let stream = self.connection.read(); - if !stream.ssl().is_init_finished() { - return Err(vm.new_value_error("handshake not done yet")); + ) -> PyResult>> { + let ctx = self.context.read(); + let ecdh_curve = ctx.ecdh_curve.read().clone(); + drop(ctx); + + if let Some(ref curve_name) = ecdh_curve { + match curve_name_to_kx_group(curve_name) { + Ok(groups) => Ok(Some(groups)), + Err(e) => Err(vm.new_value_error(format!("Failed to set ECDH curve: {e}"))), + } + } else { + Ok(None) } + } - let peer_cert = stream.ssl().peer_certificate(); - let Some(cert) = peer_cert else { - return Ok(None); + /// Helper: Prepare all common protocol settings (versions, KX groups, ciphers, ALPN) + fn prepare_protocol_settings(&self, vm: &VirtualMachine) -> PyResult { + let ctx = self.context.read(); + let versions = self.prepare_tls_versions(); + let kx_groups = self.prepare_kx_groups(vm)?; + let cipher_suites = ctx.selected_ciphers.read().clone(); + let alpn_protocols = ctx.alpn_protocols.read().clone(); + + Ok(ProtocolSettings { + versions, + kx_groups, + cipher_suites, + alpn_protocols, + }) + } + + /// Initialize server-side TLS connection with configuration + /// + /// This method handles all server-side setup including: + /// - Certificate and key validation + /// - Client authentication configuration + /// - SNI (Server Name Indication) setup + /// - ALPN protocol negotiation + /// - Session resumption configuration + /// + /// Returns the configured ServerConnection. + fn initialize_server_connection( + &self, + conn_guard: &mut Option, + vm: &VirtualMachine, + ) -> PyResult<()> { + let ctx = self.context.read(); + let cert_keys = ctx.cert_keys.read(); + + if cert_keys.is_empty() { + return Err(vm.new_value_error( + "Server-side connection requires certificate and key (use load_cert_chain)", + )); + } + + // Clone cert_keys for use in config + // PrivateKeyDer doesn't implement Clone, use clone_key() + let cert_keys_clone: Vec = cert_keys + .iter() + .map(|(ck, pk)| (ck.clone(), pk.clone_key())) + .collect(); + drop(cert_keys); + + // Prepare common protocol settings (TLS versions, ECDH curve, cipher suites, ALPN) + let protocol_settings = self.prepare_protocol_settings(vm)?; + let min_ver = *ctx.minimum_version.read(); + + // Check if client certificate verification is required + let verify_mode = *ctx.verify_mode.read(); + let root_store = ctx.root_certs.read(); + let pha_enabled = *ctx.post_handshake_auth.read(); + + // Check if TLS 1.3 is being used + let is_tls13 = min_ver >= PROTO_TLSv1_3; + + // For TLS 1.3: always use deferred validation for client certificates + // For TLS 1.2: use immediate validation during handshake + let use_deferred_validation = is_tls13 + && !pha_enabled + && (verify_mode == CERT_REQUIRED || verify_mode == CERT_OPTIONAL); + + // For TLS 1.3 + PHA: if PHA is enabled, don't request cert in initial handshake + // The certificate will be requested later via verify_client_post_handshake() + let request_initial_cert = if pha_enabled { + // PHA enabled: don't request cert initially (will use PHA later) + false + } else if verify_mode == CERT_REQUIRED || verify_mode == CERT_OPTIONAL { + // PHA not enabled or TLS 1.2: request cert in initial handshake + true + } else { + // CERT_NONE + false }; - if binary { - // Return DER-encoded certificate - cert_to_py(vm, &cert, true).map(Some) + // Check if SNI callback is set + let sni_callback = ctx.sni_callback.read().clone(); + let use_sni_resolver = sni_callback.is_some(); + + // Create SNI state if needed (to be stored in PySSLSocket later) + // For SNI, use the first cert_key pair as the initial certificate + let sni_state: Option>> = if use_sni_resolver { + // Use first cert_key as initial certificate for SNI + // Extract CertifiedKey from tuple + let (first_cert_key, _) = &cert_keys_clone[0]; + let first_cert_key = first_cert_key.clone(); + + // Check if we already have existing SNI state (from previous connection) + let existing_sni_state = self.sni_state.read().clone(); + + if let Some(sni_state_arc) = existing_sni_state { + // Reuse existing Arc and update its contents + // This is crucial: rustls SniCertResolver holds references to this Arc + let mut state = sni_state_arc.lock(); + state.0 = first_cert_key; + state.1 = None; // Reset SNI name for new connection + drop(state); + + // Return the existing Arc (not a new one!) + Some(sni_state_arc) + } else { + // First connection: create new SNI state + Some(Arc::new(ParkingMutex::new((first_cert_key, None)))) + } } else { - // Check verify_mode - unsafe { - let ssl_ctx = sys::SSL_get_SSL_CTX(stream.ssl().as_ptr()); - let verify_mode = sys::SSL_CTX_get_verify_mode(ssl_ctx); - if (verify_mode & sys::SSL_VERIFY_PEER as libc::c_int) == 0 { - // Return empty dict when SSL_VERIFY_PEER is not set - Ok(Some(vm.ctx.new_dict().into())) - } else { - // Return decoded certificate - cert_to_py(vm, &cert, false).map(Some) - } + None + }; + + // Determine which cert resolver to use + // Priority: SNI > Multi-cert/Single-cert via MultiCertResolver + let cert_resolver: Option> = if use_sni_resolver { + // SNI takes precedence - use first cert_key for initial setup + sni_state.as_ref().map(|sni_state_arc| { + Arc::new(SniCertResolver { + sni_state: sni_state_arc.clone(), + }) as Arc + }) + } else { + // Use MultiCertResolver for all cases (single or multiple certs) + // Extract CertifiedKey from tuples for MultiCertResolver + let cert_keys_only: Vec> = + cert_keys_clone.iter().map(|(ck, _)| ck.clone()).collect(); + Some(Arc::new(MultiCertResolver::new(cert_keys_only))) + }; + + // Extract cert_chain and private_key from first cert_key + // + // Note: Since we always use cert_resolver now, these values won't actually be used + // by create_server_config. But we still need to provide them for the API signature. + let (first_cert_key, _) = &cert_keys_clone[0]; + let certs_clone = first_cert_key.cert.clone(); + + // Provide a dummy key since cert_resolver will handle cert selection + let key_clone = PrivateKeyDer::Pkcs8(Vec::new().into()); + + // Get shared server session storage and ticketer from context + let server_session_storage = ctx.rustls_server_session_store.clone(); + let server_ticketer = ctx.server_ticketer.clone(); + + // Build server config using compat helper + let config_options = ServerConfigOptions { + protocol_settings, + cert_chain: certs_clone, + private_key: key_clone, + root_store: if request_initial_cert { + Some(root_store.clone()) + } else { + None + }, + request_client_cert: request_initial_cert, + use_deferred_validation, + cert_resolver, + deferred_cert_error: if use_deferred_validation { + Some(self.deferred_cert_error.clone()) + } else { + None + }, + session_storage: Some(server_session_storage), + ticketer: Some(server_ticketer), + }; + + drop(root_store); + + // Check if we have a cached ServerConfig + let cached_config_arc = ctx.server_config.read().clone(); + drop(ctx); + + let config_arc = if let Some(cached) = cached_config_arc { + // Don't use cache when SNI is enabled, because each connection needs + // a fresh SniCertResolver with the correct Arc references + if use_sni_resolver { + let config = + create_server_config(config_options).map_err(|e| vm.new_value_error(e))?; + Arc::new(config) + } else { + cached } + } else { + let config = + create_server_config(config_options).map_err(|e| vm.new_value_error(e))?; + let config_arc = Arc::new(config); + + // Cache the ServerConfig for future connections + let ctx = self.context.read(); + *ctx.server_config.write() = Some(config_arc.clone()); + drop(ctx); + + config_arc + }; + + let conn = ServerConnection::new(config_arc).map_err(|e| { + vm.new_value_error(format!("Failed to create server connection: {e}")) + })?; + + *conn_guard = Some(TlsConnection::Server(conn)); + + // If ClientHello buffer exists (from SNI callback), re-inject it + if let Some(ref hello_data) = *self.client_hello_buffer.lock() + && let Some(TlsConnection::Server(ref mut server)) = *conn_guard + { + let mut cursor = std::io::Cursor::new(hello_data.as_slice()); + let _ = server.read_tls(&mut cursor); + + // Process the re-injected ClientHello + let _ = server.process_new_packets(); + + // DON'T clear buffer - keep it to prevent callback from being invoked again + // The buffer being non-empty signals that SNI callback was already processed + } + + // Store SNI state if we're using SNI resolver + if let Some(sni_state_arc) = sni_state { + *self.sni_state.write() = Some(sni_state_arc); } + + Ok(()) } #[pymethod] - fn get_unverified_chain(&self, vm: &VirtualMachine) -> PyResult> { - let stream = self.connection.read(); - let Some(chain) = stream.ssl().peer_cert_chain() else { - return Ok(None); - }; + fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { + // Check if handshake already done + if *self.handshake_done.lock() { + return Ok(()); + } - // Return Certificate objects - let certs: Vec = chain - .iter() - .map(|cert| unsafe { - sys::X509_up_ref(cert.as_ptr()); - let owned = X509::from_ptr(cert.as_ptr()); - cert_to_certificate(vm, owned) - }) - .collect::>()?; - Ok(Some(vm.ctx.new_list(certs))) + let mut conn_guard = self.connection.lock(); + + // Initialize connection if not already done + if conn_guard.is_none() { + // Check for pending context change (from SNI callback) + if let Some(new_ctx) = self.pending_context.write().take() { + *self.context.write() = new_ctx; + } + + if self.server_side { + // Server-side connection - delegate to helper method + self.initialize_server_connection(&mut conn_guard, vm)?; + } else { + // Client-side connection + let ctx = self.context.read(); + + // Prepare common protocol settings (TLS versions, ECDH curve, cipher suites, ALPN) + let protocol_settings = self.prepare_protocol_settings(vm)?; + + // Clone values we need before building config + let verify_mode = *ctx.verify_mode.read(); + let root_store_clone = ctx.root_certs.read().clone(); + let ca_certs_der_clone = ctx.ca_certs_der.read().clone(); + + // For client mTLS: extract cert_chain and private_key from first cert_key (if any) + // Now we store both CertifiedKey and PrivateKeyDer as tuple + let cert_keys_guard = ctx.cert_keys.read(); + let (cert_chain_clone, private_key_opt) = if !cert_keys_guard.is_empty() { + let (first_cert_key, private_key) = &cert_keys_guard[0]; + let certs = first_cert_key.cert.clone(); + (certs, Some(private_key.clone_key())) + } else { + (Vec::new(), None) + }; + drop(cert_keys_guard); + + let check_hostname = *ctx.check_hostname.read(); + let verify_flags = *ctx.verify_flags.read(); + + // Get session store before dropping ctx + let session_store = ctx.rustls_session_store.clone(); + + // Get CRLs for revocation checking + let crls_clone = ctx.crls.read().clone(); + + // Drop ctx early to avoid borrow conflicts + drop(ctx); + + // Build client config using compat helper + let config_options = ClientConfigOptions { + protocol_settings, + root_store: if verify_mode != CERT_NONE { + Some(root_store_clone) + } else { + None + }, + ca_certs_der: ca_certs_der_clone, + cert_chain: if !cert_chain_clone.is_empty() { + Some(cert_chain_clone) + } else { + None + }, + private_key: private_key_opt, + verify_server_cert: verify_mode != CERT_NONE, + check_hostname, + verify_flags, + session_store: Some(session_store), + crls: crls_clone, + }; + + let config = + create_client_config(config_options).map_err(|e| vm.new_value_error(e))?; + + // Parse server name for SNI + // Convert to ServerName + use rustls::pki_types::ServerName; + let hostname_opt = self.server_hostname.read().clone(); + + let server_name = if let Some(ref hostname) = hostname_opt { + // Use the provided hostname for SNI + ServerName::try_from(hostname.clone()).map_err(|e| { + vm.new_value_error(format!("Invalid server hostname: {e:?}")) + })? + } else { + // When server_hostname=None, use an IP address to suppress SNI + // no hostname = no SNI extension + ServerName::IpAddress( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)).into(), + ) + }; + + let conn = ClientConnection::new(Arc::new(config), server_name.clone()) + .map_err(|e| { + vm.new_value_error(format!("Failed to create client connection: {e}")) + })?; + + *conn_guard = Some(TlsConnection::Client(conn)); + } + } + + // Perform the actual handshake by exchanging data with the socket/BIO + match conn_guard.as_mut() { + Some(TlsConnection::Client(_conn)) => { + // CLIENT is simple - no SNI callback handling needed + ssl_do_handshake(conn_guard.as_mut().unwrap(), self, vm) + .map_err(|e| e.into_py_err(vm))?; + + drop(conn_guard); + self.complete_handshake(vm)?; + Ok(()) + } + Some(TlsConnection::Server(_conn)) => { + // Use OpenSSL-compatible handshake for server + // Handle SNI callback restart + match ssl_do_handshake(conn_guard.as_mut().unwrap(), self, vm) { + Ok(()) => { + // Handshake completed successfully + drop(conn_guard); + self.complete_handshake(vm)?; + Ok(()) + } + Err(SslError::SniCallbackRestart) => { + // SNI detected - need to call callback and recreate connection + + // CRITICAL: Drop connection lock BEFORE calling Python callback to avoid deadlock + // + // Deadlock scenario if we keep the lock: + // 1. This thread holds self.connection.lock() + // 2. Python callback invokes other SSL methods (e.g., getpeercert(), cipher()) + // 3. Those methods try to acquire self.connection.lock() again + // 4. PyMutex (parking_lot::Mutex) is not reentrant -> DEADLOCK + // + // Trade-off: By dropping the lock, we lose the ability to send TLS alerts + // because Rustls doesn't provide a send_fatal_alert() API. See detailed + // explanation in invoke_sni_callback() where we set _reason attribute. + drop(conn_guard); + + // Get the SNI name that was extracted (may be None if client didn't send SNI) + let sni_name = self.get_extracted_sni_name(); + + // Now safe to call Python callback (no locks held) + self.invoke_sni_callback(sni_name.as_deref(), vm)?; + + // Clear connection to trigger recreation + *self.connection.lock() = None; + + // Recursively call do_handshake to recreate with new context + self.do_handshake(vm) + } + Err(e) => { + // Other errors - convert to Python exception + drop(conn_guard); + Err(e.into_py_err(vm)) + } + } + } + None => unreachable!(), + } } #[pymethod] - fn get_verified_chain(&self, vm: &VirtualMachine) -> PyResult> { - let stream = self.connection.read(); - unsafe { - let chain = sys::SSL_get0_verified_chain(stream.ssl().as_ptr()); - if chain.is_null() { - return Ok(None); + fn read( + &self, + len: OptionalArg, + buffer: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + // Convert len to usize, defaulting to 1024 if not provided + // -1 means read all available data (treat as large buffer size) + let len_val = len.unwrap_or(PEM_BUFSIZE as isize); + let mut len = if len_val == -1 { + // -1 is only valid when a buffer is provided + match &buffer { + OptionalArg::Present(buf_arg) => buf_arg.len(), + OptionalArg::Missing => { + return Err(vm.new_value_error("negative read length")); + } } + } else if len_val < 0 { + return Err(vm.new_value_error("negative read length")); + } else { + len_val as usize + }; - let num_certs = sys::OPENSSL_sk_num(chain as *const _); + // if buffer is provided, limit len to buffer size + if let OptionalArg::Present(buf_arg) = &buffer { + let buf_len = buf_arg.len(); + if len_val <= 0 || len > buf_len { + len = buf_len; + } + } - let mut certs = Vec::with_capacity(num_certs as usize); - // Return Certificate objects - for i in 0..num_certs { - let cert_ptr = sys::OPENSSL_sk_value(chain as *const _, i) as *mut sys::X509; - if cert_ptr.is_null() { - continue; + // return empty bytes immediately for len=0 + if len == 0 { + return match buffer { + OptionalArg::Present(_) => Ok(vm.ctx.new_int(0).into()), + OptionalArg::Missing => Ok(vm.ctx.new_bytes(vec![]).into()), + }; + } + + // Ensure handshake is done + if !*self.handshake_done.lock() { + return Err(vm.new_value_error("Handshake not completed")); + } + + // Check if connection has been shut down + // After unwrap()/shutdown(), read operations should fail with SSLError + let shutdown_state = *self.shutdown_state.lock(); + if shutdown_state != ShutdownState::NotStarted { + return Err(vm.new_exception_msg( + PySSLError::class(&vm.ctx).to_owned(), + "cannot read after shutdown".to_owned(), + )); + } + + // Check for deferred certificate verification errors (TLS 1.3) + self.check_deferred_cert_error(vm)?; + + // Helper function to handle return value based on buffer presence + let return_data = |data: Vec, + buffer_arg: &OptionalArg, + vm: &VirtualMachine| + -> PyResult { + match buffer_arg { + OptionalArg::Present(buf_arg) => { + // Write into buffer and return number of bytes written + let n = data.len(); + if n > 0 { + let mut buf = buf_arg.borrow_buf_mut(); + let buf_slice = &mut *buf; + let copy_len = n.min(buf_slice.len()); + buf_slice[..copy_len].copy_from_slice(&data[..copy_len]); + } + Ok(vm.ctx.new_int(n).into()) } - // Clone the X509 certificate to create an owned copy - sys::X509_up_ref(cert_ptr); - let owned_cert = X509::from_ptr(cert_ptr); - let cert_obj = cert_to_certificate(vm, owned_cert)?; - certs.push(cert_obj); + OptionalArg::Missing => { + // Return bytes object + Ok(vm.ctx.new_bytes(data).into()) + } + } + }; + + let mut conn_guard = self.connection.lock(); + let conn = conn_guard + .as_mut() + .ok_or_else(|| vm.new_value_error("Connection not established"))?; + + // Use compat layer for unified read logic with proper EOF handling + // This matches CPython's SSL_read_ex() approach + let mut buf = vec![0u8; len]; + + match crate::ssl::compat::ssl_read(conn, &mut buf, self, vm) { + Ok(n) => { + buf.truncate(n); + return_data(buf, &buffer, vm) + } + Err(crate::ssl::compat::SslError::Eof) => { + // EOF occurred in violation of protocol (unexpected closure) + Err(vm.new_exception_msg( + PySSLEOFError::class(&vm.ctx).to_owned(), + "EOF occurred in violation of protocol".to_owned(), + )) + } + Err(crate::ssl::compat::SslError::ZeroReturn) => { + // Clean closure with close_notify - return empty data + return_data(vec![], &buffer, vm) + } + Err(crate::ssl::compat::SslError::WantRead) => { + // Non-blocking mode: would block + Err(create_ssl_want_read_error(vm)) + } + Err(crate::ssl::compat::SslError::WantWrite) => { + // Non-blocking mode: would block on write + Err(create_ssl_want_write_error(vm)) + } + Err(crate::ssl::compat::SslError::Timeout(msg)) => Err(timeout_error_msg(vm, msg)), + Err(crate::ssl::compat::SslError::Py(e)) => { + // Python exception - pass through + Err(e) + } + Err(e) => { + // Other SSL errors + Err(e.into_py_err(vm)) } - - Ok(if certs.is_empty() { - None - } else { - Some(vm.ctx.new_list(certs)) - }) } } #[pymethod] - fn version(&self) -> Option<&'static str> { - let v = self.connection.read().ssl().version_str(); - if v == "unknown" { None } else { Some(v) } - } + fn pending(&self) -> PyResult { + // Returns the number of already decrypted bytes available for read + // This is critical for asyncore's readable() method which checks socket.pending() > 0 + let mut conn_guard = self.connection.lock(); + let conn = match conn_guard.as_mut() { + Some(c) => c, + None => return Ok(0), // No connection established yet + }; - #[pymethod] - fn cipher(&self) -> Option { - self.connection - .read() - .ssl() - .current_cipher() - .map(cipher_to_tuple) + // Use rustls Reader's fill_buf() to check buffered plaintext + // fill_buf() returns a reference to buffered data without consuming it + // This matches OpenSSL's SSL_pending() behavior + use std::io::BufRead; + let mut reader = conn.reader(); + match reader.fill_buf() { + Ok(buf) => Ok(buf.len()), + Err(_) => { + // WouldBlock or other errors mean no data available + // Return 0 like OpenSSL does when buffer is empty + Ok(0) + } + } } #[pymethod] - fn shared_ciphers(&self, vm: &VirtualMachine) -> Option { - #[cfg(ossl110)] - { - let stream = self.connection.read(); - unsafe { - let server_ciphers = SSL_get_ciphers(stream.ssl().as_ptr()); - if server_ciphers.is_null() { - return None; - } - - let client_ciphers = SSL_get_client_ciphers(stream.ssl().as_ptr()); - if client_ciphers.is_null() { - return None; - } + fn write(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult { + let data_bytes = data.borrow_buf(); + let data_len = data_bytes.len(); - let mut result = Vec::new(); - let num_server = sys::OPENSSL_sk_num(server_ciphers as *const _); - let num_client = sys::OPENSSL_sk_num(client_ciphers as *const _); + // return 0 immediately for empty write + if data_len == 0 { + return Ok(0); + } - for i in 0..num_server { - let server_cipher_ptr = sys::OPENSSL_sk_value(server_ciphers as *const _, i) - as *const sys::SSL_CIPHER; + // Ensure handshake is done + if !*self.handshake_done.lock() { + return Err(vm.new_value_error("Handshake not completed")); + } - // Check if client supports this cipher by comparing pointers - let mut found = false; - for j in 0..num_client { - let client_cipher_ptr = - sys::OPENSSL_sk_value(client_ciphers as *const _, j) - as *const sys::SSL_CIPHER; + // Check if connection has been shut down + // After unwrap()/shutdown(), write operations should fail with SSLError + let shutdown_state = *self.shutdown_state.lock(); + if shutdown_state != ShutdownState::NotStarted { + return Err(vm.new_exception_msg( + PySSLError::class(&vm.ctx).to_owned(), + "cannot write after shutdown".to_owned(), + )); + } - if server_cipher_ptr == client_cipher_ptr { - found = true; - break; - } + // Check for deferred certificate verification errors (TLS 1.3) + self.check_deferred_cert_error(vm)?; + + let mut conn_guard = self.connection.lock(); + let conn = conn_guard + .as_mut() + .ok_or_else(|| vm.new_value_error("Connection not established"))?; + + // Unified write logic - no need to match on Client/Server anymore + let mut writer = conn.writer(); + writer + .write_all(data_bytes.as_ref()) + .map_err(|e| vm.new_os_error(format!("Write failed: {e}")))?; + + // Flush to get TLS-encrypted data (writer automatically flushed on drop) + // Send encrypted data to socket + if conn.wants_write() { + let is_bio = self.is_bio_mode(); + + if is_bio { + // BIO mode: Write ALL pending TLS data to outgoing BIO + // This prevents hangs where Python's ssl_io_loop waits for data + self.write_pending_tls(conn, vm)?; + } else { + // Socket mode: Try once and may return SSLWantWriteError + let mut buf = Vec::new(); + conn.write_tls(&mut buf) + .map_err(|e| vm.new_os_error(format!("TLS write failed: {e}")))?; + + if !buf.is_empty() { + // Wait for socket to be ready for writing + let timed_out = self.sock_wait_for_io_impl(SelectKind::Write, vm)?; + if timed_out { + return Err(vm.new_os_error("Write operation timed out")); } - if found { - let cipher = ssl::SslCipherRef::from_ptr(server_cipher_ptr as *mut _); - let (name, version, bits) = cipher_to_tuple(cipher); - let tuple = vm.new_tuple(( - vm.ctx.new_str(name), - vm.ctx.new_str(version), - vm.ctx.new_int(bits), - )); - result.push(tuple.into()); + // Send encrypted data to socket + // Convert BlockingIOError to SSLWantWriteError + match self.sock_send(buf, vm) { + Ok(_) => {} + Err(e) => { + if is_blocking_io_error(&e, vm) { + // Non-blocking socket would block - return SSLWantWriteError + return Err(create_ssl_want_write_error(vm)); + } + return Err(e); + } } } - - if result.is_empty() { - None - } else { - Some(vm.ctx.new_list(result)) - } } } - #[cfg(not(ossl110))] - { - let _ = vm; - None - } - } - - #[pymethod] - fn selected_alpn_protocol(&self) -> Option { - #[cfg(ossl102)] - { - let stream = self.connection.read(); - unsafe { - let mut out: *const libc::c_uchar = std::ptr::null(); - let mut outlen: libc::c_uint = 0; - - sys::SSL_get0_alpn_selected(stream.ssl().as_ptr(), &mut out, &mut outlen); - if out.is_null() { - None - } else { - let slice = std::slice::from_raw_parts(out, outlen as usize); - Some(String::from_utf8_lossy(slice).into_owned()) - } - } - } - #[cfg(not(ossl102))] - { - None - } + Ok(data_len) } #[pymethod] - fn get_channel_binding( + fn getpeercert( &self, - cb_type: OptionalArg, + binary_form: OptionalArg, vm: &VirtualMachine, - ) -> PyResult> { - const CB_MAXLEN: usize = 512; - - let cb_type_str = cb_type.as_ref().map_or("tls-unique", |s| s.as_str()); + ) -> PyResult> { + let binary = binary_form.unwrap_or(false); - if cb_type_str != "tls-unique" { - return Err(vm.new_value_error(format!( - "Unsupported channel binding type '{}'", - cb_type_str - ))); + // Check if handshake is complete + if !*self.handshake_done.lock() { + return Err(vm.new_value_error("handshake not done yet")); } - let stream = self.connection.read(); - let ssl_ptr = stream.ssl().as_ptr(); + // Get peer certificates from TLS connection + let conn_guard = self.connection.lock(); + let conn = conn_guard + .as_ref() + .ok_or_else(|| vm.new_value_error("No TLS connection established"))?; + + let certs = conn.peer_certificates(); - unsafe { - let session_reused = sys::SSL_session_reused(ssl_ptr) != 0; - let is_client = matches!(self.socket_type, SslServerOrClient::Client); + // Return None if no peer certificate + let Some(certs) = certs else { + return Ok(None); + }; - // Use XOR logic from CPython - let use_finished = session_reused ^ is_client; + // Get first certificate (peer's certificate) + let cert_der = certs + .first() + .ok_or_else(|| vm.new_value_error("No peer certificate available"))?; - let mut buf = vec![0u8; CB_MAXLEN]; - let len = if use_finished { - sys::SSL_get_finished(ssl_ptr, buf.as_mut_ptr() as *mut _, CB_MAXLEN) - } else { - sys::SSL_get_peer_finished(ssl_ptr, buf.as_mut_ptr() as *mut _, CB_MAXLEN) - }; + if binary { + // Return DER-encoded certificate as bytes + let der_bytes = cert_der.as_ref().to_vec(); + return Ok(Some(vm.ctx.new_bytes(der_bytes).into())); + } - if len == 0 { - Ok(None) - } else { - buf.truncate(len); - Ok(Some(vm.ctx.new_bytes(buf))) - } + // Dictionary mode: check verify_mode + let verify_mode = *self.context.read().verify_mode.read(); + + if verify_mode == CERT_NONE { + // Return empty dict when CERT_NONE + return Ok(Some(vm.ctx.new_dict().into())); } + + // Parse DER certificate and convert to dict + let der_bytes = cert_der.as_ref(); + let (_, cert) = x509_parser::parse_x509_certificate(der_bytes) + .map_err(|e| vm.new_value_error(format!("Failed to parse certificate: {e}")))?; + + cert::cert_to_dict(vm, &cert).map(Some) } #[pymethod] - fn verify_client_post_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { - #[cfg(ossl111)] - { - let stream = self.connection.read(); - let result = unsafe { SSL_verify_client_post_handshake(stream.ssl().as_ptr()) }; - if result == 0 { - Err(convert_openssl_error(vm, openssl::error::ErrorStack::get())) - } else { - Ok(()) - } - } - #[cfg(not(ossl111))] - { - Err(vm.new_not_implemented_error( - "Post-handshake auth is not supported by your OpenSSL version.".to_owned(), - )) - } + fn cipher(&self) -> Option<(String, String, i32)> { + let conn_guard = self.connection.lock(); + let conn = conn_guard.as_ref()?; + + let suite = conn.negotiated_cipher_suite()?; + + // Extract cipher information using unified helper + let cipher_info = extract_cipher_info(&suite); + + // Note: returns a 3-tuple (name, protocol_version, bits) + // The 'description' field is part of get_ciphers() output, not cipher() + Some(( + cipher_info.name, + cipher_info.protocol.to_string(), + cipher_info.bits, + )) } #[pymethod] - fn shutdown(&self, vm: &VirtualMachine) -> PyResult> { - let stream = self.connection.read(); + fn version(&self) -> Option { + let conn_guard = self.connection.lock(); + let conn = conn_guard.as_ref()?; - // BIO mode doesn't have an underlying socket - if stream.is_bio() { - return Err(vm.new_not_implemented_error( - "shutdown() is not supported for BIO-based SSL objects".to_owned(), - )); - } + let suite = conn.negotiated_cipher_suite()?; + + let version_str = match suite.version().version { + rustls::ProtocolVersion::TLSv1_2 => "TLSv1.2", + rustls::ProtocolVersion::TLSv1_3 => "TLSv1.3", + _ => "Unknown", + }; - let ssl_ptr = stream.ssl().as_ptr(); + Some(version_str.to_string()) + } - // Perform SSL shutdown - let ret = unsafe { sys::SSL_shutdown(ssl_ptr) }; + #[pymethod] + fn selected_alpn_protocol(&self) -> Option { + let conn_guard = self.connection.lock(); + let conn = conn_guard.as_ref()?; - if ret < 0 { - // Error occurred - let err = unsafe { sys::SSL_get_error(ssl_ptr, ret) }; + let alpn_bytes = conn.alpn_protocol()?; - if err == sys::SSL_ERROR_WANT_READ || err == sys::SSL_ERROR_WANT_WRITE { - // Non-blocking would block - this is okay for shutdown - // Return the underlying socket - } else { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - format!("SSL shutdown failed: error code {}", err), - )); - } + // Null byte protocol (vec![0u8]) means no actual ALPN match (fallback protocol) + if alpn_bytes.is_empty() || alpn_bytes == [0u8] { + return None; } - // Return the underlying socket - // Get the socket from the stream (SocketStream wraps PyRef) - let socket = stream - .get_ref() - .expect("unwrap() called on bio mode; should only be called in socket mode"); - Ok(socket.0.clone()) + // Convert bytes to string + String::from_utf8(alpn_bytes.to_vec()).ok() } - #[cfg(osslconf = "OPENSSL_NO_COMP")] #[pymethod] - fn compression(&self) -> Option<&'static str> { + fn selected_npn_protocol(&self) -> Option { + // NPN (Next Protocol Negotiation) is the predecessor to ALPN + // It was deprecated in favor of ALPN (RFC 7301) + // Rustls doesn't support NPN, only ALPN + // Return None to indicate NPN is not supported None } - #[cfg(not(osslconf = "OPENSSL_NO_COMP"))] - #[pymethod] - fn compression(&self) -> Option<&'static str> { - let stream = self.connection.read(); - let comp_method = unsafe { sys::SSL_get_current_compression(stream.ssl().as_ptr()) }; - if comp_method.is_null() { - return None; - } - let typ = unsafe { sys::COMP_get_type(comp_method) }; - let nid = Nid::from_raw(typ); - if nid == Nid::UNDEF { - return None; - } - nid.short_name().ok() + + #[pygetset] + fn owner(&self) -> Option { + self.owner.read().clone() } - #[pymethod] - fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { - let mut stream = self.connection.write(); - let ssl_ptr = stream.ssl().as_ptr(); - - // BIO mode: no timeout/select logic, just do handshake - if stream.is_bio() { - return stream.do_handshake().map_err(|e| { - let exc = convert_ssl_error(vm, e); - // If it's a cert verification error, set verify info - if exc.class().is(PySslCertVerificationError::class(&vm.ctx)) { - set_verify_error_info(&exc, ssl_ptr, vm); - } - exc - }); - } + #[pygetset(setter)] + fn set_owner(&self, owner: PyObjectRef, _vm: &VirtualMachine) -> PyResult<()> { + *self.owner.write() = Some(owner); + Ok(()) + } - // Socket mode: handle timeout and blocking - let timeout = stream - .get_ref() - .expect("handshake called in bio mode; should only be called in socket mode") - .timeout_deadline(); - loop { - let err = match stream.do_handshake() { - Ok(()) => return Ok(()), - Err(e) => e, - }; - let (needs, state) = stream - .get_ref() - .expect("handshake called in bio mode; should only be called in socket mode") - .socket_needs(&err, &timeout); - match state { - SelectRet::TimedOut => { - return Err(socket::timeout_error_msg( - vm, - "The handshake operation timed out".to_owned(), - )); - } - SelectRet::Closed => return Err(socket_closed_error(vm)), - SelectRet::Nonblocking => {} - _ => { - if needs.is_some() { - continue; - } - } - } - let exc = convert_ssl_error(vm, err); - // If it's a cert verification error, set verify info - if exc.class().is(PySslCertVerificationError::class(&vm.ctx)) { - set_verify_error_info(&exc, ssl_ptr, vm); - } - return Err(exc); - } + #[pygetset] + fn server_side(&self) -> bool { + self.server_side } - #[pymethod] - fn write(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult { - let mut stream = self.connection.write(); - let data = data.borrow_buf(); - let data = &*data; - - // BIO mode: no timeout/select logic - if stream.is_bio() { - return stream.ssl_write(data).map_err(|e| convert_ssl_error(vm, e)); - } - - // Socket mode: handle timeout and blocking - let socket_ref = stream - .get_ref() - .expect("write called in bio mode; should only be called in socket mode"); - let timeout = socket_ref.timeout_deadline(); - let state = socket_ref.select(SslNeeds::Write, &timeout); - match state { - SelectRet::TimedOut => { - return Err(socket::timeout_error_msg( - vm, - "The write operation timed out".to_owned(), - )); - } - SelectRet::Closed => return Err(socket_closed_error(vm)), - _ => {} + #[pygetset] + fn context(&self) -> PyRef { + self.context.read().clone() + } + + #[pygetset(setter)] + fn set_context(&self, value: PyRef, _vm: &VirtualMachine) -> PyResult<()> { + // Update context reference immediately + // SSL_set_SSL_CTX allows context changes at any time, + // even after handshake completion + *self.context.write() = value; + + // Clear pending context as we've applied the change + *self.pending_context.write() = None; + + Ok(()) + } + + #[pygetset] + fn server_hostname(&self) -> Option { + self.server_hostname.read().clone() + } + + #[pygetset(setter)] + fn set_server_hostname( + &self, + value: Option, + vm: &VirtualMachine, + ) -> PyResult<()> { + // Check if handshake is already done + if *self.handshake_done.lock() { + return Err( + vm.new_value_error("Cannot set server_hostname on socket after handshake") + ); } - loop { - let err = match stream.ssl_write(data) { - Ok(len) => return Ok(len), - Err(e) => e, - }; - let (needs, state) = stream - .get_ref() - .expect("write called in bio mode; should only be called in socket mode") - .socket_needs(&err, &timeout); - match state { - SelectRet::TimedOut => { - return Err(socket::timeout_error_msg( - vm, - "The write operation timed out".to_owned(), - )); - } - SelectRet::Closed => return Err(socket_closed_error(vm)), - SelectRet::Nonblocking => {} - _ => { - if needs.is_some() { - continue; - } - } - } - return Err(convert_ssl_error(vm, err)); + + // Validate hostname + if let Some(hostname_str) = &value { + validate_hostname(hostname_str.as_str(), vm)?; } + + *self.server_hostname.write() = value.map(|s| s.as_str().to_string()); + Ok(()) } #[pygetset] - fn session(&self, _vm: &VirtualMachine) -> PyResult> { - let stream = self.connection.read(); - unsafe { - let session_ptr = sys::SSL_get_session(stream.ssl().as_ptr()); - if session_ptr.is_null() { - Ok(None) - } else { - // Increment reference count since SSL_get_session returns a borrowed reference - #[cfg(ossl110)] - let _session = sys::SSL_SESSION_up_ref(session_ptr); - - Ok(Some(PySslSession { - session: session_ptr, - ctx: self.ctx.read().clone(), - })) - } + fn session(&self, vm: &VirtualMachine) -> PyResult { + // Return the stored session object if any + let sess = self.session.read().clone(); + if let Some(s) = sess { + Ok(s) + } else { + Ok(vm.ctx.none()) } } #[pygetset(setter)] fn set_session(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // Check if value is SSLSession type - let session = value - .downcast_ref::() - .ok_or_else(|| vm.new_type_error("Value is not a SSLSession.".to_owned()))?; - - // Check if session refers to the same SSLContext - if !std::ptr::eq( - self.ctx.read().ctx.read().as_ptr(), - session.ctx.ctx.read().as_ptr(), - ) { - return Err( - vm.new_value_error("Session refers to a different SSLContext.".to_owned()) - ); + // Validate that value is an SSLSession + if !value.is(vm.ctx.types.none_type) { + // Try to downcast to SSLSession to validate + let _ = value + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("Value is not a SSLSession."))?; } // Check if this is a client socket - if self.socket_type != SslServerOrClient::Client { - return Err( - vm.new_value_error("Cannot set session for server-side SSLSocket.".to_owned()) - ); + if self.server_side { + return Err(vm.new_value_error("Cannot set session for server-side SSLSocket")); } - // Check if handshake is not finished - let stream = self.connection.read(); - unsafe { - if sys::SSL_is_init_finished(stream.ssl().as_ptr()) != 0 { - return Err( - vm.new_value_error("Cannot set session after handshake.".to_owned()) - ); - } - - if sys::SSL_set_session(stream.ssl().as_ptr(), session.session) == 0 { - return Err(convert_openssl_error(vm, ErrorStack::get())); - } + // Check if handshake is already done + if *self.handshake_done.lock() { + return Err(vm.new_value_error("Cannot set session after handshake.")); } + // Store the session for potential use during handshake + *self.session.write() = if value.is(vm.ctx.types.none_type) { + None + } else { + Some(value) + }; + Ok(()) } #[pygetset] fn session_reused(&self) -> bool { - let stream = self.connection.read(); - unsafe { sys::SSL_session_reused(stream.ssl().as_ptr()) != 0 } + // Return the tracked session reuse status + *self.session_was_reused.lock() } #[pymethod] - fn read( - &self, - n: usize, - buffer: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - // Special case: reading 0 bytes should return empty bytes immediately - if n == 0 { - return if buffer.is_present() { - Ok(vm.ctx.new_int(0).into()) - } else { - Ok(vm.ctx.new_bytes(vec![]).into()) - }; - } + fn compression(&self) -> Option<&'static str> { + // rustls doesn't support compression + None + } - let mut stream = self.connection.write(); - let mut inner_buffer = if let OptionalArg::Present(buffer) = &buffer { - Either::A(buffer.borrow_buf_mut()) - } else { - Either::B(vec![0u8; n]) - }; - let buf = match &mut inner_buffer { - Either::A(b) => &mut **b, - Either::B(b) => b.as_mut_slice(), - }; - let buf = match buf.get_mut(..n) { - Some(b) => b, - None => buf, + #[pymethod] + fn get_unverified_chain(&self, vm: &VirtualMachine) -> PyResult> { + // Get peer certificates from the connection + let conn_guard = self.connection.lock(); + let conn = conn_guard + .as_ref() + .ok_or_else(|| vm.new_value_error("Handshake not completed"))?; + + let certs = conn.peer_certificates(); + + let Some(certs) = certs else { + return Ok(None); }; - // BIO mode: no timeout/select logic - let count = if stream.is_bio() { - match stream.ssl_read(buf) { - Ok(count) => count, - Err(e) => return Err(convert_ssl_error(vm, e)), - } - } else { - // Socket mode: handle timeout and blocking - let timeout = stream - .get_ref() - .expect("read called in bio mode; should only be called in socket mode") - .timeout_deadline(); - loop { - let err = match stream.ssl_read(buf) { - Ok(count) => break count, - Err(e) => e, - }; - if err.code() == ssl::ErrorCode::ZERO_RETURN - && stream.get_shutdown() == ssl::ShutdownState::RECEIVED - { - break 0; - } - let (needs, state) = stream - .get_ref() - .expect("read called in bio mode; should only be called in socket mode") - .socket_needs(&err, &timeout); - match state { - SelectRet::TimedOut => { - return Err(socket::timeout_error_msg( - vm, - "The read operation timed out".to_owned(), - )); - } - SelectRet::Nonblocking => {} - _ => { - if needs.is_some() { - continue; - } - } + // Convert to list of Certificate objects + let cert_list: Vec = certs + .iter() + .map(|cert_der| { + let cert_bytes = cert_der.as_ref().to_vec(); + PySSLCertificate { + der_bytes: cert_bytes, } - return Err(convert_ssl_error(vm, err)); - } + .into_ref(&vm.ctx) + .into() + }) + .collect(); + + Ok(Some(vm.ctx.new_list(cert_list))) + } + + #[pymethod] + fn get_verified_chain(&self, vm: &VirtualMachine) -> PyResult> { + // Get peer certificates (what peer sent during handshake) + let conn_guard = self.connection.lock(); + let Some(ref conn) = *conn_guard else { + return Ok(None); }; - let ret = match inner_buffer { - Either::A(_buf) => vm.ctx.new_int(count).into(), - Either::B(mut buf) => { - buf.truncate(count); - buf.shrink_to_fit(); - vm.ctx.new_bytes(buf).into() - } + + let peer_certs = conn.peer_certificates(); + + let Some(peer_certs_slice) = peer_certs else { + return Ok(None); }; - Ok(ret) - } - } - #[pyattr] - #[pyclass(module = "ssl", name = "SSLSession")] - #[derive(PyPayload)] - struct PySslSession { - session: *mut sys::SSL_SESSION, - ctx: PyRef, - } + // Build the verified chain using cert module + let ctx_guard = self.context.read(); + let ca_certs_der = ctx_guard.ca_certs_der.read(); + + let chain_der = cert::build_verified_chain(peer_certs_slice, &ca_certs_der); + + // Convert DER chain to Python list of Certificate objects + let cert_list: Vec = chain_der + .into_iter() + .map(|der_bytes| PySSLCertificate { der_bytes }.into_ref(&vm.ctx).into()) + .collect(); - impl fmt::Debug for PySslSession { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.pad("SSLSession") + Ok(Some(vm.ctx.new_list(cert_list))) } - } - impl Drop for PySslSession { - fn drop(&mut self) { - if !self.session.is_null() { - unsafe { - sys::SSL_SESSION_free(self.session); + #[pymethod] + fn shutdown(&self, vm: &VirtualMachine) -> PyResult { + // Check current shutdown state + let current_state = *self.shutdown_state.lock(); + + // If already completed, return immediately + if current_state == ShutdownState::Completed { + if self.is_bio_mode() { + return Ok(vm.ctx.none()); } + return Ok(self.sock.clone()); } - } - } - unsafe impl Send for PySslSession {} - unsafe impl Sync for PySslSession {} + // Get connection + let mut conn_guard = self.connection.lock(); + let conn = conn_guard + .as_mut() + .ok_or_else(|| vm.new_value_error("Connection not established"))?; - impl Comparable for PySslSession { - fn cmp( - zelf: &Py, - other: &crate::vm::PyObject, - op: PyComparisonOp, - _vm: &VirtualMachine, - ) -> PyResult { - let other = class_or_notimplemented!(Self, other); + // Step 1: Send our close_notify if not already sent + if current_state == ShutdownState::NotStarted { + conn.send_close_notify(); + + // Write close_notify to outgoing buffer/BIO + self.write_pending_tls(conn, vm)?; - if !matches!(op, PyComparisonOp::Eq | PyComparisonOp::Ne) { - return Ok(PyComparisonValue::NotImplemented); + // Update state + *self.shutdown_state.lock() = ShutdownState::SentCloseNotify; } - let mut eq = unsafe { - let mut self_len: libc::c_uint = 0; - let mut other_len: libc::c_uint = 0; - let self_id = sys::SSL_SESSION_get_id(zelf.session, &mut self_len); - let other_id = sys::SSL_SESSION_get_id(other.session, &mut other_len); - if self_len != other_len { - false + // Step 2: Try to read and process peer's close_notify + let is_bio = self.is_bio_mode(); + + // First check if we already have peer's close_notify + // This can happen if it was received during a previous read() call + let mut peer_closed = self.check_peer_closed(conn, vm)?; + + // If peer hasn't closed yet, try to read from socket + if !peer_closed { + // Check if socket is in blocking mode (timeout is None) + let is_blocking = if !is_bio { + // Get socket timeout + match self.sock.get_attr("gettimeout", vm) { + Ok(method) => match method.call((), vm) { + Ok(timeout) => vm.is_none(&timeout), + Err(_) => false, + }, + Err(_) => false, + } } else { - let self_slice = std::slice::from_raw_parts(self_id, self_len as usize); - let other_slice = std::slice::from_raw_parts(other_id, other_len as usize); - self_slice == other_slice + false + }; + + if is_bio { + // In BIO mode: non-blocking read attempt + let _ = self.try_read_close_notify(conn, vm); + } else if is_blocking { + // Blocking socket mode: Return immediately without waiting for peer + // + // Reasons we don't read from socket here: + // 1. STARTTLS scenario: application data may arrive before/instead of close_notify + // - Example: client sends ENDTLS, immediately sends plain "msg 5" + // - Server's unwrap() would read "msg 5" and try to parse as TLS → FAIL + // 2. CPython's SSL_shutdown() typically returns immediately without waiting + // 3. Bidirectional shutdown is the application's responsibility + // 4. Reading from socket would consume application data incorrectly + // + // Therefore: Just send our close_notify and return success immediately. + // The peer's close_notify (if any) will remain in the socket buffer. + // + // Mark shutdown as complete and return the underlying socket + drop(conn_guard); + *self.shutdown_state.lock() = ShutdownState::Completed; + *self.connection.lock() = None; + return Ok(self.sock.clone()); } - }; - if matches!(op, PyComparisonOp::Ne) { - eq = !eq; + + // Step 3: Check again if peer has sent close_notify (non-blocking/BIO mode only) + peer_closed = self.check_peer_closed(conn, vm)?; } - Ok(PyComparisonValue::Implemented(eq)) - } - } - #[pyattr] - #[pyclass(module = "ssl", name = "MemoryBIO")] - #[derive(PyPayload)] - struct PySslMemoryBio { - bio: *mut sys::BIO, - eof_written: AtomicCell, - } + drop(conn_guard); // Release lock before returning - impl fmt::Debug for PySslMemoryBio { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.pad("MemoryBIO") - } - } + if !peer_closed { + // Still waiting for peer's close-notify + // Raise SSLWantReadError to signal app needs to transfer data + // This is correct for non-blocking sockets and BIO mode + return Err(create_ssl_want_read_error(vm)); + } + // Both close-notify exchanged, shutdown complete + *self.shutdown_state.lock() = ShutdownState::Completed; - impl Drop for PySslMemoryBio { - fn drop(&mut self) { - if !self.bio.is_null() { - unsafe { - sys::BIO_free_all(self.bio); - } + if is_bio { + return Ok(vm.ctx.none()); } + Ok(self.sock.clone()) } - } - unsafe impl Send for PySslMemoryBio {} - unsafe impl Sync for PySslMemoryBio {} + // Helper: Write all pending TLS data (including close_notify) to outgoing buffer/BIO + fn write_pending_tls(&self, conn: &mut TlsConnection, vm: &VirtualMachine) -> PyResult<()> { + loop { + if !conn.wants_write() { + break; + } + + let mut buf = vec![0u8; SSL3_RT_MAX_PLAIN_LENGTH]; + let written = conn + .write_tls(&mut buf.as_mut_slice()) + .map_err(|e| vm.new_os_error(format!("TLS write failed: {e}")))?; - // OpenSSL functions not in openssl-sys + if written == 0 { + break; + } - unsafe extern "C" { - // X509_check_ca returns 1 for CA certificates, 0 otherwise - fn X509_check_ca(x: *const sys::X509) -> libc::c_int; - } + // Send to outgoing BIO or socket + self.sock_send(buf[..written].to_vec(), vm)?; + } - unsafe extern "C" { - fn SSL_get_ciphers(ssl: *const sys::SSL) -> *const sys::stack_st_SSL_CIPHER; - } + Ok(()) + } + + // Helper: Try to read incoming data from BIO (non-blocking) + fn try_read_close_notify( + &self, + conn: &mut TlsConnection, + vm: &VirtualMachine, + ) -> PyResult<()> { + // Try to read incoming data from BIO + // This is non-blocking in BIO mode - if no data, recv returns empty + match self.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) { + Ok(bytes_obj) => { + let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?; + let data = bytes.borrow_buf(); + + if !data.is_empty() { + // Feed data to TLS connection + let data_slice: &[u8] = data.as_ref(); + let mut cursor = std::io::Cursor::new(data_slice); + let _ = conn.read_tls(&mut cursor); + + // Process packets + let _ = conn.process_new_packets(); + } + } + Err(_) => { + // No data available or error - that's OK in BIO mode + } + } - #[cfg(ossl110)] - unsafe extern "C" { - fn SSL_get_client_ciphers(ssl: *const sys::SSL) -> *const sys::stack_st_SSL_CIPHER; - } + Ok(()) + } - #[cfg(ossl111)] - unsafe extern "C" { - fn SSL_verify_client_post_handshake(ssl: *const sys::SSL) -> libc::c_int; - fn SSL_set_post_handshake_auth(ssl: *mut sys::SSL, val: libc::c_int); - } + // Helper: Check if peer has sent close_notify + fn check_peer_closed( + &self, + conn: &mut TlsConnection, + vm: &VirtualMachine, + ) -> PyResult { + // Process any remaining packets and check peer_has_closed + let io_state = conn + .process_new_packets() + .map_err(|e| vm.new_os_error(format!("Failed to process packets: {e}")))?; - #[cfg(ossl110)] - unsafe extern "C" { - fn SSL_CTX_get_security_level(ctx: *const sys::SSL_CTX) -> libc::c_int; - } + Ok(io_state.peer_has_closed()) + } - unsafe extern "C" { - fn SSL_set_SSL_CTX(ssl: *mut sys::SSL, ctx: *mut sys::SSL_CTX) -> *mut sys::SSL_CTX; - } + #[pymethod] + fn shared_ciphers(&self, vm: &VirtualMachine) -> Option { + // Return None for client-side sockets + if !self.server_side { + return None; + } - #[cfg(ossl110)] - unsafe extern "C" { - fn SSL_SESSION_has_ticket(session: *const sys::SSL_SESSION) -> libc::c_int; - fn SSL_SESSION_get_ticket_lifetime_hint(session: *const sys::SSL_SESSION) -> libc::c_ulong; - } + // Check if handshake completed + if !*self.handshake_done.lock() { + return None; + } - // X509 object types - const X509_LU_X509: libc::c_int = 1; - const X509_LU_CRL: libc::c_int = 2; + // Get negotiated cipher suite from rustls + let conn_guard = self.connection.lock(); + let conn = conn_guard.as_ref()?; - unsafe extern "C" { - fn X509_OBJECT_get_type(obj: *const sys::X509_OBJECT) -> libc::c_int; - } + let suite = conn.negotiated_cipher_suite()?; - // SSL session statistics constants (used with SSL_CTX_ctrl) - const SSL_CTRL_SESS_NUMBER: libc::c_int = 20; - const SSL_CTRL_SESS_CONNECT: libc::c_int = 21; - const SSL_CTRL_SESS_CONNECT_GOOD: libc::c_int = 22; - const SSL_CTRL_SESS_CONNECT_RENEGOTIATE: libc::c_int = 23; - const SSL_CTRL_SESS_ACCEPT: libc::c_int = 24; - const SSL_CTRL_SESS_ACCEPT_GOOD: libc::c_int = 25; - const SSL_CTRL_SESS_ACCEPT_RENEGOTIATE: libc::c_int = 26; - const SSL_CTRL_SESS_HIT: libc::c_int = 27; - const SSL_CTRL_SESS_MISSES: libc::c_int = 29; - const SSL_CTRL_SESS_TIMEOUTS: libc::c_int = 30; - const SSL_CTRL_SESS_CACHE_FULL: libc::c_int = 31; - - // SSL session statistics functions (implemented as macros in OpenSSL) - #[allow(non_snake_case)] - unsafe fn SSL_CTX_sess_number(ctx: *const sys::SSL_CTX) -> libc::c_long { - unsafe { sys::SSL_CTX_ctrl(ctx as *mut _, SSL_CTRL_SESS_NUMBER, 0, std::ptr::null_mut()) } - } + // Extract cipher information using unified helper + let cipher_info = extract_cipher_info(&suite); - #[allow(non_snake_case)] - unsafe fn SSL_CTX_sess_connect(ctx: *const sys::SSL_CTX) -> libc::c_long { - unsafe { - sys::SSL_CTX_ctrl( - ctx as *mut _, - SSL_CTRL_SESS_CONNECT, - 0, - std::ptr::null_mut(), - ) + // Return as list with single tuple (name, version, bits) + let tuple = vm.ctx.new_tuple(vec![ + vm.ctx.new_str(cipher_info.name).into(), + vm.ctx.new_str(cipher_info.protocol).into(), + vm.ctx.new_int(cipher_info.bits).into(), + ]); + Some(vm.ctx.new_list(vec![tuple.into()])) } - } - #[allow(non_snake_case)] - unsafe fn SSL_CTX_sess_connect_good(ctx: *const sys::SSL_CTX) -> libc::c_long { - unsafe { - sys::SSL_CTX_ctrl( - ctx as *mut _, - SSL_CTRL_SESS_CONNECT_GOOD, - 0, - std::ptr::null_mut(), - ) - } - } + #[pymethod] + fn verify_client_post_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { + // TLS 1.3 post-handshake authentication + // This is only valid for server-side TLS 1.3 connections - #[allow(non_snake_case)] - unsafe fn SSL_CTX_sess_connect_renegotiate(ctx: *const sys::SSL_CTX) -> libc::c_long { - unsafe { - sys::SSL_CTX_ctrl( - ctx as *mut _, - SSL_CTRL_SESS_CONNECT_RENEGOTIATE, - 0, - std::ptr::null_mut(), - ) - } - } + // Check if this is a server-side socket + if !self.server_side { + return Err(vm.new_value_error( + "Cannot perform post-handshake authentication on client-side socket", + )); + } - #[allow(non_snake_case)] - unsafe fn SSL_CTX_sess_accept(ctx: *const sys::SSL_CTX) -> libc::c_long { - unsafe { sys::SSL_CTX_ctrl(ctx as *mut _, SSL_CTRL_SESS_ACCEPT, 0, std::ptr::null_mut()) } - } + // Check if handshake has been completed + if !*self.handshake_done.lock() { + return Err(vm.new_value_error( + "Handshake must be completed before post-handshake authentication", + )); + } - #[allow(non_snake_case)] - unsafe fn SSL_CTX_sess_accept_good(ctx: *const sys::SSL_CTX) -> libc::c_long { - unsafe { - sys::SSL_CTX_ctrl( - ctx as *mut _, - SSL_CTRL_SESS_ACCEPT_GOOD, - 0, - std::ptr::null_mut(), - ) - } - } + // Check connection exists and protocol version + let conn_guard = self.connection.lock(); + if let Some(conn) = conn_guard.as_ref() { + let version = match conn { + TlsConnection::Client(_) => { + return Err(vm.new_value_error( + "Post-handshake authentication requires server socket", + )); + } + TlsConnection::Server(server) => server.protocol_version(), + }; + + // Post-handshake auth is only available in TLS 1.3 + if version != Some(rustls::ProtocolVersion::TLSv1_3) { + // Get SSLError class from ssl module (not _ssl) + // ssl.py imports _ssl.SSLError as ssl.SSLError + let ssl_mod = vm.import("ssl", 0)?; + let ssl_error_class = ssl_mod.get_attr("SSLError", vm)?; + + // Create SSLError instance with message containing WRONG_SSL_VERSION + let msg = "[SSL: WRONG_SSL_VERSION] wrong ssl version"; + let args = vm.ctx.new_tuple(vec![vm.ctx.new_str(msg).into()]); + let exc = ssl_error_class.call((args,), vm)?; + + return Err(exc + .downcast() + .map_err(|_| vm.new_type_error("Failed to create SSLError"))?); + } + } else { + return Err(vm.new_value_error("No SSL connection established")); + } - #[allow(non_snake_case)] - unsafe fn SSL_CTX_sess_accept_renegotiate(ctx: *const sys::SSL_CTX) -> libc::c_long { - unsafe { - sys::SSL_CTX_ctrl( - ctx as *mut _, - SSL_CTRL_SESS_ACCEPT_RENEGOTIATE, - 0, - std::ptr::null_mut(), - ) + // rustls doesn't provide an API for post-handshake authentication. + // The rustls TLS library does not support requesting client certificates + // after the initial handshake is completed. + // Raise SSLError instead of NotImplementedError for compatibility + Err(vm.new_exception_msg( + PySSLError::class(&vm.ctx).to_owned(), + "Post-handshake authentication is not supported by the rustls backend. \ + The rustls TLS library does not provide an API to request client certificates \ + after the initial handshake. Consider requesting the client certificate \ + during the initial handshake by setting the appropriate verify_mode before \ + calling do_handshake()." + .to_owned(), + )) } - } - #[allow(non_snake_case)] - unsafe fn SSL_CTX_sess_hits(ctx: *const sys::SSL_CTX) -> libc::c_long { - unsafe { sys::SSL_CTX_ctrl(ctx as *mut _, SSL_CTRL_SESS_HIT, 0, std::ptr::null_mut()) } - } + #[pymethod] + fn get_channel_binding( + &self, + cb_type: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult> { + let cb_type_str = cb_type.as_ref().map_or("tls-unique", |s| s.as_str()); - #[allow(non_snake_case)] - unsafe fn SSL_CTX_sess_misses(ctx: *const sys::SSL_CTX) -> libc::c_long { - unsafe { sys::SSL_CTX_ctrl(ctx as *mut _, SSL_CTRL_SESS_MISSES, 0, std::ptr::null_mut()) } - } + // rustls doesn't support channel binding (tls-unique, tls-server-end-point, etc.) + // This is because: + // 1. tls-unique requires access to TLS Finished messages, which rustls doesn't expose + // 2. tls-server-end-point requires the server certificate, which we don't track here + // 3. TLS 1.3 deprecated tls-unique anyway + // + // For compatibility, we'll return None (no channel binding available) + // rather than raising an error - #[allow(non_snake_case)] - unsafe fn SSL_CTX_sess_timeouts(ctx: *const sys::SSL_CTX) -> libc::c_long { - unsafe { - sys::SSL_CTX_ctrl( - ctx as *mut _, - SSL_CTRL_SESS_TIMEOUTS, - 0, - std::ptr::null_mut(), - ) - } - } + if cb_type_str != "tls-unique" { + return Err(vm.new_value_error(format!( + "Unsupported channel binding type '{cb_type_str}'", + ))); + } - #[allow(non_snake_case)] - unsafe fn SSL_CTX_sess_cache_full(ctx: *const sys::SSL_CTX) -> libc::c_long { - unsafe { - sys::SSL_CTX_ctrl( - ctx as *mut _, - SSL_CTRL_SESS_CACHE_FULL, - 0, - std::ptr::null_mut(), - ) + // Return None to indicate channel binding is not available + // This matches the behavior when the handshake hasn't completed yet + Ok(None) } } - // DH parameters functions - unsafe extern "C" { - fn PEM_read_DHparams( - fp: *mut libc::FILE, - x: *mut *mut sys::DH, - cb: *mut libc::c_void, - u: *mut libc::c_void, - ) -> *mut sys::DH; - } - - // OpenSSL BIO helper functions - // These are typically macros in OpenSSL, implemented via BIO_ctrl - const BIO_CTRL_PENDING: libc::c_int = 10; - const BIO_CTRL_SET_EOF: libc::c_int = 2; - - #[allow(non_snake_case)] - unsafe fn BIO_ctrl_pending(bio: *mut sys::BIO) -> usize { - unsafe { sys::BIO_ctrl(bio, BIO_CTRL_PENDING, 0, std::ptr::null_mut()) as usize } - } + impl Constructor for PySSLSocket { + type Args = (); - #[allow(non_snake_case)] - unsafe fn BIO_set_mem_eof_return(bio: *mut sys::BIO, eof: libc::c_int) -> libc::c_int { - unsafe { - sys::BIO_ctrl( - bio, - BIO_CTRL_SET_EOF, - eof as libc::c_long, - std::ptr::null_mut(), - ) as libc::c_int + fn py_new(_cls: PyTypeRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error( + "Cannot directly instantiate SSLSocket, use SSLContext.wrap_socket()", + )) } } - #[allow(non_snake_case)] - unsafe fn BIO_clear_retry_flags(bio: *mut sys::BIO) { - unsafe { - sys::BIO_clear_flags(bio, sys::BIO_FLAGS_RWS | sys::BIO_FLAGS_SHOULD_RETRY); - } + // MemoryBIO - provides in-memory buffer for SSL/TLS I/O + #[pyattr] + #[pyclass(name = "MemoryBIO", module = "ssl")] + #[derive(Debug, PyPayload)] + struct PyMemoryBIO { + // Internal buffer + buffer: PyMutex>, + // EOF flag + eof: PyRwLock, } - impl Constructor for PySslMemoryBio { - type Args = (); - - fn py_new(cls: PyTypeRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult { - unsafe { - let bio = sys::BIO_new(sys::BIO_s_mem()); - if bio.is_null() { - return Err(vm.new_memory_error("failed to allocate BIO".to_owned())); - } + #[pyclass(with(Constructor), flags(BASETYPE))] + impl PyMemoryBIO { + #[pymethod] + fn read(&self, len: OptionalArg, vm: &VirtualMachine) -> PyResult { + let mut buffer = self.buffer.lock(); - sys::BIO_set_retry_read(bio); - BIO_set_mem_eof_return(bio, -1); + if buffer.is_empty() && *self.eof.read() { + // Return empty bytes at EOF + return Ok(vm.ctx.new_bytes(vec![])); + } - PySslMemoryBio { - bio, - eof_written: AtomicCell::new(false), + let read_len = match len { + OptionalArg::Present(n) if n >= 0 => n as usize, + OptionalArg::Present(n) => { + return Err(vm.new_value_error(format!("negative read length: {n}"))); } - .into_ref_with_type(vm, cls) - .map(Into::into) - } - } - } + OptionalArg::Missing => buffer.len(), // Read all available + }; - #[pyclass(flags(IMMUTABLETYPE), with(Constructor))] - impl PySslMemoryBio { - #[pygetset] - fn pending(&self) -> usize { - unsafe { BIO_ctrl_pending(self.bio) } - } + let actual_len = read_len.min(buffer.len()); + let data = buffer.drain(..actual_len).collect::>(); - #[pygetset] - fn eof(&self) -> bool { - let pending = unsafe { BIO_ctrl_pending(self.bio) }; - pending == 0 && self.eof_written.load() + Ok(vm.ctx.new_bytes(data)) } #[pymethod] - fn read(&self, size: OptionalArg, vm: &VirtualMachine) -> PyResult> { - unsafe { - let avail = BIO_ctrl_pending(self.bio).min(i32::MAX as usize) as i32; - let len = size.unwrap_or(-1); - let len = if len < 0 || len > avail { avail } else { len }; - - if len == 0 { - return Ok(Vec::new()); + fn write(&self, buf: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // Check if it's a memoryview and if it's contiguous + if let Ok(mem_view) = buf.get_attr("c_contiguous", vm) { + // It's a memoryview, check if contiguous + let is_contiguous: bool = mem_view.try_to_bool(vm)?; + if !is_contiguous { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.buffer_error.to_owned(), + "non-contiguous buffer is not supported".to_owned(), + )); } + } - let mut buf = vec![0u8; len as usize]; - let nbytes = sys::BIO_read(self.bio, buf.as_mut_ptr() as *mut _, len); + // Convert to bytes-like object + let bytes_like = ArgBytesLike::try_from_object(vm, buf)?; + let data = bytes_like.borrow_buf(); + let len = data.len(); - if nbytes < 0 { - return Err(convert_openssl_error(vm, ErrorStack::get())); - } + let mut buffer = self.buffer.lock(); + buffer.extend_from_slice(&data); - buf.truncate(nbytes as usize); - Ok(buf) - } + Ok(len) } #[pymethod] - fn write(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult { - if self.eof_written.load() { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "cannot write() after write_eof()".to_owned(), - )); - } + fn write_eof(&self, _vm: &VirtualMachine) -> PyResult<()> { + *self.eof.write() = true; + Ok(()) + } - data.with_ref(|buf| unsafe { - if buf.len() > i32::MAX as usize { - return Err( - vm.new_overflow_error(format!("string longer than {} bytes", i32::MAX)) - ); - } + #[pygetset] + fn pending(&self) -> i32 { + self.buffer.lock().len() as i32 + } - let nbytes = sys::BIO_write(self.bio, buf.as_ptr() as *const _, buf.len() as i32); - if nbytes < 0 { - return Err(convert_openssl_error(vm, ErrorStack::get())); - } + #[pygetset] + fn eof(&self) -> bool { + // EOF is true only when buffer is empty AND write_eof has been called + let pending = self.buffer.lock().len(); + pending == 0 && *self.eof.read() + } + } - Ok(nbytes) - }) + impl Representable for PyMemoryBIO { + #[inline] + fn repr_str(_zelf: &Py, _vm: &VirtualMachine) -> PyResult { + Ok("".to_owned()) } + } - #[pymethod] - fn write_eof(&self) { - self.eof_written.store(true); - unsafe { - BIO_clear_retry_flags(self.bio); - BIO_set_mem_eof_return(self.bio, 0); + impl Constructor for PyMemoryBIO { + type Args = (); + + fn py_new(cls: PyTypeRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult { + let obj = PyMemoryBIO { + buffer: PyMutex::new(Vec::new()), + eof: PyRwLock::new(false), } + .into_ref_with_type(vm, cls)?; + Ok(obj.into()) } } - #[pyclass(flags(IMMUTABLETYPE), with(Comparable))] - impl PySslSession { + // SSLSession - represents a cached SSL session + // NOTE: This is an EMULATION - actual session data is managed by Rustls internally + #[pyattr] + #[pyclass(name = "SSLSession", module = "ssl")] + #[derive(Debug, PyPayload)] + struct PySSLSession { + // Session data - serialized rustls session (EMULATED - kept empty) + session_data: Vec, + // Session ID - synthetic ID generated from metadata (NOT actual TLS session ID) + #[allow(dead_code)] + session_id: Vec, + // Session metadata + creation_time: std::time::SystemTime, + // Lifetime in seconds (default 7200 = 2 hours) + lifetime: u64, + } + + #[pyclass(flags(BASETYPE))] + impl PySSLSession { #[pygetset] fn time(&self) -> i64 { - unsafe { - #[cfg(ossl330)] - { - sys::SSL_SESSION_get_time(self.session) as i64 - } - #[cfg(not(ossl330))] - { - sys::SSL_SESSION_get_time(self.session) as i64 - } - } + // Return session creation time as Unix timestamp + self.creation_time + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as i64 } #[pygetset] fn timeout(&self) -> i64 { - unsafe { sys::SSL_SESSION_get_timeout(self.session) as i64 } + // Return session timeout/lifetime in seconds + self.lifetime as i64 } #[pygetset] - fn ticket_lifetime_hint(&self) -> u64 { - // SSL_SESSION_get_ticket_lifetime_hint available in OpenSSL 1.1.0+ - #[cfg(ossl110)] - { - unsafe { SSL_SESSION_get_ticket_lifetime_hint(self.session) as u64 } - } - #[cfg(not(ossl110))] - { - // Not available in older OpenSSL versions - 0 - } + fn ticket_lifetime_hint(&self) -> i64 { + // Return ticket lifetime hint (same as timeout for rustls) + self.lifetime as i64 } #[pygetset] fn id(&self, vm: &VirtualMachine) -> PyBytesRef { - unsafe { - let mut len: libc::c_uint = 0; - let id_ptr = sys::SSL_SESSION_get_id(self.session, &mut len); - let id_slice = std::slice::from_raw_parts(id_ptr, len as usize); - vm.ctx.new_bytes(id_slice.to_vec()) - } + // Return session ID (hash of session data for uniqueness) + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = DefaultHasher::new(); + self.session_data.hash(&mut hasher); + let hash = hasher.finish(); + + // Convert hash to bytes + vm.ctx.new_bytes(hash.to_be_bytes().to_vec()) } #[pygetset] fn has_ticket(&self) -> bool { - // SSL_SESSION_has_ticket available in OpenSSL 1.1.0+ - #[cfg(ossl110)] - { - unsafe { SSL_SESSION_has_ticket(self.session) != 0 } - } - #[cfg(not(ossl110))] - { - // Not available in older OpenSSL versions - false - } + // For rustls, if we have session data, we have a ticket + !self.session_data.is_empty() } } - #[track_caller] - pub(crate) fn convert_openssl_error( - vm: &VirtualMachine, - err: ErrorStack, - ) -> PyBaseExceptionRef { - match err.errors().last() { - Some(e) => { - // Check if this is a system library error (errno-based) - // CPython: Modules/_ssl.c:667-671 - let lib = sys::ERR_GET_LIB(e.code()); - - if lib == sys::ERR_LIB_SYS { - // A system error is being reported; reason is set to errno - let reason = sys::ERR_GET_REASON(e.code()); - - // errno 2 = ENOENT = FileNotFoundError - let exc_type = if reason == 2 { - vm.ctx.exceptions.file_not_found_error.to_owned() - } else { - vm.ctx.exceptions.os_error.to_owned() - }; - let exc = vm.new_exception(exc_type, vec![vm.ctx.new_int(reason).into()]); - // Set errno attribute explicitly - let _ = exc - .as_object() - .set_attr("errno", vm.ctx.new_int(reason), vm); - return exc; - } - - let caller = std::panic::Location::caller(); - let (file, line) = (caller.file(), caller.line()); - let file = file - .rsplit_once(&['/', '\\'][..]) - .map_or(file, |(_, basename)| basename); - - // Get error codes - same approach as CPython - let lib = sys::ERR_GET_LIB(e.code()); - let reason = sys::ERR_GET_REASON(e.code()); - - // Look up error mnemonic from our static tables - // CPython uses dict lookup: err_codes_to_names[(lib, reason)] - let key = super::ssl_data::encode_error_key(lib, reason); - let errstr = super::ssl_data::ERROR_CODES - .get(&key) - .copied() - .or_else(|| { - // Fallback: use OpenSSL's error string - e.reason() - }) - .unwrap_or("unknown error"); - - // Check if this is a certificate verification error - // ERR_LIB_SSL = 20 (from _ssl_data_300.h) - // SSL_R_CERTIFICATE_VERIFY_FAILED = 134 (from _ssl_data_300.h) - let is_cert_verify_error = lib == 20 && reason == 134; - - // Look up library name from our static table - // CPython uses: lib_codes_to_names[lib] - let lib_name = super::ssl_data::LIBRARY_CODES.get(&(lib as u32)).copied(); - - // Use SSLCertVerificationError for certificate verification failures - let cls = if is_cert_verify_error { - PySslCertVerificationError::class(&vm.ctx).to_owned() - } else { - PySslError::class(&vm.ctx).to_owned() - }; - - // Build message - let msg = if let Some(lib_str) = lib_name { - format!("[{lib_str}] {errstr} ({file}:{line})") - } else { - format!("{errstr} ({file}:{line})") - }; - - // Create exception instance - let reason = sys::ERR_GET_REASON(e.code()); - let exc = vm.new_exception( - cls, - vec![vm.ctx.new_int(reason).into(), vm.ctx.new_str(msg).into()], - ); - - // Set attributes on instance, not class - let exc_obj: PyObjectRef = exc.into(); + impl Representable for PySSLSession { + #[inline] + fn repr_str(_zelf: &Py, _vm: &VirtualMachine) -> PyResult { + Ok("".to_owned()) + } + } - // Set reason attribute (always set, even if just the error string) - let reason_value = vm.ctx.new_str(errstr); - let _ = exc_obj.set_attr("reason", reason_value, vm); + // Helper functions - // Set library attribute (None if not available) - let library_value: PyObjectRef = if let Some(lib_str) = lib_name { - vm.ctx.new_str(lib_str).into() - } else { - vm.ctx.none() - }; - let _ = exc_obj.set_attr("library", library_value, vm); - - // For SSLCertVerificationError, set verify_code and verify_message - // Note: These will be set to None here, and can be updated by the caller - // if they have access to the SSL object - if is_cert_verify_error { - let _ = exc_obj.set_attr("verify_code", vm.ctx.none(), vm); - let _ = exc_obj.set_attr("verify_message", vm.ctx.none(), vm); - } + // OID module already imported at top of _ssl module - // Convert back to PyBaseExceptionRef - exc_obj.downcast().expect( - "exc_obj is created as PyBaseExceptionRef and must downcast successfully", - ) - } - None => { - let cls = PySslError::class(&vm.ctx).to_owned(); - vm.new_exception_empty(cls) - } - } + #[derive(FromArgs)] + struct Txt2ObjArgs { + txt: PyStrRef, + #[pyarg(named, optional)] + name: OptionalArg, } - // Helper function to set verify_code and verify_message on SSLCertVerificationError - fn set_verify_error_info( - exc: &PyBaseExceptionRef, - ssl_ptr: *const sys::SSL, - vm: &VirtualMachine, - ) { - // Get verify result - let verify_code = unsafe { sys::SSL_get_verify_result(ssl_ptr) }; - let verify_code_obj = vm.ctx.new_int(verify_code); - - // Get verify message - let verify_message = unsafe { - let verify_str = sys::X509_verify_cert_error_string(verify_code); - if verify_str.is_null() { - vm.ctx.none() - } else { - let c_str = std::ffi::CStr::from_ptr(verify_str); - vm.ctx.new_str(c_str.to_string_lossy()).into() - } + #[pyfunction] + fn txt2obj(args: Txt2ObjArgs, vm: &VirtualMachine) -> PyResult { + let txt = args.txt.as_str(); + let name = args.name.unwrap_or(false); + + // If name=False (default), only accept OID strings + // If name=True, accept both names and OID strings + let entry = if txt + .chars() + .next() + .map(|c| c.is_ascii_digit()) + .unwrap_or(false) + { + // Looks like an OID string (starts with digit) + oid::find_by_oid_string(txt) + } else if name { + // name=True: allow shortname/longname lookup + oid::find_by_name(txt) + } else { + // name=False: only OID strings allowed, not names + None }; - let exc_obj = exc.as_object(); - let _ = exc_obj.set_attr("verify_code", verify_code_obj, vm); - let _ = exc_obj.set_attr("verify_message", verify_message, vm); - } - #[track_caller] - fn convert_ssl_error( - vm: &VirtualMachine, - e: impl std::borrow::Borrow, - ) -> PyBaseExceptionRef { - let e = e.borrow(); - let (cls, msg) = match e.code() { - ssl::ErrorCode::WANT_READ => ( - PySslWantReadError::class(&vm.ctx).to_owned(), - "The operation did not complete (read)", - ), - ssl::ErrorCode::WANT_WRITE => ( - PySslWantWriteError::class(&vm.ctx).to_owned(), - "The operation did not complete (write)", - ), - ssl::ErrorCode::SYSCALL => match e.io_error() { - Some(io_err) => return io_err.to_pyexception(vm), - // When no I/O error and OpenSSL error queue is empty, - // this is an EOF in violation of protocol -> SSLEOFError - // Need to set args[0] = SSL_ERROR_EOF for suppress_ragged_eofs check - None => { - return vm.new_exception( - PySslEOFError::class(&vm.ctx).to_owned(), - vec![ - vm.ctx.new_int(SSL_ERROR_EOF).into(), - vm.ctx - .new_str("EOF occurred in violation of protocol") - .into(), - ], - ); - } - }, - ssl::ErrorCode::SSL => { - // Check for OpenSSL 3.0 SSL_R_UNEXPECTED_EOF_WHILE_READING - if let Some(ssl_err) = e.ssl_error() { - // In OpenSSL 3.0+, unexpected EOF is reported as SSL_ERROR_SSL - // with this specific reason code instead of SSL_ERROR_SYSCALL - unsafe { - let err_code = sys::ERR_peek_last_error(); - let reason = sys::ERR_GET_REASON(err_code); - let lib = sys::ERR_GET_LIB(err_code); - if lib == ERR_LIB_SSL && reason == SSL_R_UNEXPECTED_EOF_WHILE_READING { - return vm.new_exception( - PySslEOFError::class(&vm.ctx).to_owned(), - vec![ - vm.ctx.new_int(SSL_ERROR_EOF).into(), - vm.ctx - .new_str("EOF occurred in violation of protocol") - .into(), - ], - ); - } - } - return convert_openssl_error(vm, ssl_err.clone()); - } - ( - PySslError::class(&vm.ctx).to_owned(), - "A failure in the SSL library occurred", - ) - } - _ => ( - PySslError::class(&vm.ctx).to_owned(), - "A failure in the SSL library occurred", - ), - }; - vm.new_exception_msg(cls, msg.to_owned()) + let entry = entry.ok_or_else(|| vm.new_value_error(format!("unknown object '{txt}'")))?; + + // Return tuple: (nid, shortname, longname, oid) + Ok(vm + .new_tuple(( + vm.ctx.new_int(entry.nid), + vm.ctx.new_str(entry.short_name), + vm.ctx.new_str(entry.long_name), + vm.ctx.new_str(entry.oid_string()), + )) + .into()) } - // SSL_FILETYPE_ASN1 part of _add_ca_certs in CPython - fn x509_stack_from_der(der: &[u8]) -> Result, ErrorStack> { - unsafe { - openssl::init(); - let bio = bio::MemBioSlice::new(der)?; - - let mut certs = vec![]; - loop { - let cert = sys::d2i_X509_bio(bio.as_ptr(), std::ptr::null_mut()); - if cert.is_null() { - break; - } - certs.push(X509::from_ptr(cert)); - } + #[pyfunction] + fn nid2obj(nid: i32, vm: &VirtualMachine) -> PyResult { + let entry = oid::find_by_nid(nid) + .ok_or_else(|| vm.new_value_error(format!("unknown NID {nid}")))?; + + // Return tuple: (nid, shortname, longname, oid) + Ok(vm + .new_tuple(( + vm.ctx.new_int(entry.nid), + vm.ctx.new_str(entry.short_name), + vm.ctx.new_str(entry.long_name), + vm.ctx.new_str(entry.oid_string()), + )) + .into()) + } - let err = sys::ERR_peek_last_error(); + #[pyfunction] + fn get_default_verify_paths(vm: &VirtualMachine) -> PyResult { + // Return default certificate paths as a tuple + // Lib/ssl.py expects: (openssl_cafile_env, openssl_cafile, openssl_capath_env, openssl_capath) + // parts[0] = environment variable name for cafile + // parts[1] = default cafile path + // parts[2] = environment variable name for capath + // parts[3] = default capath path + + // Common default paths for different platforms + // These match the first candidates that rustls-native-certs/openssl-probe checks + #[cfg(target_os = "macos")] + let (default_cafile, default_capath) = { + // macOS primarily uses Keychain API, but provides fallback paths + // for compatibility and when Keychain access fails + (Some("/etc/ssl/cert.pem"), Some("/etc/ssl/certs")) + }; - if certs.is_empty() { - // let msg = if filetype == sys::SSL_FILETYPE_PEM { - // "no start line: cadata does not contain a certificate" - // } else { - // "not enough data: cadata does not contain a certificate" - // }; - return Err(ErrorStack::get()); - } - if err != 0 { - return Err(ErrorStack::get()); - } + #[cfg(target_os = "linux")] + let (default_cafile, default_capath) = { + // Linux: matches openssl-probe's first candidate (/etc/ssl/cert.pem) + // openssl-probe checks multiple locations at runtime, but we return + // OpenSSL's compile-time default + (Some("/etc/ssl/cert.pem"), Some("/etc/ssl/certs")) + }; - Ok(certs) - } + #[cfg(not(any(target_os = "macos", target_os = "linux")))] + let (default_cafile, default_capath): (Option<&str>, Option<&str>) = (None, None); + + let tuple = vm.ctx.new_tuple(vec![ + vm.ctx.new_str("SSL_CERT_FILE").into(), // openssl_cafile_env + default_cafile + .map(|s| vm.ctx.new_str(s).into()) + .unwrap_or_else(|| vm.ctx.none()), // openssl_cafile + vm.ctx.new_str("SSL_CERT_DIR").into(), // openssl_capath_env + default_capath + .map(|s| vm.ctx.new_str(s).into()) + .unwrap_or_else(|| vm.ctx.none()), // openssl_capath + ]); + Ok(tuple.into()) } - type CipherTuple = (&'static str, &'static str, i32); - - fn cipher_to_tuple(cipher: &ssl::SslCipherRef) -> CipherTuple { - (cipher.name(), cipher.version(), cipher.bits().secret) + #[pyfunction] + fn RAND_status() -> i32 { + 1 // Always have good randomness with aws-lc-rs } - fn cipher_description(cipher: *const sys::SSL_CIPHER) -> String { - unsafe { - // SSL_CIPHER_description writes up to 128 bytes - let mut buf = vec![0u8; 256]; - let result = sys::SSL_CIPHER_description( - cipher, - buf.as_mut_ptr() as *mut libc::c_char, - buf.len() as i32, - ); - if result.is_null() { - return String::from("No description available"); - } - // Find the null terminator - let len = buf.iter().position(|&c| c == 0).unwrap_or(buf.len()); - String::from_utf8_lossy(&buf[..len]).trim().to_string() - } + #[pyfunction] + fn RAND_add(_string: PyObjectRef, _entropy: f64) { + // No-op: aws-lc-rs handles its own entropy + // Accept any type (str, bytes, bytearray) } - impl Read for SocketStream { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - let mut socket: &PySocket = &self.0; - socket.read(buf) + #[pyfunction] + fn RAND_bytes(n: i64, vm: &VirtualMachine) -> PyResult { + use aws_lc_rs::rand::{SecureRandom, SystemRandom}; + + // Validate n is not negative + if n < 0 { + return Err(vm.new_value_error("num must be positive")); } + + let n_usize = n as usize; + let rng = SystemRandom::new(); + let mut buf = vec![0u8; n_usize]; + rng.fill(&mut buf) + .map_err(|_| vm.new_os_error("Failed to generate random bytes"))?; + Ok(PyBytesRef::from(vm.ctx.new_bytes(buf))) } - impl Write for SocketStream { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - let mut socket: &PySocket = &self.0; - socket.write(buf) - } - fn flush(&mut self) -> std::io::Result<()> { - let mut socket: &PySocket = &self.0; - socket.flush() - } + #[pyfunction] + fn RAND_pseudo_bytes(n: i64, vm: &VirtualMachine) -> PyResult<(PyBytesRef, bool)> { + // In rustls/aws-lc-rs, all random bytes are cryptographically strong + let bytes = RAND_bytes(n, vm)?; + Ok((bytes, true)) } - #[cfg(target_os = "android")] - mod android { - use super::convert_openssl_error; - use crate::vm::{VirtualMachine, builtins::PyBaseExceptionRef}; - use openssl::{ - ssl::SslContextBuilder, - x509::{X509, store::X509StoreBuilder}, - }; - use std::{ - fs::{File, read_dir}, - io::Read, - path::Path, + /// Test helper to decode a certificate from a file path + /// + /// This is a simplified wrapper around cert_der_to_dict_helper that handles + /// file reading and PEM/DER auto-detection. Used by test suite. + #[pyfunction] + fn _test_decode_cert(path: PyStrRef, vm: &VirtualMachine) -> PyResult { + // Read certificate file + let cert_data = std::fs::read(path.as_str()).map_err(|e| { + vm.new_os_error(format!( + "Failed to read certificate file {}: {}", + path.as_str(), + e + )) + })?; + + // Auto-detect PEM vs DER format + let cert_der = if cert_data + .windows(27) + .any(|w| w == b"-----BEGIN CERTIFICATE-----") + { + // Parse PEM format + let mut cursor = std::io::Cursor::new(&cert_data); + rustls_pemfile::certs(&mut cursor) + .find_map(|r| r.ok()) + .ok_or_else(|| vm.new_value_error("No valid certificate found in PEM file"))? + .to_vec() + } else { + // Assume DER format + cert_data }; - static CERT_DIR: &'static str = "/system/etc/security/cacerts"; - - pub(super) fn load_client_ca_list( - vm: &VirtualMachine, - b: &mut SslContextBuilder, - ) -> Result<(), PyBaseExceptionRef> { - let root = Path::new(CERT_DIR); - if !root.is_dir() { - return Err(vm.new_exception_msg( - vm.ctx.exceptions.file_not_found_error.to_owned(), - CERT_DIR.to_string(), - )); - } + // Reuse the comprehensive helper function + cert::cert_der_to_dict_helper(vm, &cert_der) + } - let mut combined_pem = String::new(); - let entries = read_dir(root) - .map_err(|err| vm.new_os_error(format!("read cert root: {}", err)))?; - for entry in entries { - let entry = - entry.map_err(|err| vm.new_os_error(format!("iter cert root: {}", err)))?; + #[pyfunction] + fn DER_cert_to_PEM_cert(der_cert: ArgBytesLike, vm: &VirtualMachine) -> PyResult { + let der_bytes = der_cert.borrow_buf(); + let bytes_slice: &[u8] = der_bytes.as_ref(); - let path = entry.path(); - if !path.is_file() { - continue; - } + // Use pem-rfc7468 for RFC 7468 compliant PEM encoding + let pem_str = encode_string("CERTIFICATE", LineEnding::LF, bytes_slice) + .map_err(|e| vm.new_value_error(format!("PEM encoding failed: {e}")))?; - File::open(&path) - .and_then(|mut file| file.read_to_string(&mut combined_pem)) - .map_err(|err| { - vm.new_os_error(format!("open cert file {}: {}", path.display(), err)) - })?; + Ok(vm.ctx.new_str(pem_str)) + } - combined_pem.push('\n'); - } + #[pyfunction] + fn PEM_cert_to_DER_cert(pem_cert: PyStrRef, vm: &VirtualMachine) -> PyResult { + let pem_str = pem_cert.as_str(); - let mut store_b = - X509StoreBuilder::new().map_err(|err| convert_openssl_error(vm, err))?; - let x509_vec = X509::stack_from_pem(combined_pem.as_bytes()) - .map_err(|err| convert_openssl_error(vm, err))?; - for x509 in x509_vec { - store_b - .add_cert(x509) - .map_err(|err| convert_openssl_error(vm, err))?; - } - b.set_cert_store(store_b.build()); + // Parse PEM format + let mut cursor = std::io::Cursor::new(pem_str.as_bytes()); + let mut certs = rustls_pemfile::certs(&mut cursor); - Ok(()) + if let Some(Ok(cert)) = certs.next() { + Ok(vm.ctx.new_bytes(cert.to_vec())) + } else { + Err(vm.new_value_error("Failed to parse PEM certificate")) } } -} - -#[cfg(not(ossl101))] -#[pymodule(sub)] -mod ossl101 {} -#[cfg(not(ossl111))] -#[pymodule(sub)] -mod ossl111 {} - -#[cfg(not(windows))] -#[pymodule(sub)] -mod windows {} - -#[allow(non_upper_case_globals)] -#[cfg(ossl101)] -#[pymodule(sub)] -mod ossl101 { - #[pyattr] - use openssl_sys::{ - SSL_OP_NO_COMPRESSION as OP_NO_COMPRESSION, SSL_OP_NO_TLSv1_1 as OP_NO_TLSv1_1, - SSL_OP_NO_TLSv1_2 as OP_NO_TLSv1_2, - }; -} - -#[allow(non_upper_case_globals)] -#[cfg(ossl111)] -#[pymodule(sub)] -mod ossl111 { + // Certificate type for SSL module (pure Rust implementation) #[pyattr] - use openssl_sys::SSL_OP_NO_TLSv1_3 as OP_NO_TLSv1_3; -} - -#[cfg(windows)] -#[pymodule(sub)] -mod windows { - use crate::{ - common::ascii, - vm::{ - PyObjectRef, PyPayload, PyResult, VirtualMachine, - builtins::{PyFrozenSet, PyStrRef}, - convert::ToPyException, - }, - }; - - #[pyfunction] - fn enum_certificates(store_name: PyStrRef, vm: &VirtualMachine) -> PyResult> { - use schannel::{RawPointer, cert_context::ValidUses, cert_store::CertStore}; - use windows_sys::Win32::Security::Cryptography; - - // TODO: check every store for it, not just 2 of them: - // https://github.com/python/cpython/blob/3.8/Modules/_ssl.c#L5603-L5610 - let open_fns = [CertStore::open_current_user, CertStore::open_local_machine]; - let stores = open_fns - .iter() - .filter_map(|open| open(store_name.as_str()).ok()) - .collect::>(); - let certs = stores.iter().flat_map(|s| s.certs()).map(|c| { - let cert = vm.ctx.new_bytes(c.to_der().to_owned()); - let enc_type = unsafe { - let ptr = c.as_ptr() as *const Cryptography::CERT_CONTEXT; - (*ptr).dwCertEncodingType - }; - let enc_type = match enc_type { - Cryptography::X509_ASN_ENCODING => vm.new_pyobj(ascii!("x509_asn")), - Cryptography::PKCS_7_ASN_ENCODING => vm.new_pyobj(ascii!("pkcs_7_asn")), - other => vm.new_pyobj(other), - }; - let usage: PyObjectRef = match c.valid_uses().map_err(|e| e.to_pyexception(vm))? { - ValidUses::All => vm.ctx.new_bool(true).into(), - ValidUses::Oids(oids) => PyFrozenSet::from_iter( - vm, - oids.into_iter().map(|oid| vm.ctx.new_str(oid).into()), - )? - .into_ref(&vm.ctx) - .into(), - }; - Ok(vm.new_tuple((cert, enc_type, usage)).into()) - }); - let certs: Vec = certs.collect::>>()?; - Ok(certs) + #[pyclass(module = "_ssl", name = "Certificate")] + #[derive(Debug, PyPayload)] + pub struct PySSLCertificate { + // Store the raw DER bytes + der_bytes: Vec, } -} - -mod bio { - //! based off rust-openssl's private `bio` module - use libc::c_int; - use openssl::error::ErrorStack; - use openssl_sys as sys; - use std::marker::PhantomData; + impl PySSLCertificate { + // Parse the certificate lazily + fn parse(&self) -> Result, String> { + match x509_parser::parse_x509_certificate(&self.der_bytes) { + Ok((_, cert)) => Ok(cert), + Err(e) => Err(format!("Failed to parse certificate: {e}")), + } + } + } - pub struct MemBioSlice<'a>(*mut sys::BIO, PhantomData<&'a [u8]>); + #[pyclass(with(Comparable, Hashable, Representable))] + impl PySSLCertificate { + #[pymethod] + fn public_bytes( + &self, + format: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let format = format.unwrap_or(ENCODING_PEM); - impl Drop for MemBioSlice<'_> { - fn drop(&mut self) { - unsafe { - sys::BIO_free_all(self.0); + match format { + x if x == ENCODING_DER => { + // Return DER bytes directly + Ok(vm.ctx.new_bytes(self.der_bytes.clone()).into()) + } + x if x == ENCODING_PEM => { + // Convert DER to PEM using RFC 7468 compliant encoding + let pem_str = encode_string("CERTIFICATE", LineEnding::LF, &self.der_bytes) + .map_err(|e| vm.new_value_error(format!("PEM encoding failed: {e}")))?; + Ok(vm.ctx.new_str(pem_str).into()) + } + _ => Err(vm.new_value_error("Unsupported format")), } } + + #[pymethod] + fn get_info(&self, vm: &VirtualMachine) -> PyResult { + let cert = self.parse().map_err(|e| vm.new_value_error(e))?; + cert::cert_to_dict(vm, &cert) + } } - impl<'a> MemBioSlice<'a> { - pub fn new(buf: &'a [u8]) -> Result, ErrorStack> { - openssl::init(); + // Implement Comparable trait for PySSLCertificate + impl Comparable for PySSLCertificate { + fn cmp( + zelf: &Py, + other: &PyObject, + op: PyComparisonOp, + _vm: &VirtualMachine, + ) -> PyResult { + op.eq_only(|| { + if let Some(other_cert) = other.downcast_ref::() { + Ok((zelf.der_bytes == other_cert.der_bytes).into()) + } else { + Ok(PyComparisonValue::NotImplemented) + } + }) + } + } - assert!(buf.len() <= c_int::MAX as usize); - let bio = unsafe { sys::BIO_new_mem_buf(buf.as_ptr() as *const _, buf.len() as c_int) }; - if bio.is_null() { - return Err(ErrorStack::get()); - } + // Implement Hashable trait for PySSLCertificate + impl Hashable for PySSLCertificate { + fn hash(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; - Ok(MemBioSlice(bio, PhantomData)) + let mut hasher = DefaultHasher::new(); + zelf.der_bytes.hash(&mut hasher); + Ok(hasher.finish() as PyHash) } + } - pub fn as_ptr(&self) -> *mut sys::BIO { - self.0 + // Implement Representable trait for PySSLCertificate + impl Representable for PySSLCertificate { + #[inline] + fn repr_str(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + // Try to parse and show subject + match zelf.parse() { + Ok(cert) => { + let subject = cert.subject(); + // Get CN if available + let cn = subject + .iter_common_name() + .next() + .and_then(|attr| attr.as_str().ok()) + .unwrap_or("Unknown"); + Ok(format!("")) + } + Err(_) => Ok("".to_owned()), + } } } } diff --git a/stdlib/src/ssl/cert.rs b/stdlib/src/ssl/cert.rs index 19dd09f3379..2baefad700b 100644 --- a/stdlib/src/ssl/cert.rs +++ b/stdlib/src/ssl/cert.rs @@ -1,232 +1,1774 @@ -pub(super) use ssl_cert::{PySSLCertificate, cert_to_certificate, cert_to_py, obj2txt}; - -// Certificate type for SSL module - -#[pymodule(sub)] -pub(crate) mod ssl_cert { - use crate::{ - common::ascii, - vm::{ - PyObjectRef, PyPayload, PyResult, VirtualMachine, - convert::{ToPyException, ToPyObject}, - function::{FsPath, OptionalArg}, - }, +// cspell: ignore accessdescs + +//! Certificate parsing, validation, and conversion utilities for SSL/TLS +//! +//! This module provides reusable functions for working with X.509 certificates: +//! - Parsing PEM/DER encoded certificates +//! - Validating certificate properties (CA status, etc.) +//! - Converting certificates to Python dict format +//! - Building and verifying certificate chains +//! - Loading certificates from files, directories, and bytes + +use chrono::{DateTime, Utc}; +use parking_lot::RwLock as ParkingRwLock; +use rustls::{ + DigitallySignedStruct, RootCertStore, SignatureScheme, + client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, + pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}, + server::danger::{ClientCertVerified, ClientCertVerifier}, +}; +use rustpython_vm::{PyObjectRef, PyResult, VirtualMachine}; +use std::collections::HashSet; +use std::sync::Arc; +use x509_parser::prelude::*; + +use super::compat::{VERIFY_X509_PARTIAL_CHAIN, VERIFY_X509_STRICT}; + +// Certificate Verification Constants + +/// All supported signature schemes for certificate verification +/// +/// This list includes all modern signature algorithms supported by rustls. +/// Used by verifiers that accept any signature scheme (NoVerifier, EmptyRootStoreVerifier). +const ALL_SIGNATURE_SCHEMES: &[SignatureScheme] = &[ + SignatureScheme::RSA_PKCS1_SHA256, + SignatureScheme::RSA_PKCS1_SHA384, + SignatureScheme::RSA_PKCS1_SHA512, + SignatureScheme::ECDSA_NISTP256_SHA256, + SignatureScheme::ECDSA_NISTP384_SHA384, + SignatureScheme::ECDSA_NISTP521_SHA512, + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::RSA_PSS_SHA384, + SignatureScheme::RSA_PSS_SHA512, + SignatureScheme::ED25519, +]; + +// Error Handling Utilities + +/// Certificate loading error types with specific error messages +/// +/// This module provides consistent error creation functions for certificate +/// operations, reducing code duplication and ensuring uniform error messages +/// across the codebase. +mod cert_error { + use std::io; + use std::sync::Arc; + + /// Create InvalidData error with formatted message + pub fn invalid_data(msg: impl Into) -> io::Error { + io::Error::new(io::ErrorKind::InvalidData, msg.into()) + } + + /// PEM parsing error variants + pub mod pem { + use super::*; + + pub fn no_start_line(context: &str) -> io::Error { + invalid_data(format!("no start line: {context}")) + } + + pub fn parse_failed(e: impl std::fmt::Display) -> io::Error { + invalid_data(format!("Failed to parse PEM certificate: {e}")) + } + + pub fn parse_failed_debug(e: impl std::fmt::Debug) -> io::Error { + invalid_data(format!("Failed to parse PEM certificate: {e:?}")) + } + + pub fn invalid_cert() -> io::Error { + invalid_data("No certificates found in certificate file") + } + } + + /// DER parsing error variants + pub mod der { + use super::*; + + pub fn not_enough_data(context: &str) -> io::Error { + invalid_data(format!("not enough data: {context}")) + } + + pub fn parse_failed(e: impl std::fmt::Display) -> io::Error { + invalid_data(format!("Failed to parse DER certificate: {e}")) + } + } + + /// Private key error variants + pub mod key { + use super::*; + + pub fn not_found(context: &str) -> io::Error { + invalid_data(format!("No private key found in {context}")) + } + + pub fn parse_failed(e: impl std::fmt::Display) -> io::Error { + invalid_data(format!("Failed to parse private key: {e}")) + } + + pub fn parse_encrypted_failed(e: impl std::fmt::Display) -> io::Error { + invalid_data(format!("Failed to parse encrypted private key: {e}")) + } + + pub fn decrypt_failed(e: impl std::fmt::Display) -> io::Error { + io::Error::other(format!( + "Failed to decrypt private key (wrong password?): {e}", + )) + } + } + + /// Convert error message to rustls::Error with InvalidCertificate wrapper + pub fn to_rustls_invalid_cert(msg: impl Into) -> rustls::Error { + rustls::Error::InvalidCertificate(rustls::CertificateError::Other(rustls::OtherError( + Arc::new(invalid_data(msg)), + ))) + } + + /// Convert error message to rustls::Error with InvalidCertificate wrapper and custom ErrorKind + pub fn to_rustls_cert_error(kind: io::ErrorKind, msg: impl Into) -> rustls::Error { + rustls::Error::InvalidCertificate(rustls::CertificateError::Other(rustls::OtherError( + Arc::new(io::Error::new(kind, msg.into())), + ))) + } +} + +// Helper Functions for Certificate Parsing + +/// Map X.509 OID to human-readable attribute name +/// +/// Converts common X.509 Distinguished Name OIDs to their standard names. +/// Returns the OID string itself if not recognized. +fn oid_to_attribute_name(oid_str: &str) -> &str { + match oid_str { + "2.5.4.3" => "commonName", + "2.5.4.6" => "countryName", + "2.5.4.7" => "localityName", + "2.5.4.8" => "stateOrProvinceName", + "2.5.4.10" => "organizationName", + "2.5.4.11" => "organizationalUnitName", + "1.2.840.113549.1.9.1" => "emailAddress", + _ => oid_str, + } +} + +/// Format IP address (IPv4 or IPv6) to string +/// +/// Formats raw IP address bytes according to standard notation: +/// - IPv4: dotted decimal (e.g., "192.0.2.1") +/// - IPv6: colon-separated hex (e.g., "2001:DB8:0:0:0:0:0:1") +fn format_ip_address(ip: &[u8]) -> String { + if ip.len() == 4 { + // IPv4 + format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]) + } else if ip.len() == 16 { + // IPv6 - format in full form without compression (uppercase) + // CPython returns IPv6 in full form: 2001:DB8:0:0:0:0:0:1 (not 2001:db8::1) + let segments = [ + u16::from_be_bytes([ip[0], ip[1]]), + u16::from_be_bytes([ip[2], ip[3]]), + u16::from_be_bytes([ip[4], ip[5]]), + u16::from_be_bytes([ip[6], ip[7]]), + u16::from_be_bytes([ip[8], ip[9]]), + u16::from_be_bytes([ip[10], ip[11]]), + u16::from_be_bytes([ip[12], ip[13]]), + u16::from_be_bytes([ip[14], ip[15]]), + ]; + format!( + "{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}", + segments[0], + segments[1], + segments[2], + segments[3], + segments[4], + segments[5], + segments[6], + segments[7] + ) + } else { + // Unknown format - return as debug string + format!("{ip:?}") + } +} + +/// Format ASN.1 time to string +/// +/// Formats certificate validity dates in the format: +/// "Mon DD HH:MM:SS YYYY GMT" +fn format_asn1_time(time: &x509_parser::time::ASN1Time) -> String { + let timestamp = time.timestamp(); + DateTime::::from_timestamp(timestamp, 0) + .expect("ASN1Time must be valid timestamp") + .format("%b %e %H:%M:%S %Y GMT") + .to_string() +} + +/// Format certificate serial number to hexadecimal string with even padding +/// +/// Converts a BigUint serial number to uppercase hex string, ensuring +/// even length by prepending '0' if necessary. +fn format_serial_number(serial: &num_bigint::BigUint) -> String { + let mut serial_str = serial.to_str_radix(16).to_uppercase(); + if serial_str.len() % 2 == 1 { + serial_str.insert(0, '0'); + } + serial_str +} + +/// Normalize wildcard hostname by stripping "*." prefix +/// +/// Returns the normalized hostname without the wildcard prefix. +/// Used for wildcard certificate matching. +fn normalize_wildcard_hostname(hostname: &str) -> &str { + hostname.strip_prefix("*.").unwrap_or(hostname) +} + +/// Process Subject Alternative Name (SAN) general names into Python tuples +/// +/// Converts X.509 GeneralName entries into Python tuple format. +/// Returns a vector of PyObjectRef tuples in the format: (type, value) +fn process_san_general_names( + vm: &VirtualMachine, + general_names: &[GeneralName<'_>], +) -> Vec { + general_names + .iter() + .filter_map(|name| match name { + GeneralName::DNSName(dns) => Some(vm.new_tuple(("DNS", *dns)).into()), + GeneralName::IPAddress(ip) => { + let ip_str = format_ip_address(ip); + Some(vm.new_tuple(("IP Address", ip_str)).into()) + } + GeneralName::RFC822Name(email) => Some(vm.new_tuple(("email", *email)).into()), + GeneralName::URI(uri) => Some(vm.new_tuple(("URI", *uri)).into()), + GeneralName::DirectoryName(dn) => { + let dn_str = format!("{dn}"); + Some(vm.new_tuple(("DirName", dn_str)).into()) + } + GeneralName::RegisteredID(oid) => { + let oid_str = oid.to_string(); + Some(vm.new_tuple(("Registered ID", oid_str)).into()) + } + GeneralName::OtherName(oid, value) => { + let oid_str = oid.to_string(); + let value_str = format!("{value:?}"); + Some( + vm.new_tuple(("othername", format!("{oid_str}:{value_str}"))) + .into(), + ) + } + _ => None, + }) + .collect() +} + +// Certificate Validation and Parsing + +/// Check if a certificate is a CA certificate by examining the Basic Constraints extension +/// +/// Returns `true` if the certificate has Basic Constraints with CA=true, +/// `false` otherwise (including parse errors or missing extension). +/// This matches OpenSSL's X509_check_ca() behavior. +pub fn is_ca_certificate(cert_der: &[u8]) -> bool { + // Parse the certificate + let Ok((_, cert)) = X509Certificate::from_der(cert_der) else { + return false; }; - use foreign_types_shared::ForeignTypeRef; - use openssl::{ - asn1::Asn1ObjectRef, - x509::{self, X509, X509Ref}, + + // Check Basic Constraints extension + // If extension exists and CA=true, it's a CA certificate + // Otherwise (no extension or CA=false), it's NOT a CA certificate + if let Ok(Some(ext)) = cert.basic_constraints() { + return ext.value.ca; + } + + // No Basic Constraints extension -> NOT a CA certificate + // (matches OpenSSL X509_check_ca() behavior) + false +} + +/// Convert an X509Name to Python nested tuple format for SSL certificate dicts +/// +/// Format: ((('CN', 'example.com'),), (('O', 'Example Org'),), ...) +fn name_to_py(vm: &VirtualMachine, name: &x509_parser::x509::X509Name<'_>) -> PyResult { + let list: Vec = name + .iter() + .flat_map(|rdn| { + // Each RDN can have multiple attributes + rdn.iter() + .map(|attr| { + let oid_str = attr.attr_type().to_id_string(); + let value_str = attr.attr_value().as_str().unwrap_or("").to_string(); + let key = oid_to_attribute_name(&oid_str); + + vm.new_tuple((vm.new_tuple((vm.ctx.new_str(key), vm.ctx.new_str(value_str))),)) + .into() + }) + .collect::>() + }) + .collect(); + + Ok(vm.ctx.new_tuple(list).into()) +} + +/// Convert DER-encoded certificate to Python dict (for getpeercert with binary_form=False) +/// +/// Returns a dict with fields: subject, issuer, version, serialNumber, +/// notBefore, notAfter, subjectAltName (if present) +pub fn cert_to_dict( + vm: &VirtualMachine, + cert: &x509_parser::certificate::X509Certificate<'_>, +) -> PyResult { + let dict = vm.ctx.new_dict(); + + // Subject and Issuer + dict.set_item("subject", name_to_py(vm, cert.subject())?, vm)?; + dict.set_item("issuer", name_to_py(vm, cert.issuer())?, vm)?; + + // Version (X.509 v3 = version 2 in the cert, but Python uses 3) + dict.set_item( + "version", + vm.ctx.new_int(cert.version().0 as i32 + 1).into(), + vm, + )?; + + // Serial number - hex format with even length + let serial = format_serial_number(&cert.serial); + dict.set_item("serialNumber", vm.ctx.new_str(serial).into(), vm)?; + + // Validity dates - format with GMT using chrono + dict.set_item( + "notBefore", + vm.ctx + .new_str(format_asn1_time(&cert.validity().not_before)) + .into(), + vm, + )?; + dict.set_item( + "notAfter", + vm.ctx + .new_str(format_asn1_time(&cert.validity().not_after)) + .into(), + vm, + )?; + + // Subject Alternative Names (if present) + if let Ok(Some(san_ext)) = cert.subject_alternative_name() { + let san_list = process_san_general_names(vm, &san_ext.value.general_names); + + if !san_list.is_empty() { + dict.set_item("subjectAltName", vm.ctx.new_tuple(san_list).into(), vm)?; + } + } + + Ok(dict.into()) +} + +/// Convert DER-encoded certificate to Python dict (for get_ca_certs) +/// +/// Similar to cert_to_dict but includes additional fields like crlDistributionPoints +/// and uses CPython's specific ordering: issuer, notAfter, notBefore, serialNumber, subject, version +pub fn cert_der_to_dict_helper(vm: &VirtualMachine, cert_der: &[u8]) -> PyResult { + // Parse the certificate using x509-parser + let (_, cert) = x509_parser::parse_x509_certificate(cert_der) + .map_err(|e| vm.new_value_error(format!("Failed to parse certificate: {e}")))?; + + // Helper to convert X509Name to nested tuple format + let name_to_tuple = |name: &x509_parser::x509::X509Name<'_>| -> PyResult { + let mut entries = Vec::new(); + for rdn in name.iter() { + for attr in rdn.iter() { + let oid_str = attr.attr_type().to_id_string(); + + // Get value as bytes and convert to string + let value_str = if let Ok(s) = attr.attr_value().as_str() { + s.to_string() + } else { + let value_bytes = attr.attr_value().data; + match std::str::from_utf8(value_bytes) { + Ok(s) => s.to_string(), + Err(_) => String::from_utf8_lossy(value_bytes).into_owned(), + } + }; + + let key = oid_to_attribute_name(&oid_str); + + let entry = + vm.new_tuple((vm.ctx.new_str(key.to_string()), vm.ctx.new_str(value_str))); + entries.push(vm.new_tuple((entry,)).into()); + } + } + Ok(vm.ctx.new_tuple(entries).into()) }; - use openssl_sys as sys; - use std::fmt; - - // Import constants and error converter from _ssl module - use crate::ssl::_ssl::{ENCODING_DER, ENCODING_PEM, convert_openssl_error}; - - pub(crate) fn obj2txt(obj: &Asn1ObjectRef, no_name: bool) -> Option { - let no_name = i32::from(no_name); - let ptr = obj.as_ptr(); - let b = unsafe { - let buflen = sys::OBJ_obj2txt(std::ptr::null_mut(), 0, ptr, no_name); - assert!(buflen >= 0); - if buflen == 0 { - return None; - } - let buflen = buflen as usize; - let mut buf = Vec::::with_capacity(buflen + 1); - let ret = sys::OBJ_obj2txt( - buf.as_mut_ptr() as *mut libc::c_char, - buf.capacity() as _, - ptr, - no_name, - ); - assert!(ret >= 0); - // SAFETY: OBJ_obj2txt initialized the buffer successfully - buf.set_len(buflen); - buf + + let dict = vm.ctx.new_dict(); + + // CPython ordering: issuer, notAfter, notBefore, serialNumber, subject, version + dict.set_item("issuer", name_to_tuple(cert.issuer())?, vm)?; + + // Validity - format with GMT using chrono + dict.set_item( + "notAfter", + vm.ctx + .new_str(format_asn1_time(&cert.validity().not_after)) + .into(), + vm, + )?; + dict.set_item( + "notBefore", + vm.ctx + .new_str(format_asn1_time(&cert.validity().not_before)) + .into(), + vm, + )?; + + // Serial number - hex format with even length + let serial = format_serial_number(&cert.serial); + dict.set_item("serialNumber", vm.ctx.new_str(serial).into(), vm)?; + + dict.set_item("subject", name_to_tuple(cert.subject())?, vm)?; + + // Version + dict.set_item( + "version", + vm.ctx.new_int(cert.version().0 as i32 + 1).into(), + vm, + )?; + + // Authority Information Access (OCSP and caIssuers) - use x509-parser's extensions_map + let mut ocsp_urls = Vec::new(); + let mut ca_issuer_urls = Vec::new(); + let mut crl_urls = Vec::new(); + + if let Ok(ext_map) = cert.tbs_certificate.extensions_map() { + use x509_parser::extensions::{GeneralName, ParsedExtension}; + use x509_parser::oid_registry::{ + OID_PKIX_AUTHORITY_INFO_ACCESS, OID_X509_EXT_CRL_DISTRIBUTION_POINTS, }; - let s = String::from_utf8(b) - .unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned()); - Some(s) + + // Authority Information Access + if let Some(ext) = ext_map.get(&OID_PKIX_AUTHORITY_INFO_ACCESS) + && let ParsedExtension::AuthorityInfoAccess(aia) = &ext.parsed_extension() + { + for desc in &aia.accessdescs { + if let GeneralName::URI(uri) = &desc.access_location { + let method_str = desc.access_method.to_id_string(); + if method_str == "1.3.6.1.5.5.7.48.1" { + // OCSP + ocsp_urls.push(vm.ctx.new_str(uri.to_string()).into()); + } else if method_str == "1.3.6.1.5.5.7.48.2" { + // caIssuers + ca_issuer_urls.push(vm.ctx.new_str(uri.to_string()).into()); + } + } + } + } + + // CRL Distribution Points + if let Some(ext) = ext_map.get(&OID_X509_EXT_CRL_DISTRIBUTION_POINTS) + && let ParsedExtension::CRLDistributionPoints(cdp) = &ext.parsed_extension() + { + for dp in cdp.points.iter() { + if let Some(dist_point) = &dp.distribution_point { + use x509_parser::extensions::DistributionPointName; + if let DistributionPointName::FullName(names) = dist_point { + for name in names { + if let GeneralName::URI(uri) = name { + crl_urls.push(vm.ctx.new_str(uri.to_string()).into()); + } + } + } + } + } + } } - #[pyattr] - #[pyclass(module = "ssl", name = "Certificate")] - #[derive(PyPayload)] - pub(crate) struct PySSLCertificate { - cert: X509, + if !ocsp_urls.is_empty() { + dict.set_item("OCSP", vm.ctx.new_tuple(ocsp_urls).into(), vm)?; + } + if !ca_issuer_urls.is_empty() { + dict.set_item("caIssuers", vm.ctx.new_tuple(ca_issuer_urls).into(), vm)?; + } + if !crl_urls.is_empty() { + dict.set_item( + "crlDistributionPoints", + vm.ctx.new_tuple(crl_urls).into(), + vm, + )?; } - impl fmt::Debug for PySSLCertificate { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.pad("Certificate") + // Subject Alternative Names + if let Ok(Some(san_ext)) = cert.subject_alternative_name() { + let mut san_entries = Vec::new(); + for name in &san_ext.value.general_names { + use x509_parser::extensions::GeneralName; + match name { + GeneralName::DNSName(dns) => { + san_entries.push(vm.new_tuple(("DNS", *dns)).into()); + } + GeneralName::IPAddress(ip) => { + let ip_str = format_ip_address(ip); + san_entries.push(vm.new_tuple(("IP Address", ip_str)).into()); + } + GeneralName::RFC822Name(email) => { + san_entries.push(vm.new_tuple(("email", *email)).into()); + } + GeneralName::URI(uri) => { + san_entries.push(vm.new_tuple(("URI", *uri)).into()); + } + GeneralName::OtherName(_oid, _data) => { + // OtherName is not fully supported, mark as unsupported + san_entries.push(vm.new_tuple(("othername", "")).into()); + } + GeneralName::DirectoryName(name) => { + // Convert X509Name to nested tuple format + let dir_tuple = name_to_tuple(name)?; + san_entries.push(vm.new_tuple(("DirName", dir_tuple)).into()); + } + GeneralName::RegisteredID(oid) => { + // Convert OID to string representation + let oid_str = oid.to_id_string(); + san_entries.push(vm.new_tuple(("Registered ID", oid_str)).into()); + } + _ => {} + } + } + if !san_entries.is_empty() { + dict.set_item("subjectAltName", vm.ctx.new_tuple(san_entries).into(), vm)?; } } - #[pyclass] - impl PySSLCertificate { - #[pymethod] - fn public_bytes( - &self, - format: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - let format = format.unwrap_or(ENCODING_PEM); + Ok(dict.into()) +} + +/// Build a verified certificate chain by adding CA certificates from the trust store +/// +/// Takes peer certificates (from TLS handshake) and extends the chain by finding +/// issuer certificates from the trust store until reaching a root certificate. +/// +/// Returns the complete chain as DER-encoded bytes. +pub fn build_verified_chain( + peer_certs: &[CertificateDer<'static>], + ca_certs_der: &[Vec], +) -> Vec> { + let mut chain_der: Vec> = Vec::new(); + + // Start with peer certificates (what was sent during handshake) + for cert in peer_certs { + chain_der.push(cert.as_ref().to_vec()); + } + + // Keep adding issuers until we reach a root or can't find the issuer + while let Some(der) = chain_der.last() { + let last_cert_der = der; + + // Parse the last certificate in the chain + let (_, last_cert) = match X509Certificate::from_der(last_cert_der) { + Ok(parsed) => parsed, + Err(_) => break, + }; + + // Check if it's self-signed (root certificate) + if last_cert.subject() == last_cert.issuer() { + // This is a root certificate, we're done + break; + } - match format { - x if x == ENCODING_DER => { - // DER encoding - let der = self - .cert - .to_der() - .map_err(|e| convert_openssl_error(vm, e))?; - Ok(vm.ctx.new_bytes(der).into()) + // Try to find the issuer in the trust store + let issuer_name = last_cert.issuer(); + let mut found_issuer = false; + + for ca_der in ca_certs_der.iter() { + let (_, ca_cert) = match X509Certificate::from_der(ca_der) { + Ok(parsed) => parsed, + Err(_) => continue, + }; + + // Check if this CA's subject matches the certificate's issuer + if ca_cert.subject() == issuer_name { + // Check if we already have this certificate in the chain + if !chain_der.iter().any(|existing| existing == ca_der) { + chain_der.push(ca_der.clone()); + found_issuer = true; + break; } - x if x == ENCODING_PEM => { - // PEM encoding - let pem = self - .cert - .to_pem() - .map_err(|e| convert_openssl_error(vm, e))?; - Ok(vm.ctx.new_bytes(pem).into()) + } + } + + if !found_issuer { + // Can't find issuer, stop here + break; + } + } + + chain_der +} + +/// Statistics from certificate loading operations +#[derive(Debug, Clone, Default)] +pub struct CertStats { + pub total_certs: usize, + pub ca_certs: usize, +} + +/// Certificate loader that handles PEM/DER parsing and validation +/// +/// This structure encapsulates the common pattern of loading certificates +/// from various sources (files, directories, bytes) and adding them to +/// a RootCertStore while tracking statistics. +/// +/// Duplicate certificates are detected and only counted once. +pub struct CertLoader<'a> { + store: &'a mut RootCertStore, + ca_certs_der: &'a mut Vec>, + seen_certs: HashSet>, +} + +impl<'a> CertLoader<'a> { + /// Create a new CertLoader with references to the store and DER cache + pub fn new(store: &'a mut RootCertStore, ca_certs_der: &'a mut Vec>) -> Self { + // Initialize seen_certs with existing certificates + let seen_certs = ca_certs_der.iter().cloned().collect(); + Self { + store, + ca_certs_der, + seen_certs, + } + } + + /// Load certificates from a file (supports both PEM and DER formats) + /// + /// Returns statistics about loaded certificates + pub fn load_from_file(&mut self, path: &str) -> Result { + let contents = std::fs::read(path)?; + self.load_from_bytes(&contents) + } + + /// Load certificates from a directory + /// + /// Reads all files in the directory and attempts to parse them as certificates. + /// Invalid files are silently skipped (matches OpenSSL capath behavior). + pub fn load_from_dir(&mut self, dir_path: &str) -> Result { + let entries = std::fs::read_dir(dir_path)?; + let mut stats = CertStats::default(); + + for entry in entries { + let entry = entry?; + let path = entry.path(); + + // Skip directories and process all files + // OpenSSL capath uses hash-based naming like "4e1295a3.0" + if path.is_file() + && let Ok(contents) = std::fs::read(&path) + { + // Ignore errors for individual files (some may not be certs) + if let Ok(file_stats) = self.load_from_bytes(&contents) { + stats.total_certs += file_stats.total_certs; + stats.ca_certs += file_stats.ca_certs; } - _ => Err(vm.new_value_error("Unsupported format".to_owned())), } } - #[pymethod] - fn get_info(&self, vm: &VirtualMachine) -> PyResult { - cert_to_dict(vm, &self.cert) + Ok(stats) + } + + /// Helper: Add a certificate to the store with duplicate checking + /// + /// Returns true if the certificate was added (not a duplicate), false if it was a duplicate. + fn add_cert_to_store( + &mut self, + cert_bytes: Vec, + cert_der: CertificateDer<'static>, + treat_all_as_ca: bool, + stats: &mut CertStats, + ) -> bool { + // Check for duplicates using HashSet + if !self.seen_certs.insert(cert_bytes.clone()) { + return false; // Duplicate certificate - skip + } + + // Determine if this is a CA certificate + let is_ca = if treat_all_as_ca { + true + } else { + is_ca_certificate(&cert_bytes) + }; + + // Store full DER for get_ca_certs() + self.ca_certs_der.push(cert_bytes); + + // Add to trust store (rustls may handle duplicates internally) + let _ = self.store.add(cert_der); + + // Update statistics + stats.total_certs += 1; + if is_ca { + stats.ca_certs += 1; } + + true } - fn name_to_py(vm: &VirtualMachine, name: &x509::X509NameRef) -> PyResult { - let list = name - .entries() - .map(|entry| { - let txt = obj2txt(entry.object(), false).to_pyobject(vm); - let asn1_str = entry.data(); - let data_bytes = asn1_str.as_slice(); - let data = match std::str::from_utf8(data_bytes) { - Ok(s) => vm.ctx.new_str(s.to_owned()), - Err(_) => vm - .ctx - .new_str(String::from_utf8_lossy(data_bytes).into_owned()), - }; - Ok(vm.new_tuple(((txt, data),)).into()) - }) - .collect::>()?; - Ok(vm.ctx.new_tuple(list).into()) - } - - // Helper to convert X509 to dict (for getpeercert with binary=False) - fn cert_to_dict(vm: &VirtualMachine, cert: &X509Ref) -> PyResult { - let dict = vm.ctx.new_dict(); - - dict.set_item("subject", name_to_py(vm, cert.subject_name())?, vm)?; - dict.set_item("issuer", name_to_py(vm, cert.issuer_name())?, vm)?; - // X.509 version: OpenSSL uses 0-based (0=v1, 1=v2, 2=v3) but Python uses 1-based (1=v1, 2=v2, 3=v3) - dict.set_item("version", vm.new_pyobj(cert.version() + 1), vm)?; - - let serial_num = cert - .serial_number() - .to_bn() - .and_then(|bn| bn.to_hex_str()) - .map_err(|e| convert_openssl_error(vm, e))?; - dict.set_item( - "serialNumber", - vm.ctx.new_str(serial_num.to_owned()).into(), - vm, - )?; + /// Load certificates from byte slice (auto-detects PEM vs DER format) + /// + /// Tries to parse as PEM first, falls back to DER if that fails. + /// Duplicate certificates are detected and only counted once. + /// + /// If `treat_all_as_ca` is true, all certificates are counted as CA certificates + /// regardless of their Basic Constraints (this matches + /// load_verify_locations with cadata parameter). + /// + /// If `pem_only` is true, only PEM parsing is attempted (for string input) + pub fn load_from_bytes_ex( + &mut self, + data: &[u8], + treat_all_as_ca: bool, + pem_only: bool, + ) -> Result { + let mut stats = CertStats::default(); - dict.set_item( - "notBefore", - vm.ctx.new_str(cert.not_before().to_string()).into(), - vm, - )?; - dict.set_item( - "notAfter", - vm.ctx.new_str(cert.not_after().to_string()).into(), - vm, - )?; + // Try to parse as PEM first + let mut cursor = std::io::Cursor::new(data); + let certs_iter = rustls_pemfile::certs(&mut cursor); + + let mut found_any = false; + let mut first_pem_error = None; // Store first PEM parsing error + for cert_result in certs_iter { + match cert_result { + Ok(cert) => { + found_any = true; + let cert_bytes = cert.to_vec(); + + // Validate that this is actually a valid X.509 certificate + // rustls_pemfile only does base64 decoding, not X.509 validation + if let Err(e) = X509Certificate::from_der(&cert_bytes) { + // Invalid X.509 certificate + return Err(cert_error::pem::parse_failed_debug(e)); + } + + // Add certificate using helper method (handles duplicates) + self.add_cert_to_store(cert_bytes, cert, treat_all_as_ca, &mut stats); + // Helper returns false for duplicates (skip counting) + } + Err(e) if !found_any => { + // PEM parsing failed on first certificate + if pem_only { + // For string input (PEM only), return "no start line" error + return Err(cert_error::pem::no_start_line( + "cadata does not contain a certificate", + )); + } + // Store the error and break to try DER format below + first_pem_error = Some(e); + break; + } + Err(e) => { + // PEM parsing failed after some certs were loaded + return Err(cert_error::pem::parse_failed(e)); + } + } + } + + // If PEM parsing found nothing, try DER format (unless pem_only) + // DER can have multiple certificates concatenated, so parse them sequentially + if !found_any && stats.total_certs == 0 { + // If we had a PEM parsing error, return it instead of trying DER fallback + // This ensures that malformed PEM files (like badcert.pem) raise an error + if let Some(e) = first_pem_error { + return Err(cert_error::pem::parse_failed(e)); + } - if let Some(names) = cert.subject_alt_names() { - let san: Vec = names - .iter() - .map(|gen_name| { - if let Some(email) = gen_name.email() { - vm.new_tuple((ascii!("email"), email)).into() - } else if let Some(dnsname) = gen_name.dnsname() { - vm.new_tuple((ascii!("DNS"), dnsname)).into() - } else if let Some(ip) = gen_name.ipaddress() { - // Parse IP address properly (IPv4 or IPv6) - let ip_str = if ip.len() == 4 { - // IPv4 - format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]) - } else if ip.len() == 16 { - // IPv6 - format with all zeros visible (not compressed) - let ip_addr = std::net::Ipv6Addr::from([ - ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], ip[8], - ip[9], ip[10], ip[11], ip[12], ip[13], ip[14], ip[15], - ]); - let s = ip_addr.segments(); - format!( - "{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}", - s[0], s[1], s[2], s[3], s[4], s[5], s[6], s[7] - ) + // For PEM-only mode (string input), don't fallback to DER + if pem_only { + return Err(cert_error::pem::no_start_line( + "cadata does not contain a certificate", + )); + } + let mut remaining = data; + let mut loaded_count = 0; + + while !remaining.is_empty() { + match X509Certificate::from_der(remaining) { + Ok((rest, _parsed_cert)) => { + // Extract the DER bytes for this certificate + // Length = total remaining - bytes left after parsing + let cert_len = remaining.len() - rest.len(); + let cert_bytes = &remaining[..cert_len]; + let cert_der = CertificateDer::from(cert_bytes.to_vec()); + + // Add certificate using helper method (handles duplicates) + self.add_cert_to_store( + cert_bytes.to_vec(), + cert_der, + treat_all_as_ca, + &mut stats, + ); + + loaded_count += 1; + remaining = rest; // Move to next certificate + } + Err(e) => { + if loaded_count == 0 { + // Failed to parse first certificate - invalid data + return Err(cert_error::der::not_enough_data( + "cadata does not contain a certificate", + )); } else { - // Fallback for unexpected length - String::from_utf8_lossy(ip).into_owned() - }; - vm.new_tuple((ascii!("IP Address"), ip_str)).into() - } else if let Some(uri) = gen_name.uri() { - vm.new_tuple((ascii!("URI"), uri)).into() - } else { - // Handle DirName, Registered ID, and othername - // Check if this is a directory name - if let Some(dirname) = gen_name.directory_name() - && let Ok(py_name) = name_to_py(vm, dirname) - { - return vm.new_tuple((ascii!("DirName"), py_name)).into(); + // Loaded some certificates but failed on subsequent data (garbage) + return Err(cert_error::der::parse_failed(e)); } + } + } + } + + // If we somehow got here with no certificates loaded + if loaded_count == 0 { + return Err(cert_error::der::not_enough_data( + "cadata does not contain a certificate", + )); + } + } + + Ok(stats) + } + + /// Load certificates from byte slice (auto-detects PEM vs DER format) + /// + /// This is a convenience wrapper that calls load_from_bytes_ex with treat_all_as_ca=false + /// and pem_only=false. + pub fn load_from_bytes(&mut self, data: &[u8]) -> Result { + self.load_from_bytes_ex(data, false, false) + } +} + +// NoVerifier: disables certificate verification (for CERT_NONE mode) +#[derive(Debug)] +pub struct NoVerifier; + +impl ServerCertVerifier for NoVerifier { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> Result { + // Accept all certificates without verification + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + // Accept all signatures without verification + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + // Accept all signatures without verification + Ok(HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + ALL_SIGNATURE_SCHEMES.to_vec() + } +} - // TODO: Handle Registered ID (GEN_RID) - // CPython implementation uses i2t_ASN1_OBJECT to convert OID - // This requires accessing GENERAL_NAME union which is complex in Rust - // For now, we return for unhandled types +// HostnameIgnoringVerifier: verifies certificate chain but ignores hostname +// This is used when check_hostname=False but verify_mode != CERT_NONE +// +// Unlike the previous implementation that used an inner WebPkiServerVerifier, +// this version uses webpki directly to verify only the certificate chain, +// completely bypassing hostname verification. +#[derive(Debug)] +pub struct HostnameIgnoringVerifier { + inner: Arc, +} - // For othername and other unsupported types - vm.new_tuple((ascii!("othername"), ascii!(""))) - .into() +impl HostnameIgnoringVerifier { + /// Create a new HostnameIgnoringVerifier with a pre-built verifier + /// This is useful when you need to configure the verifier with CRLs or other options + pub fn new_with_verifier(inner: Arc) -> Self { + Self { inner } + } +} + +impl ServerCertVerifier for HostnameIgnoringVerifier { + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, // Intentionally ignored + ocsp_response: &[u8], + now: UnixTime, + ) -> Result { + // Extract a hostname from the certificate to pass to inner verifier + // The inner verifier will validate certificate chain, trust anchors, etc. + // but may fail on hostname mismatch - we'll catch and ignore that error + let dummy_hostname = extract_first_dns_name(end_entity) + .unwrap_or_else(|| ServerName::try_from("localhost").expect("localhost is valid")); + + // Call inner verifier for full certificate validation + match self.inner.verify_server_cert( + end_entity, + intermediates, + &dummy_hostname, + ocsp_response, + now, + ) { + Ok(verified) => Ok(verified), + Err(e) => { + // Check if the error is a hostname mismatch + // If so, ignore it (that's the whole point of HostnameIgnoringVerifier) + match e { + rustls::Error::InvalidCertificate( + rustls::CertificateError::NotValidForName, + ) + | rustls::Error::InvalidCertificate( + rustls::CertificateError::NotValidForNameContext { .. }, + ) => { + // Hostname mismatch - this is expected and acceptable + // The certificate chain, trust anchor, and expiry are valid + Ok(ServerCertVerified::assertion()) } - }) - .collect(); - dict.set_item("subjectAltName", vm.ctx.new_tuple(san).into(), vm)?; - }; + _ => { + // Other errors (expired cert, untrusted CA, etc.) should propagate + Err(e) + } + } + } + } + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } +} + +// Helper function to extract the first DNS name from a certificate +fn extract_first_dns_name(cert_der: &CertificateDer<'_>) -> Option> { + let (_, cert) = X509Certificate::from_der(cert_der.as_ref()).ok()?; + + // Try Subject Alternative Names first + if let Ok(Some(san_ext)) = cert.subject_alternative_name() { + for name in &san_ext.value.general_names { + if let x509_parser::extensions::GeneralName::DNSName(dns) = name { + // Remove wildcard prefix if present (e.g., "*.example.com" → "example.com") + // This allows us to use the domain for certificate chain verification + // when check_hostname=False + let dns_str = dns.to_string(); + let normalized_dns = normalize_wildcard_hostname(&dns_str); + + match ServerName::try_from(normalized_dns.to_string()) { + Ok(server_name) => { + return Some(server_name); + } + Err(_e) => { + // Continue to next + } + } + } + } + } - Ok(dict.into()) + // Fallback to Common Name + for rdn in cert.subject().iter() { + for attr in rdn.iter() { + if attr.attr_type() == &x509_parser::oid_registry::OID_X509_COMMON_NAME + && let Ok(cn) = attr.attr_value().as_str() + { + // Remove wildcard prefix if present + let normalized_cn = normalize_wildcard_hostname(cn); + + match ServerName::try_from(normalized_cn.to_string()) { + Ok(server_name) => { + return Some(server_name); + } + Err(_e) => {} + } + } + } } - // Helper to create Certificate object from X509 - pub(crate) fn cert_to_certificate(vm: &VirtualMachine, cert: X509) -> PyResult { - Ok(PySSLCertificate { cert }.into_ref(&vm.ctx).into()) + None +} + +// Custom client certificate verifier for TLS 1.3 deferred validation +// This verifier always succeeds during handshake but stores verification errors +// for later retrieval during I/O operations +#[derive(Debug)] +pub struct DeferredClientCertVerifier { + // The actual verifier that performs validation + inner: Arc, + // Shared storage for deferred error message + deferred_error: Arc>>, +} + +impl DeferredClientCertVerifier { + pub fn new( + inner: Arc, + deferred_error: Arc>>, + ) -> Self { + Self { + inner, + deferred_error, + } } +} - // For getpeercert() - returns bytes or dict depending on binary flag - pub(crate) fn cert_to_py(vm: &VirtualMachine, cert: &X509Ref, binary: bool) -> PyResult { - if binary { - let b = cert.to_der().map_err(|e| convert_openssl_error(vm, e))?; - Ok(vm.ctx.new_bytes(b).into()) +impl ClientCertVerifier for DeferredClientCertVerifier { + fn offer_client_auth(&self) -> bool { + self.inner.offer_client_auth() + } + + fn client_auth_mandatory(&self) -> bool { + // Delegate to inner verifier to respect CERT_REQUIRED mode + // This ensures client certificates are mandatory when verify_mode=CERT_REQUIRED + self.inner.client_auth_mandatory() + } + + fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] { + self.inner.root_hint_subjects() + } + + fn verify_client_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + now: UnixTime, + ) -> Result { + // Perform the actual verification + let result = self + .inner + .verify_client_cert(end_entity, intermediates, now); + + // If verification failed, store the error for later + if result.is_err() { + let error_msg = "TLS handshake failed: received fatal alert: UnknownCA".to_string(); + *self.deferred_error.write() = Some(error_msg); + } + + // Always return success to allow handshake to complete + // The error will be raised during the first I/O operation + Ok(ClientCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } +} + +// Public Utility Functions + +/// Load certificate chain and private key from files +/// +/// This function loads a certificate chain from `cert_path` and a private key +/// from `key_path`. If `password` is provided, it will be used to decrypt +/// an encrypted private key. +/// +/// Returns (certificate_chain, private_key) on success. +/// +/// # Arguments +/// * `cert_path` - Path to certificate file (PEM or DER format) +/// * `key_path` - Path to private key file (PEM or DER format, optionally encrypted) +/// * `password` - Optional password for encrypted private key +/// +/// # Errors +/// Returns error if: +/// - Files cannot be read +/// - Certificate or key cannot be parsed +/// - Password is incorrect for encrypted key +pub(super) fn load_cert_chain_from_file( + cert_path: &str, + key_path: &str, + password: Option<&str>, +) -> Result<(Vec>, PrivateKeyDer<'static>), Box> { + // Load certificate file - preserve io::Error for errno + let cert_contents = std::fs::read(cert_path)?; + + // Parse certificates (PEM format) + let mut cert_cursor = std::io::Cursor::new(&cert_contents); + let certs: Vec> = rustls_pemfile::certs(&mut cert_cursor) + .collect::, _>>() + .map_err(cert_error::pem::parse_failed)?; + + if certs.is_empty() { + return Err(Box::new(cert_error::pem::invalid_cert())); + } + + // Load private key file - preserve io::Error for errno + let key_contents = std::fs::read(key_path)?; + + // Parse private key (supports PKCS8, RSA, EC formats) + let private_key = if let Some(pwd) = password { + // Try to parse as encrypted PKCS#8 + use der::SecretDocument; + use pkcs8::EncryptedPrivateKeyInfo; + use rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer}; + + let pem_str = String::from_utf8_lossy(&key_contents); + + // Extract just the ENCRYPTED PRIVATE KEY block if present + // (file may contain multiple PEM blocks like key + certificate) + let encrypted_key_pem = if let Some(start) = + pem_str.find("-----BEGIN ENCRYPTED PRIVATE KEY-----") + { + if let Some(end_marker) = pem_str[start..].find("-----END ENCRYPTED PRIVATE KEY-----") { + let end = start + end_marker + "-----END ENCRYPTED PRIVATE KEY-----".len(); + Some(&pem_str[start..end]) + } else { + None + } } else { - cert_to_dict(vm, cert) + None + }; + + // Try to decode and decrypt PEM-encoded encrypted private key using pkcs8's PEM support + let decrypted_key_result = if let Some(key_pem) = encrypted_key_pem { + match SecretDocument::from_pem(key_pem) { + Ok((label, doc)) => { + if label == "ENCRYPTED PRIVATE KEY" { + // Parse encrypted key info from DER + match EncryptedPrivateKeyInfo::try_from(doc.as_bytes()) { + Ok(encrypted_key) => { + // Decrypt with password + match encrypted_key.decrypt(pwd.as_bytes()) { + Ok(decrypted) => { + // Convert decrypted SecretDocument to PrivateKeyDer + let key_vec: Vec = decrypted.as_bytes().to_vec(); + let pkcs8_key: PrivatePkcs8KeyDer<'static> = key_vec.into(); + Some(PrivateKeyDer::Pkcs8(pkcs8_key)) + } + Err(e) => { + return Err(Box::new(cert_error::key::decrypt_failed(e))); + } + } + } + Err(e) => { + return Err(Box::new(cert_error::key::parse_encrypted_failed(e))); + } + } + } else { + None + } + } + Err(_) => None, + } + } else { + None + }; + + match decrypted_key_result { + Some(key) => key, + None => { + // Not encrypted PKCS#8, try as unencrypted key + // (password might have been provided for an unencrypted key) + let mut key_cursor = std::io::Cursor::new(&key_contents); + match rustls_pemfile::private_key(&mut key_cursor) { + Ok(Some(key)) => key, + Ok(None) => { + return Err(Box::new(cert_error::key::not_found("key file"))); + } + Err(e) => { + return Err(Box::new(cert_error::key::parse_failed(e))); + } + } + } + } + } else { + // No password provided - try to parse unencrypted key + let mut key_cursor = std::io::Cursor::new(&key_contents); + match rustls_pemfile::private_key(&mut key_cursor) { + Ok(Some(key)) => key, + Ok(None) => { + return Err(Box::new(cert_error::key::not_found("key file"))); + } + Err(e) => { + return Err(Box::new(cert_error::key::parse_failed(e))); + } + } + }; + + Ok((certs, private_key)) +} + +/// Validate that a certificate and private key match +/// +/// This function checks that the public key in the certificate matches +/// the provided private key. This is a basic sanity check to prevent +/// configuration errors. +/// +/// # Arguments +/// * `certs` - Certificate chain (first certificate is the leaf) +/// * `private_key` - Private key to validate against +/// +/// # Errors +/// Returns error if: +/// - Certificate chain is empty +/// - Public key extraction fails +/// - Keys don't match +/// +/// Note: This is a simplified validation. Full validation would require +/// signing and verifying a test message, which is complex with rustls. +pub fn validate_cert_key_match( + certs: &[CertificateDer<'_>], + private_key: &PrivateKeyDer<'_>, +) -> Result<(), String> { + if certs.is_empty() { + return Err("Certificate chain is empty".to_string()); + } + + // For rustls, the actual validation happens when creating CertifiedKey + // We can attempt to create a signing key to verify the key is valid + use rustls::crypto::aws_lc_rs::sign::any_supported_type; + + match any_supported_type(private_key) { + Ok(_signing_key) => { + // If we can create a signing key, the private key is valid + // Rustls will validate the cert-key match when building config + Ok(()) + } + Err(_) => Err("PEM lib".to_string()), + } +} + +/// StrictCertVerifier: wraps a ServerCertVerifier and adds RFC 5280 strict validation +/// +/// When VERIFY_X509_STRICT flag is set, performs additional validation: +/// - Checks for Authority Key Identifier (AKI) extension (required by RFC 5280 Section 4.2.1.1) +/// - Validates other RFC 5280 compliance requirements +/// +/// This matches X509_V_FLAG_X509_STRICT behavior in OpenSSL. +#[derive(Debug)] +pub struct StrictCertVerifier { + inner: Arc, + verify_flags: i32, +} + +impl StrictCertVerifier { + /// Create a new StrictCertVerifier + /// + /// # Arguments + /// * `inner` - The underlying verifier to wrap + /// * `verify_flags` - SSL verification flags (e.g., VERIFY_X509_STRICT) + pub fn new(inner: Arc, verify_flags: i32) -> Self { + Self { + inner, + verify_flags, } } - #[pyfunction] - pub(crate) fn _test_decode_cert(path: FsPath, vm: &VirtualMachine) -> PyResult { - let path = path.to_path_buf(vm)?; - let pem = std::fs::read(path).map_err(|e| e.to_pyexception(vm))?; - let x509 = X509::from_pem(&pem).map_err(|e| convert_openssl_error(vm, e))?; - cert_to_py(vm, &x509, false) + /// Check if a certificate has the Authority Key Identifier extension + /// + /// RFC 5280 Section 4.2.1.1 states that conforming CAs MUST include this + /// extension in all certificates except self-signed certificates. + fn check_aki_present(cert_der: &[u8]) -> Result<(), String> { + let (_, cert) = X509Certificate::from_der(cert_der) + .map_err(|e| format!("Failed to parse certificate: {e}"))?; + + // Check for Authority Key Identifier extension (OID 2.5.29.35) + let has_aki = cert + .tbs_certificate + .extensions() + .iter() + .any(|ext| ext.oid == oid_registry::OID_X509_EXT_AUTHORITY_KEY_IDENTIFIER); + + if !has_aki { + return Err( + "certificate verification failed: certificate missing required Authority Key Identifier extension" + .to_string(), + ); + } + + Ok(()) } } + +impl ServerCertVerifier for StrictCertVerifier { + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &ServerName<'_>, + ocsp_response: &[u8], + now: UnixTime, + ) -> Result { + // First, perform the standard verification + let result = self.inner.verify_server_cert( + end_entity, + intermediates, + server_name, + ocsp_response, + now, + )?; + + // If VERIFY_X509_STRICT flag is set, perform additional validation + if self.verify_flags & VERIFY_X509_STRICT != 0 { + // Check end entity certificate for AKI + // RFC 5280 Section 4.2.1.1: self-signed certificates are exempt from AKI requirement + if !is_self_signed(end_entity) { + Self::check_aki_present(end_entity.as_ref()) + .map_err(cert_error::to_rustls_invalid_cert)?; + } + + // Check intermediate certificates for AKI + for intermediate in intermediates { + Self::check_aki_present(intermediate.as_ref()) + .map_err(cert_error::to_rustls_invalid_cert)?; + } + } + + Ok(result) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } +} + +/// EmptyRootStoreVerifier: used when verify_mode != CERT_NONE but no CA certs are loaded +/// +/// This verifier always fails certificate verification with UnknownIssuer error, +/// when no root certificates are available. +/// This allows the SSL context to be created successfully, but handshake will fail +/// with a proper SSLCertVerificationError (verify_code=20, UNABLE_TO_GET_ISSUER_CERT_LOCALLY). +#[derive(Debug)] +pub struct EmptyRootStoreVerifier; + +impl ServerCertVerifier for EmptyRootStoreVerifier { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> Result { + // Always fail with UnknownIssuer - when no CA certs loaded + // This will be mapped to X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY (20) + Err(rustls::Error::InvalidCertificate( + rustls::CertificateError::UnknownIssuer, + )) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + // Accept signatures during handshake - the cert verification will fail anyway + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + // Accept signatures during handshake - the cert verification will fail anyway + Ok(HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + ALL_SIGNATURE_SCHEMES.to_vec() + } +} + +/// CRLCheckVerifier: Wraps a verifier to enforce CRL checking when flags are set +/// +/// This verifier ensures that when CRL checking flags are set (VERIFY_CRL_CHECK_LEAF = 4) +/// but no CRLs have been loaded, the verification fails with UnknownRevocationStatus. +/// This matches X509_V_FLAG_CRL_CHECK without loaded CRLs +/// causes "unable to get CRL" error. +#[derive(Debug)] +pub struct CRLCheckVerifier { + inner: Arc, + has_crls: bool, + crl_check_enabled: bool, +} + +impl CRLCheckVerifier { + pub fn new( + inner: Arc, + has_crls: bool, + crl_check_enabled: bool, + ) -> Self { + Self { + inner, + has_crls, + crl_check_enabled, + } + } +} + +impl ServerCertVerifier for CRLCheckVerifier { + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &ServerName<'_>, + ocsp_response: &[u8], + now: UnixTime, + ) -> Result { + // If CRL checking is enabled but no CRLs are loaded, fail with UnknownRevocationStatus + // X509_V_ERR_UNABLE_TO_GET_CRL (3) + if self.crl_check_enabled && !self.has_crls { + return Err(rustls::Error::InvalidCertificate( + rustls::CertificateError::UnknownRevocationStatus, + )); + } + + // Otherwise, delegate to inner verifier + self.inner + .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } +} + +/// Partial Chain Verifier - Handles VERIFY_X509_PARTIAL_CHAIN flag +/// +/// OpenSSL's X509_V_FLAG_PARTIAL_CHAIN allows verification to succeed if any certificate +/// in the presented chain is found in the trust store, not just the root CA. This is useful +/// for trusting intermediate certificates or self-signed certificates directly. +/// +/// rustls's WebPkiServerVerifier doesn't support this behavior by default, so we wrap it +/// to add partial chain support when the flag is set. +/// +/// Behavior: +/// 1. Try standard verification first (full chain to trusted root) +/// 2. If that fails and VERIFY_X509_PARTIAL_CHAIN is set: +/// - Check if the end-entity certificate is in the trust store +/// - If yes, accept the certificate as trusted +/// +/// This matches accepting self-signed certificates that +/// are explicitly loaded via load_verify_locations(). +#[derive(Debug)] +pub struct PartialChainVerifier { + inner: Arc, + ca_certs_der: Vec>, + verify_flags: i32, +} + +impl PartialChainVerifier { + pub fn new( + inner: Arc, + ca_certs_der: Vec>, + verify_flags: i32, + ) -> Self { + Self { + inner, + ca_certs_der, + verify_flags, + } + } +} + +impl ServerCertVerifier for PartialChainVerifier { + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &ServerName<'_>, + ocsp_response: &[u8], + now: UnixTime, + ) -> Result { + // Try standard verification first + match self.inner.verify_server_cert( + end_entity, + intermediates, + server_name, + ocsp_response, + now, + ) { + Ok(result) => Ok(result), + Err(e) => { + // If verification failed, check if the end-entity certificate is in the trust store + // OpenSSL behavior: + // 1. Self-signed certs in trust store: ALWAYS trusted (flag not required) + // 2. Non-self-signed end-entity certs in trust store: require VERIFY_X509_PARTIAL_CHAIN + // 3. Intermediate certs in trust store: require VERIFY_X509_PARTIAL_CHAIN + let end_entity_der = end_entity.as_ref(); + if self + .ca_certs_der + .iter() + .any(|cert_der| cert_der.as_slice() == end_entity_der) + { + // End-entity certificate is in the trust store + // Check if this is a self-signed certificate + let is_self_signed_cert = is_self_signed(end_entity); + + // Self-signed: always trust (OpenSSL behavior) + // Non-self-signed: require VERIFY_X509_PARTIAL_CHAIN flag + if is_self_signed_cert || (self.verify_flags & VERIFY_X509_PARTIAL_CHAIN != 0) { + // Certificate is trusted, but still perform hostname verification + verify_hostname(end_entity, server_name)?; + return Ok(ServerCertVerified::assertion()); + } + } + // No match found or non-self-signed without flag - return original error + Err(e) + } + } + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } +} + +// Hostname Verification: + +/// Check if a certificate is self-signed by comparing issuer and subject. +/// Returns true if the certificate is self-signed (issuer == subject). +fn is_self_signed(cert_der: &CertificateDer<'_>) -> bool { + use x509_parser::prelude::*; + + // Parse the certificate + let Ok((_, cert)) = X509Certificate::from_der(cert_der.as_ref()) else { + // If we can't parse it, assume it's not self-signed (conservative approach) + return false; + }; + + // Compare issuer and subject + // A certificate is self-signed if issuer == subject + cert.issuer() == cert.subject() +} + +/// Verify that a certificate is valid for the given hostname/IP address. +/// This function checks Subject Alternative Names (SAN) and Common Name (CN). +fn verify_hostname( + cert_der: &CertificateDer<'_>, + server_name: &ServerName<'_>, +) -> Result<(), rustls::Error> { + use x509_parser::extensions::GeneralName; + use x509_parser::prelude::*; + + // Parse the certificate + let (_, cert) = X509Certificate::from_der(cert_der.as_ref()).map_err(|e| { + cert_error::to_rustls_invalid_cert(format!( + "Failed to parse certificate for hostname verification: {e}" + )) + })?; + + match server_name { + ServerName::DnsName(dns) => { + let expected_name = dns.as_ref(); + + // 1. Check Subject Alternative Names (SAN) - preferred method + if let Ok(Some(san_ext)) = cert.subject_alternative_name() { + for name in &san_ext.value.general_names { + if let GeneralName::DNSName(dns_name) = name + && hostname_matches(expected_name, dns_name) + { + return Ok(()); + } + } + } + + // 2. Fallback to Common Name (CN) - deprecated but still checked for compatibility + for rdn in cert.subject().iter() { + for attr in rdn.iter() { + if attr.attr_type() == &x509_parser::oid_registry::OID_X509_COMMON_NAME + && let Ok(cn) = attr.attr_value().as_str() + && hostname_matches(expected_name, cn) + { + return Ok(()); + } + } + } + + // No match found - return error + Err(cert_error::to_rustls_invalid_cert(format!( + "Hostname mismatch: certificate is not valid for '{expected_name}'", + ))) + } + ServerName::IpAddress(ip) => verify_ip_address(&cert, ip), + _ => { + // Unknown server name type + Err(cert_error::to_rustls_cert_error( + std::io::ErrorKind::InvalidInput, + "Unsupported server name type for hostname verification", + )) + } + } +} + +/// Match a hostname against a pattern, supporting wildcard certificates (*.example.com). +/// Implements RFC 6125 wildcard matching rules: +/// - Wildcard must be in the leftmost label +/// - Wildcard must be the only character in that label +/// - Wildcard must match at least one character +fn hostname_matches(expected: &str, pattern: &str) -> bool { + // Wildcard matching for *.example.com + if let Some(pattern_base) = pattern.strip_prefix("*.") { + // Find the first dot in expected hostname + if let Some(dot_pos) = expected.find('.') { + let expected_base = &expected[dot_pos + 1..]; + + // The base domains must match (case insensitive) + // and the leftmost label must not be empty + return dot_pos > 0 && expected_base.eq_ignore_ascii_case(pattern_base); + } + + // No dot in expected, can't match wildcard + return false; + } + + // Exact match (case insensitive per RFC 4343) + expected.eq_ignore_ascii_case(pattern) +} + +/// Verify that a certificate is valid for the given IP address. +/// Checks Subject Alternative Names for IP Address entries. +fn verify_ip_address( + cert: &X509Certificate<'_>, + expected_ip: &rustls::pki_types::IpAddr, +) -> Result<(), rustls::Error> { + use std::net::IpAddr; + use x509_parser::extensions::GeneralName; + + // Convert rustls IpAddr to std::net::IpAddr for comparison + let expected_std_ip: IpAddr = match expected_ip { + rustls::pki_types::IpAddr::V4(octets) => IpAddr::V4(std::net::Ipv4Addr::from(*octets)), + rustls::pki_types::IpAddr::V6(octets) => IpAddr::V6(std::net::Ipv6Addr::from(*octets)), + }; + + // Check Subject Alternative Names for IP addresses + if let Ok(Some(san_ext)) = cert.subject_alternative_name() { + for name in &san_ext.value.general_names { + if let GeneralName::IPAddress(cert_ip_bytes) = name { + // Parse the IP address from the certificate + let cert_ip = match cert_ip_bytes.len() { + 4 => { + // IPv4 + if let Ok(octets) = <[u8; 4]>::try_from(*cert_ip_bytes) { + IpAddr::V4(std::net::Ipv4Addr::from(octets)) + } else { + continue; + } + } + 16 => { + // IPv6 + if let Ok(octets) = <[u8; 16]>::try_from(*cert_ip_bytes) { + IpAddr::V6(std::net::Ipv6Addr::from(octets)) + } else { + continue; + } + } + _ => continue, // Invalid IP address length + }; + + if cert_ip == expected_std_ip { + return Ok(()); + } + } + } + } + + // No matching IP address found + Err(cert_error::to_rustls_invalid_cert(format!( + "IP address mismatch: certificate is not valid for '{expected_std_ip}'", + ))) +} diff --git a/stdlib/src/ssl/compat.rs b/stdlib/src/ssl/compat.rs new file mode 100644 index 00000000000..e4f979968e4 --- /dev/null +++ b/stdlib/src/ssl/compat.rs @@ -0,0 +1,1786 @@ +// spell-checker: ignore webpki ssleof sslerror akid certsign sslerr aesgcm + +// OpenSSL compatibility layer for rustls +// +// This module provides OpenSSL-like abstractions over rustls APIs, +// making the code more readable and maintainable. Each function is named +// after its OpenSSL equivalent (e.g., ssl_do_handshake corresponds to SSL_do_handshake). + +// SSL error code data tables (shared with OpenSSL backend for compatibility) +// These map OpenSSL error codes to human-readable strings +#[path = "../openssl/ssl_data_31.rs"] +mod ssl_data; + +use crate::socket::{SelectKind, timeout_error_msg}; +use crate::vm::VirtualMachine; +use parking_lot::RwLock as ParkingRwLock; +use rustls::RootCertStore; +use rustls::client::ClientConfig; +use rustls::client::ClientConnection; +use rustls::crypto::SupportedKxGroup; +use rustls::pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer}; +use rustls::server::ResolvesServerCert; +use rustls::server::ServerConfig; +use rustls::server::ServerConnection; +use rustls::sign::CertifiedKey; +use rustpython_vm::builtins::PyBaseExceptionRef; +use rustpython_vm::function::ArgBytesLike; +use rustpython_vm::{AsObject, PyObjectRef, PyPayload, PyResult, TryFromObject}; +use std::io::Read; +use std::sync::{Arc, Once}; + +// Import PySSLSocket and helper functions from parent module +use super::_ssl::{ + PySSLCertVerificationError, PySSLError, PySSLSocket, create_ssl_eof_error, + create_ssl_want_read_error, create_ssl_want_write_error, create_ssl_zero_return_error, +}; + +// SSL Verification Flags +/// VERIFY_X509_STRICT flag for RFC 5280 strict compliance +/// When set, performs additional validation including AKI extension checks +pub const VERIFY_X509_STRICT: i32 = 0x20; + +/// VERIFY_X509_PARTIAL_CHAIN flag for partial chain validation +/// When set, accept certificates if any certificate in the chain is in the trust store +/// (not just root CAs). This matches OpenSSL's X509_V_FLAG_PARTIAL_CHAIN behavior. +pub const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000; + +// CryptoProvider Initialization: + +/// Ensure the default CryptoProvider is installed (thread-safe, runs once) +/// +/// This is necessary because rustls 0.23+ requires a process-level CryptoProvider +/// to be installed before using default_provider(). We use Once to ensure this +/// happens exactly once, even if called from multiple threads. +static INIT_PROVIDER: Once = Once::new(); + +fn ensure_default_provider() { + INIT_PROVIDER.call_once(|| { + let _ = rustls::crypto::CryptoProvider::install_default( + rustls::crypto::aws_lc_rs::default_provider(), + ); + }); +} + +// OpenSSL Constants: + +// OpenSSL TLS record maximum plaintext size (ssl/ssl_local.h) +// #define SSL3_RT_MAX_PLAIN_LENGTH 16384 +const SSL3_RT_MAX_PLAIN_LENGTH: usize = 16384; + +// OpenSSL error library codes (include/openssl/err.h) +// #define ERR_LIB_SSL 20 +const ERR_LIB_SSL: i32 = 20; + +// OpenSSL SSL error reason codes (include/openssl/sslerr.h) +// #define SSL_R_NO_SHARED_CIPHER 193 +const SSL_R_NO_SHARED_CIPHER: i32 = 193; + +// OpenSSL X509 verification flags (include/openssl/x509_vfy.h) +// #define X509_V_FLAG_CRL_CHECK 4 +const X509_V_FLAG_CRL_CHECK: i32 = 4; + +// X509 Certificate Verification Error Codes (OpenSSL Compatible): +// +// These constants match OpenSSL's X509_V_ERR_* values for certificate +// verification. They are used to map rustls certificate errors to OpenSSL +// error codes for compatibility. + +pub use x509::{ + X509_V_ERR_CERT_HAS_EXPIRED, X509_V_ERR_CERT_NOT_YET_VALID, X509_V_ERR_CERT_REVOKED, + X509_V_ERR_HOSTNAME_MISMATCH, X509_V_ERR_INVALID_PURPOSE, X509_V_ERR_IP_ADDRESS_MISMATCH, + X509_V_ERR_UNABLE_TO_GET_CRL, X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY, + X509_V_ERR_UNSPECIFIED, +}; + +#[allow(dead_code)] +mod x509 { + pub const X509_V_OK: i32 = 0; + pub const X509_V_ERR_UNSPECIFIED: i32 = 1; + pub const X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT: i32 = 2; + pub const X509_V_ERR_UNABLE_TO_GET_CRL: i32 = 3; + pub const X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE: i32 = 4; + pub const X509_V_ERR_UNABLE_TO_DECRYPT_CRL_SIGNATURE: i32 = 5; + pub const X509_V_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY: i32 = 6; + pub const X509_V_ERR_CERT_SIGNATURE_FAILURE: i32 = 7; + pub const X509_V_ERR_CRL_SIGNATURE_FAILURE: i32 = 8; + pub const X509_V_ERR_CERT_NOT_YET_VALID: i32 = 9; + pub const X509_V_ERR_CERT_HAS_EXPIRED: i32 = 10; + pub const X509_V_ERR_CRL_NOT_YET_VALID: i32 = 11; + pub const X509_V_ERR_CRL_HAS_EXPIRED: i32 = 12; + pub const X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD: i32 = 13; + pub const X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD: i32 = 14; + pub const X509_V_ERR_ERROR_IN_CRL_LAST_UPDATE_FIELD: i32 = 15; + pub const X509_V_ERR_ERROR_IN_CRL_NEXT_UPDATE_FIELD: i32 = 16; + pub const X509_V_ERR_OUT_OF_MEM: i32 = 17; + pub const X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT: i32 = 18; + pub const X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN: i32 = 19; + pub const X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY: i32 = 20; + pub const X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE: i32 = 21; + pub const X509_V_ERR_CERT_CHAIN_TOO_LONG: i32 = 22; + pub const X509_V_ERR_CERT_REVOKED: i32 = 23; + pub const X509_V_ERR_INVALID_CA: i32 = 24; + pub const X509_V_ERR_PATH_LENGTH_EXCEEDED: i32 = 25; + pub const X509_V_ERR_INVALID_PURPOSE: i32 = 26; + pub const X509_V_ERR_CERT_UNTRUSTED: i32 = 27; + pub const X509_V_ERR_CERT_REJECTED: i32 = 28; + pub const X509_V_ERR_SUBJECT_ISSUER_MISMATCH: i32 = 29; + pub const X509_V_ERR_AKID_SKID_MISMATCH: i32 = 30; + pub const X509_V_ERR_AKID_ISSUER_SERIAL_MISMATCH: i32 = 31; + pub const X509_V_ERR_KEYUSAGE_NO_CERTSIGN: i32 = 32; + pub const X509_V_ERR_UNABLE_TO_GET_CRL_ISSUER: i32 = 33; + pub const X509_V_ERR_UNHANDLED_CRITICAL_EXTENSION: i32 = 34; + pub const X509_V_ERR_KEYUSAGE_NO_CRL_SIGN: i32 = 35; + pub const X509_V_ERR_UNHANDLED_CRITICAL_CRL_EXTENSION: i32 = 36; + pub const X509_V_ERR_INVALID_NON_CA: i32 = 37; + pub const X509_V_ERR_PROXY_PATH_LENGTH_EXCEEDED: i32 = 38; + pub const X509_V_ERR_KEYUSAGE_NO_DIGITAL_SIGNATURE: i32 = 39; + pub const X509_V_ERR_PROXY_CERTIFICATES_NOT_ALLOWED: i32 = 40; + pub const X509_V_ERR_INVALID_EXTENSION: i32 = 41; + pub const X509_V_ERR_INVALID_POLICY_EXTENSION: i32 = 42; + pub const X509_V_ERR_NO_EXPLICIT_POLICY: i32 = 43; + pub const X509_V_ERR_DIFFERENT_CRL_SCOPE: i32 = 44; + pub const X509_V_ERR_UNSUPPORTED_EXTENSION_FEATURE: i32 = 45; + pub const X509_V_ERR_UNNESTED_RESOURCE: i32 = 46; + pub const X509_V_ERR_PERMITTED_VIOLATION: i32 = 47; + pub const X509_V_ERR_EXCLUDED_VIOLATION: i32 = 48; + pub const X509_V_ERR_SUBTREE_MINMAX: i32 = 49; + pub const X509_V_ERR_APPLICATION_VERIFICATION: i32 = 50; + pub const X509_V_ERR_UNSUPPORTED_CONSTRAINT_TYPE: i32 = 51; + pub const X509_V_ERR_UNSUPPORTED_CONSTRAINT_SYNTAX: i32 = 52; + pub const X509_V_ERR_UNSUPPORTED_NAME_SYNTAX: i32 = 53; + pub const X509_V_ERR_CRL_PATH_VALIDATION_ERROR: i32 = 54; + pub const X509_V_ERR_HOSTNAME_MISMATCH: i32 = 62; + pub const X509_V_ERR_EMAIL_MISMATCH: i32 = 63; + pub const X509_V_ERR_IP_ADDRESS_MISMATCH: i32 = 64; +} + +// Certificate Error Conversion Functions: + +/// Convert rustls CertificateError to X509 verification code and message +/// +/// Maps rustls certificate errors to OpenSSL X509_V_ERR_* codes for compatibility. +/// Returns (verify_code, verify_message) tuple. +fn rustls_cert_error_to_verify_info(cert_err: &rustls::CertificateError) -> (i32, &'static str) { + use rustls::CertificateError; + + match cert_err { + CertificateError::UnknownIssuer => ( + X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY, + "unable to get local issuer certificate", + ), + CertificateError::Expired => (X509_V_ERR_CERT_HAS_EXPIRED, "certificate has expired"), + CertificateError::NotValidYet => ( + X509_V_ERR_CERT_NOT_YET_VALID, + "certificate is not yet valid", + ), + CertificateError::Revoked => (X509_V_ERR_CERT_REVOKED, "certificate revoked"), + CertificateError::UnknownRevocationStatus => ( + X509_V_ERR_UNABLE_TO_GET_CRL, + "unable to get certificate CRL", + ), + CertificateError::InvalidPurpose => ( + X509_V_ERR_INVALID_PURPOSE, + "unsupported certificate purpose", + ), + CertificateError::Other(other_err) => { + // Check if this is a hostname mismatch error from our verify_hostname function + let err_msg = format!("{other_err:?}"); + if err_msg.contains("Hostname mismatch") || err_msg.contains("not valid for") { + ( + X509_V_ERR_HOSTNAME_MISMATCH, + "Hostname mismatch, certificate is not valid for", + ) + } else if err_msg.contains("IP address mismatch") { + ( + X509_V_ERR_IP_ADDRESS_MISMATCH, + "IP address mismatch, certificate is not valid for", + ) + } else { + (X509_V_ERR_UNSPECIFIED, "certificate verification failed") + } + } + _ => (X509_V_ERR_UNSPECIFIED, "certificate verification failed"), + } +} + +/// Create SSLCertVerificationError with proper attributes +/// +/// Matches CPython's _ssl.c fill_and_set_sslerror() behavior. +/// This function creates a Python SSLCertVerificationError exception with verify_code +/// and verify_message attributes set appropriately for the given rustls certificate error. +/// +/// # Note +/// If attribute setting fails (extremely rare), returns the exception without attributes +pub(super) fn create_ssl_cert_verification_error( + vm: &VirtualMachine, + cert_err: &rustls::CertificateError, +) -> PyResult { + let (verify_code, verify_message) = rustls_cert_error_to_verify_info(cert_err); + + let msg = + format!("[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: {verify_message}",); + + let exc = vm.new_exception_msg(PySSLCertVerificationError::class(&vm.ctx).to_owned(), msg); + + // Set verify_code and verify_message attributes + // Ignore errors as they're extremely rare (e.g., out of memory) + exc.as_object().set_attr( + "verify_code", + vm.ctx.new_int(verify_code).as_object().to_owned(), + vm, + )?; + exc.as_object().set_attr( + "verify_message", + vm.ctx.new_str(verify_message).as_object().to_owned(), + vm, + )?; + + exc.as_object() + .set_attr("library", vm.ctx.new_str("SSL").as_object().to_owned(), vm)?; + exc.as_object().set_attr( + "reason", + vm.ctx + .new_str("CERTIFICATE_VERIFY_FAILED") + .as_object() + .to_owned(), + vm, + )?; + + Ok(exc) +} + +/// Unified TLS connection type (client or server) +#[derive(Debug)] +pub(super) enum TlsConnection { + Client(ClientConnection), + Server(ServerConnection), +} + +impl TlsConnection { + /// Check if handshake is in progress + pub fn is_handshaking(&self) -> bool { + match self { + TlsConnection::Client(conn) => conn.is_handshaking(), + TlsConnection::Server(conn) => conn.is_handshaking(), + } + } + + /// Check if connection wants to read data + pub fn wants_read(&self) -> bool { + match self { + TlsConnection::Client(conn) => conn.wants_read(), + TlsConnection::Server(conn) => conn.wants_read(), + } + } + + /// Check if connection wants to write data + pub fn wants_write(&self) -> bool { + match self { + TlsConnection::Client(conn) => conn.wants_write(), + TlsConnection::Server(conn) => conn.wants_write(), + } + } + + /// Read TLS data from socket + pub fn read_tls(&mut self, reader: &mut dyn std::io::Read) -> std::io::Result { + match self { + TlsConnection::Client(conn) => conn.read_tls(reader), + TlsConnection::Server(conn) => conn.read_tls(reader), + } + } + + /// Write TLS data to socket + pub fn write_tls(&mut self, writer: &mut dyn std::io::Write) -> std::io::Result { + match self { + TlsConnection::Client(conn) => conn.write_tls(writer), + TlsConnection::Server(conn) => conn.write_tls(writer), + } + } + + /// Process new TLS packets + pub fn process_new_packets(&mut self) -> Result { + match self { + TlsConnection::Client(conn) => conn.process_new_packets(), + TlsConnection::Server(conn) => conn.process_new_packets(), + } + } + + /// Get reader for plaintext data (rustls native type) + pub fn reader(&mut self) -> rustls::Reader<'_> { + match self { + TlsConnection::Client(conn) => conn.reader(), + TlsConnection::Server(conn) => conn.reader(), + } + } + + /// Get writer for plaintext data (rustls native type) + pub fn writer(&mut self) -> rustls::Writer<'_> { + match self { + TlsConnection::Client(conn) => conn.writer(), + TlsConnection::Server(conn) => conn.writer(), + } + } + + /// Check if session was resumed + pub fn is_session_resumed(&self) -> bool { + use rustls::HandshakeKind; + match self { + TlsConnection::Client(conn) => { + matches!(conn.handshake_kind(), Some(HandshakeKind::Resumed)) + } + TlsConnection::Server(conn) => { + matches!(conn.handshake_kind(), Some(HandshakeKind::Resumed)) + } + } + } + + /// Send close_notify alert + pub fn send_close_notify(&mut self) { + match self { + TlsConnection::Client(conn) => conn.send_close_notify(), + TlsConnection::Server(conn) => conn.send_close_notify(), + } + } + + /// Get negotiated ALPN protocol + pub fn alpn_protocol(&self) -> Option<&[u8]> { + match self { + TlsConnection::Client(conn) => conn.alpn_protocol(), + TlsConnection::Server(conn) => conn.alpn_protocol(), + } + } + + /// Get negotiated cipher suite + pub fn negotiated_cipher_suite(&self) -> Option { + match self { + TlsConnection::Client(conn) => conn.negotiated_cipher_suite(), + TlsConnection::Server(conn) => conn.negotiated_cipher_suite(), + } + } + + /// Get peer certificates + pub fn peer_certificates(&self) -> Option<&[rustls::pki_types::CertificateDer<'static>]> { + match self { + TlsConnection::Client(conn) => conn.peer_certificates(), + TlsConnection::Server(conn) => conn.peer_certificates(), + } + } +} + +/// Error types matching OpenSSL error codes +#[derive(Debug)] +pub(super) enum SslError { + /// SSL_ERROR_WANT_READ + WantRead, + /// SSL_ERROR_WANT_WRITE + WantWrite, + /// SSL_ERROR_SYSCALL + Syscall(String), + /// SSL_ERROR_SSL + Ssl(String), + /// SSL_ERROR_ZERO_RETURN (clean closure with close_notify) + ZeroReturn, + /// Unexpected EOF without close_notify (protocol violation) + Eof, + /// Certificate verification error + CertVerification(rustls::CertificateError), + /// I/O error + Io(std::io::Error), + /// Timeout error (socket.timeout) + Timeout(String), + /// SNI callback triggered - need to restart handshake + SniCallbackRestart, + /// Python exception (pass through directly) + Py(PyBaseExceptionRef), + /// TLS alert received with OpenSSL-compatible error code + AlertReceived { lib: i32, reason: i32 }, + /// NO_SHARED_CIPHER error (OpenSSL SSL_R_NO_SHARED_CIPHER) + NoCipherSuites, +} + +impl SslError { + /// Convert TLS alert code to OpenSSL error reason code + /// OpenSSL uses reason = 1000 + alert_code for TLS alerts + fn alert_to_openssl_reason(alert: rustls::AlertDescription) -> i32 { + // AlertDescription can be converted to u8 via as u8 cast + 1000 + (u8::from(alert) as i32) + } + + /// Convert rustls error to SslError + pub fn from_rustls(err: rustls::Error) -> Self { + match err { + rustls::Error::InvalidCertificate(cert_err) => SslError::CertVerification(cert_err), + rustls::Error::AlertReceived(alert_desc) => { + // Map TLS alerts to OpenSSL-compatible error codes + // lib = 20 (ERR_LIB_SSL), reason = 1000 + alert_code + match alert_desc { + rustls::AlertDescription::CloseNotify => { + // Special case: close_notify is handled as ZeroReturn + SslError::ZeroReturn + } + _ => { + // All other alerts: convert to OpenSSL error code + // This includes InternalError (80 -> reason 1080) + SslError::AlertReceived { + lib: ERR_LIB_SSL, + reason: Self::alert_to_openssl_reason(alert_desc), + } + } + } + } + // OpenSSL 3.0 changed transport EOF from SSL_ERROR_SYSCALL with + // zero return value to SSL_ERROR_SSL with SSL_R_UNEXPECTED_EOF_WHILE_READING. + // In rustls, these cases correspond to unexpected connection closure: + rustls::Error::InvalidMessage(_) => { + // UnexpectedMessage, CorruptMessage, etc. → SSLEOFError + // Matches CPython's "EOF occurred in violation of protocol" + SslError::Eof + } + rustls::Error::PeerIncompatible(peer_err) => { + // Check for specific incompatibility types + use rustls::PeerIncompatible; + match peer_err { + PeerIncompatible::NoCipherSuitesInCommon => { + // Maps to OpenSSL SSL_R_NO_SHARED_CIPHER (lib=20, reason=193) + SslError::NoCipherSuites + } + _ => { + // Other protocol incompatibilities → SSLEOFError + SslError::Eof + } + } + } + _ => SslError::Ssl(format!("{err}")), + } + } + + /// Create SSLError with library and reason from string values + /// + /// This is the base helper for creating SSLError with _library and _reason + /// attributes when you already have the string values. + /// + /// # Arguments + /// * `vm` - Virtual machine reference + /// * `library` - Library name (e.g., "PEM", "SSL") + /// * `reason` - Error reason (e.g., "PEM lib", "NO_SHARED_CIPHER") + /// * `message` - Main error message + /// + /// # Returns + /// PyBaseExceptionRef with _library and _reason attributes set + /// + /// # Note + /// If attribute setting fails (extremely rare), returns the exception without attributes + pub(super) fn create_ssl_error_with_reason( + vm: &VirtualMachine, + library: &str, + reason: &str, + message: impl Into, + ) -> PyBaseExceptionRef { + let exc = vm.new_exception_msg(PySSLError::class(&vm.ctx).to_owned(), message.into()); + + // Set library and reason attributes + // Ignore errors as they're extremely rare (e.g., out of memory) + let _ = exc.as_object().set_attr( + "library", + vm.ctx.new_str(library).as_object().to_owned(), + vm, + ); + let _ = + exc.as_object() + .set_attr("reason", vm.ctx.new_str(reason).as_object().to_owned(), vm); + + exc + } + + /// Create SSLError with library and reason from ssl_data codes + /// + /// This helper converts OpenSSL numeric error codes to Python SSLError exceptions + /// with proper _library and _reason attributes by looking up the error strings + /// in ssl_data tables, then delegates to create_ssl_error_with_reason. + /// + /// # Arguments + /// * `vm` - Virtual machine reference + /// * `lib` - OpenSSL library code (e.g., ERR_LIB_SSL = 20) + /// * `reason` - OpenSSL reason code (e.g., SSL_R_NO_SHARED_CIPHER = 193) + /// + /// # Returns + /// PyBaseExceptionRef with _library and _reason attributes set + fn create_ssl_error_from_codes( + vm: &VirtualMachine, + lib: i32, + reason: i32, + ) -> PyBaseExceptionRef { + // Look up error strings from ssl_data tables + let key = ssl_data::encode_error_key(lib, reason); + let reason_str = ssl_data::ERROR_CODES + .get(&key) + .copied() + .unwrap_or("unknown error"); + + let lib_str = ssl_data::LIBRARY_CODES + .get(&(lib as u32)) + .copied() + .unwrap_or("UNKNOWN"); + + // Delegate to create_ssl_error_with_reason for actual exception creation + Self::create_ssl_error_with_reason(vm, lib_str, reason_str, format!("[SSL] {reason_str}")) + } + + /// Convert to Python exception + pub fn into_py_err(self, vm: &VirtualMachine) -> PyBaseExceptionRef { + match self { + SslError::WantRead => create_ssl_want_read_error(vm), + SslError::WantWrite => create_ssl_want_write_error(vm), + SslError::Timeout(msg) => timeout_error_msg(vm, msg), + SslError::Syscall(msg) => vm.new_os_error(msg), + SslError::Ssl(msg) => vm.new_exception_msg( + PySSLError::class(&vm.ctx).to_owned(), + format!("SSL error: {msg}"), + ), + SslError::ZeroReturn => create_ssl_zero_return_error(vm), + SslError::Eof => create_ssl_eof_error(vm), + SslError::CertVerification(cert_err) => { + // Use the proper cert verification error creator + create_ssl_cert_verification_error(vm, &cert_err).expect("unlikely to happen") + } + SslError::Io(err) => vm.new_os_error(format!("I/O error: {err}")), + SslError::SniCallbackRestart => { + // This should be handled at PySSLSocket level + unreachable!("SniCallbackRestart should not reach Python layer") + } + SslError::Py(exc) => exc, + SslError::AlertReceived { lib, reason } => { + Self::create_ssl_error_from_codes(vm, lib, reason) + } + SslError::NoCipherSuites => { + // OpenSSL error: lib=20 (ERR_LIB_SSL), reason=193 (SSL_R_NO_SHARED_CIPHER) + Self::create_ssl_error_from_codes(vm, ERR_LIB_SSL, SSL_R_NO_SHARED_CIPHER) + } + } + } +} + +pub type SslResult = Result; +/// Common protocol settings shared between client and server connections +#[derive(Debug)] +pub struct ProtocolSettings { + pub versions: &'static [&'static rustls::SupportedProtocolVersion], + pub kx_groups: Option>, + pub cipher_suites: Option>, + pub alpn_protocols: Vec>, +} + +/// Options for creating a server TLS configuration +#[derive(Debug)] +pub struct ServerConfigOptions { + /// Common protocol settings (versions, ALPN, KX groups, cipher suites) + pub protocol_settings: ProtocolSettings, + /// Server certificate chain + pub cert_chain: Vec>, + /// Server private key + pub private_key: PrivateKeyDer<'static>, + /// Root certificates for client verification (if required) + pub root_store: Option, + /// Whether to request client certificate + pub request_client_cert: bool, + /// Whether to use deferred client certificate validation (TLS 1.3) + pub use_deferred_validation: bool, + /// Custom certificate resolver (for SNI support) + pub cert_resolver: Option>, + /// Deferred certificate error storage (for TLS 1.3) + pub deferred_cert_error: Option>>>, + /// Session storage for server-side session resumption + pub session_storage: Option>, + /// Shared ticketer for TLS 1.2 session tickets (stateless resumption) + pub ticketer: Option>, +} + +/// Options for creating a client TLS configuration +#[derive(Debug)] +pub struct ClientConfigOptions { + /// Common protocol settings (versions, ALPN, KX groups, cipher suites) + pub protocol_settings: ProtocolSettings, + /// Root certificates for server verification + pub root_store: Option, + /// DER-encoded CA certificates (for partial chain verification) + pub ca_certs_der: Vec>, + /// Client certificate chain (for mTLS) + pub cert_chain: Option>>, + /// Client private key (for mTLS) + pub private_key: Option>, + /// Whether to verify server certificates (CERT_NONE disables verification) + pub verify_server_cert: bool, + /// Whether to check hostname against certificate (check_hostname) + pub check_hostname: bool, + /// SSL verification flags (e.g., VERIFY_X509_STRICT) + pub verify_flags: i32, + /// Session store for client-side session resumption + pub session_store: Option>, + /// Certificate Revocation Lists for CRL checking + pub crls: Vec>, +} + +/// Create custom CryptoProvider with specified cipher suites and key exchange groups +/// +/// This helper function consolidates the duplicated CryptoProvider creation logic +/// for both server and client configurations. +fn create_custom_crypto_provider( + cipher_suites: Option>, + kx_groups: Option>, +) -> Arc { + use rustls::crypto::aws_lc_rs::{ALL_CIPHER_SUITES, ALL_KX_GROUPS}; + let default_provider = rustls::crypto::aws_lc_rs::default_provider(); + + Arc::new(rustls::crypto::CryptoProvider { + cipher_suites: cipher_suites.unwrap_or_else(|| ALL_CIPHER_SUITES.to_vec()), + kx_groups: kx_groups.unwrap_or_else(|| ALL_KX_GROUPS.to_vec()), + signature_verification_algorithms: default_provider.signature_verification_algorithms, + secure_random: default_provider.secure_random, + key_provider: default_provider.key_provider, + }) +} + +/// Create a server TLS configuration +/// +/// This abstracts the complex rustls ServerConfig building logic, +/// matching SSL_CTX initialization for server sockets. +pub(super) fn create_server_config(options: ServerConfigOptions) -> Result { + use rustls::server::WebPkiClientVerifier; + + // Ensure default CryptoProvider is installed + ensure_default_provider(); + + // Create custom crypto provider using helper function + let custom_provider = create_custom_crypto_provider( + options.protocol_settings.cipher_suites.clone(), + options.protocol_settings.kx_groups.clone(), + ); + + // Step 1: Build the appropriate client cert verifier based on settings + let client_cert_verifier: Option> = + if let Some(root_store) = options.root_store { + if options.request_client_cert { + // Client certificate verification required + let base_verifier = WebPkiClientVerifier::builder(Arc::new(root_store)) + .build() + .map_err(|e| format!("Failed to create client verifier: {e}"))?; + + if options.use_deferred_validation { + // TLS 1.3: Use deferred validation + if let Some(deferred_error) = options.deferred_cert_error { + use crate::ssl::cert::DeferredClientCertVerifier; + let deferred_verifier = + DeferredClientCertVerifier::new(base_verifier, deferred_error); + Some(Arc::new(deferred_verifier)) + } else { + // No deferred error storage provided, use immediate validation + Some(base_verifier) + } + } else { + // TLS 1.2 or non-deferred: Use immediate validation + Some(base_verifier) + } + } else { + // No client authentication + None + } + } else { + // No root store - no client authentication + None + }; + + // Step 2: Create ServerConfig builder once with the selected verifier + let builder = ServerConfig::builder_with_provider(custom_provider.clone()) + .with_protocol_versions(options.protocol_settings.versions) + .map_err(|e| format!("Failed to create server config builder: {e}"))?; + + let builder = if let Some(verifier) = client_cert_verifier { + builder.with_client_cert_verifier(verifier) + } else { + builder.with_no_client_auth() + }; + + // Add certificate + let mut config = if let Some(resolver) = options.cert_resolver { + // Use custom cert resolver (e.g., for SNI) + builder.with_cert_resolver(resolver) + } else { + // Use single certificate + builder + .with_single_cert(options.cert_chain, options.private_key) + .map_err(|e| format!("Failed to set server certificate: {e}"))? + }; + + // Set ALPN protocols with fallback + apply_alpn_with_fallback( + &mut config.alpn_protocols, + &options.protocol_settings.alpn_protocols, + ); + + // Set session storage for server-side session resumption (TLS 1.3) + if let Some(session_storage) = options.session_storage { + config.session_storage = session_storage; + } + + // Set ticketer for TLS 1.2 session tickets (stateless resumption) + if let Some(ticketer) = options.ticketer { + config.ticketer = ticketer.clone(); + } + + Ok(config) +} + +/// Build WebPki verifier with CRL support +/// +/// This helper function consolidates the duplicated CRL setup logic for both +/// check_hostname=True and check_hostname=False cases. +fn build_webpki_verifier_with_crls( + root_store: Arc, + crls: Vec>, + verify_flags: i32, +) -> Result, String> { + use rustls::client::WebPkiServerVerifier; + + let mut verifier_builder = WebPkiServerVerifier::builder(root_store); + + // Check if CRL verification is requested + let crl_check_requested = verify_flags & X509_V_FLAG_CRL_CHECK != 0; + let has_crls = !crls.is_empty(); + + // Add CRLs if provided OR if CRL checking is explicitly requested + // (even with empty CRLs, rustls will fail verification if CRL checking is enabled) + if has_crls || crl_check_requested { + verifier_builder = verifier_builder.with_crls(crls); + + // Check if we should only verify end-entity (leaf) certificates + if verify_flags & X509_V_FLAG_CRL_CHECK != 0 { + verifier_builder = verifier_builder.only_check_end_entity_revocation(); + } + } + + let webpki_verifier = verifier_builder + .build() + .map_err(|e| format!("Failed to build WebPkiServerVerifier: {e}"))?; + + Ok(webpki_verifier as Arc) +} + +/// Apply verifier wrappers (CRLCheckVerifier and StrictCertVerifier) +/// +/// This helper function consolidates the duplicated verifier wrapping logic. +fn apply_verifier_wrappers( + verifier: Arc, + verify_flags: i32, + has_crls: bool, + ca_certs_der: Vec>, +) -> Arc { + let crl_check_requested = verify_flags & X509_V_FLAG_CRL_CHECK != 0; + + // Wrap with CRLCheckVerifier to enforce CRL checking when flags are set + let verifier = if crl_check_requested { + use crate::ssl::cert::CRLCheckVerifier; + Arc::new(CRLCheckVerifier::new( + verifier, + has_crls, + crl_check_requested, + )) + } else { + verifier + }; + + // Always use PartialChainVerifier when trust store is not empty + // This allows self-signed certificates in trust store to be trusted + // (OpenSSL behavior: self-signed certs are always trusted, non-self-signed require flag) + let verifier = if !ca_certs_der.is_empty() { + use crate::ssl::cert::PartialChainVerifier; + Arc::new(PartialChainVerifier::new( + verifier, + ca_certs_der, + verify_flags, + )) + } else { + verifier + }; + + // Wrap with StrictCertVerifier if VERIFY_X509_STRICT flag is set + if verify_flags & VERIFY_X509_STRICT != 0 { + Arc::new(super::cert::StrictCertVerifier::new(verifier, verify_flags)) + } else { + verifier + } +} + +/// Apply ALPN protocols +/// +/// OpenSSL 1.1.0f+ allows ALPN negotiation to fail without aborting handshake. +/// rustls follows RFC 7301 strictly and rejects connections with no matching protocol. +/// To emulate OpenSSL behavior, we add a special fallback protocol (null byte). +fn apply_alpn_with_fallback(config_alpn: &mut Vec>, alpn_protocols: &[Vec]) { + if !alpn_protocols.is_empty() { + *config_alpn = alpn_protocols.to_vec(); + config_alpn.push(vec![0u8]); // Add null byte as fallback marker + } +} + +/// Create a client TLS configuration +/// +/// This abstracts the complex rustls ClientConfig building logic, +/// matching SSL_CTX initialization for client sockets. +pub(super) fn create_client_config(options: ClientConfigOptions) -> Result { + // Ensure default CryptoProvider is installed + ensure_default_provider(); + + // Create custom crypto provider using helper function + let custom_provider = create_custom_crypto_provider( + options.protocol_settings.cipher_suites.clone(), + options.protocol_settings.kx_groups.clone(), + ); + + // Step 1: Build the appropriate verifier based on verification settings + let verifier: Arc = if options + .verify_server_cert + { + // Verify server certificates + let root_store = options + .root_store + .ok_or("Root store required for server verification")?; + + let root_store_arc = Arc::new(root_store); + + // Check if root_store is empty (no CA certs loaded) + // CPython allows this and fails during handshake with SSLCertVerificationError + if root_store_arc.is_empty() { + // Use EmptyRootStoreVerifier - always fails with UnknownIssuer during handshake + use crate::ssl::cert::EmptyRootStoreVerifier; + Arc::new(EmptyRootStoreVerifier) + } else { + // Calculate has_crls once for both hostname verification paths + let has_crls = !options.crls.is_empty(); + + if options.check_hostname { + // Default behavior: verify both certificate chain and hostname + let base_verifier = build_webpki_verifier_with_crls( + root_store_arc.clone(), + options.crls, + options.verify_flags, + )?; + + // Apply CRL and Strict verifier wrappers using helper function + apply_verifier_wrappers( + base_verifier, + options.verify_flags, + has_crls, + options.ca_certs_der.clone(), + ) + } else { + // check_hostname=False: verify certificate chain but ignore hostname + use crate::ssl::cert::HostnameIgnoringVerifier; + + // Build verifier with CRL support using helper function + let webpki_verifier = build_webpki_verifier_with_crls( + root_store_arc.clone(), + options.crls, + options.verify_flags, + )?; + + // Apply CRL verifier wrapper if needed (without Strict wrapper yet) + let crl_check_requested = options.verify_flags & X509_V_FLAG_CRL_CHECK != 0; + let verifier = if crl_check_requested { + use crate::ssl::cert::CRLCheckVerifier; + Arc::new(CRLCheckVerifier::new( + webpki_verifier, + has_crls, + crl_check_requested, + )) as Arc + } else { + webpki_verifier + }; + + // Wrap with PartialChainVerifier if VERIFY_X509_PARTIAL_CHAIN is set + const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000; + let verifier = if options.verify_flags & VERIFY_X509_PARTIAL_CHAIN != 0 { + use crate::ssl::cert::PartialChainVerifier; + Arc::new(PartialChainVerifier::new( + verifier, + options.ca_certs_der.clone(), + options.verify_flags, + )) as Arc + } else { + verifier + }; + + // Wrap with HostnameIgnoringVerifier to bypass hostname checking + let hostname_ignoring_verifier: Arc< + dyn rustls::client::danger::ServerCertVerifier, + > = Arc::new(HostnameIgnoringVerifier::new_with_verifier(verifier)); + + // Apply Strict verifier wrapper once at the end if needed + if options.verify_flags & VERIFY_X509_STRICT != 0 { + Arc::new(crate::ssl::cert::StrictCertVerifier::new( + hostname_ignoring_verifier, + options.verify_flags, + )) + } else { + hostname_ignoring_verifier + } + } + } + } else { + // CERT_NONE: disable all verification + use crate::ssl::cert::NoVerifier; + Arc::new(NoVerifier) + }; + + // Step 2: Create ClientConfig builder once with the selected verifier + let builder = ClientConfig::builder_with_provider(custom_provider.clone()) + .with_protocol_versions(options.protocol_settings.versions) + .map_err(|e| format!("Failed to create client config builder: {e}"))? + .dangerous() + .with_custom_certificate_verifier(verifier); + + // Add client certificate if provided (mTLS) + let mut config = + if let (Some(cert_chain), Some(private_key)) = (options.cert_chain, options.private_key) { + builder + .with_client_auth_cert(cert_chain, private_key) + .map_err(|e| format!("Failed to set client certificate: {e}"))? + } else { + builder.with_no_client_auth() + }; + + // Set ALPN protocols + apply_alpn_with_fallback( + &mut config.alpn_protocols, + &options.protocol_settings.alpn_protocols, + ); + + // Set session resumption + if let Some(session_store) = options.session_store { + use rustls::client::Resumption; + config.resumption = Resumption::store(session_store); + } + + Ok(config) +} + +/// Helper function - check if error is BlockingIOError +pub(super) fn is_blocking_io_error(err: &PyBaseExceptionRef, vm: &VirtualMachine) -> bool { + err.fast_isinstance(vm.ctx.exceptions.blocking_io_error) +} + +// Handshake Helper Functions + +/// Write TLS handshake data to socket/BIO +/// +/// Drains all pending TLS data from rustls and sends it to the peer. +/// Returns whether any progress was made. +fn handshake_write_loop( + conn: &mut TlsConnection, + socket: &PySSLSocket, + force_initial_write: bool, + vm: &VirtualMachine, +) -> SslResult { + let mut made_progress = false; + + while conn.wants_write() || force_initial_write { + if force_initial_write && !conn.wants_write() { + // No data to write on first iteration - break to avoid infinite loop + break; + } + + let mut buf = Vec::new(); + let written = conn + .write_tls(&mut buf as &mut dyn std::io::Write) + .map_err(SslError::Io)?; + + if written > 0 && !buf.is_empty() { + // Send directly without select - blocking sockets will wait automatically + // Handle BlockingIOError from non-blocking sockets + match socket.sock_send(buf, vm) { + Ok(_) => { + made_progress = true; + } + Err(e) => { + if is_blocking_io_error(&e, vm) { + // Non-blocking socket would block - return SSLWantWriteError + return Err(SslError::WantWrite); + } + return Err(SslError::Py(e)); + } + } + } else if written == 0 { + // No data written but wants_write is true - should not happen normally + // Break to avoid infinite loop + break; + } + + // Check if there's more to write + if !conn.wants_write() { + break; + } + } + + Ok(made_progress) +} + +/// Read TLS handshake data from socket/BIO +/// +/// Waits for and reads TLS records from the peer, handling SNI callback setup. +/// Returns (made_progress, is_first_sni_read). +fn handshake_read_data( + conn: &mut TlsConnection, + socket: &PySSLSocket, + is_bio: bool, + is_server: bool, + vm: &VirtualMachine, +) -> SslResult<(bool, bool)> { + if !conn.wants_read() { + return Ok((false, false)); + } + + // SERVER-SPECIFIC: Check if this is the first read (for SNI callback) + // Must check BEFORE reading data, so we can detect first time + let is_first_sni_read = is_server && socket.is_first_sni_read(); + + // Wait for data in socket mode + if !is_bio { + let timed_out = socket + .sock_wait_for_io_impl(SelectKind::Read, vm) + .map_err(SslError::Py)?; + + if timed_out { + // This should rarely happen now - only if socket itself has a timeout + // and we're waiting for required handshake data + return Err(SslError::Timeout("timed out".to_string())); + } + } + + let data_obj = match socket.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) { + Ok(d) => d, + Err(e) => { + if is_blocking_io_error(&e, vm) { + return Err(SslError::WantRead); + } + // In socket mode, if recv times out and we're only waiting for read + // (no wants_write), we might be waiting for optional NewSessionTicket in TLS 1.3 + // Consider the handshake complete + if !is_bio && !conn.wants_write() { + // Check if it's a timeout exception + if e.fast_isinstance(vm.ctx.exceptions.timeout_error) { + // Timeout waiting for optional data - handshake is complete + return Ok((false, false)); + } + } + return Err(SslError::Py(e)); + } + }; + + // SERVER-SPECIFIC: Save ClientHello on first read for potential connection recreation + if is_first_sni_read { + // Extract bytes from PyObjectRef + use rustpython_vm::builtins::PyBytes; + if let Some(bytes_obj) = data_obj.downcast_ref::() { + socket.save_client_hello_from_bytes(bytes_obj.as_bytes()); + } + } + + // Feed data to rustls + ssl_read_tls_records(conn, data_obj, is_bio, vm)?; + + Ok((true, is_first_sni_read)) +} + +/// Handle handshake completion for server-side TLS 1.3 +/// +/// Tries to send NewSessionTicket in non-blocking mode to avoid deadlocks. +/// Returns true if handshake is complete and we should exit. +fn handle_handshake_complete( + conn: &mut TlsConnection, + socket: &PySSLSocket, + _is_server: bool, + vm: &VirtualMachine, +) -> SslResult { + if conn.is_handshaking() { + return Ok(false); // Not complete yet + } + + // Handshake is complete! + // + // Different behavior for BIO mode vs socket mode: + // + // BIO mode (CPython-compatible): + // - Python code calls outgoing.read() to get pending data + // - We just return here and let Python handle the data + // + // Socket mode (rustls-specific): + // - OpenSSL automatically writes to socket in SSL_do_handshake() + // - We must explicitly call write_tls() to send pending data + // - Without this, client hangs waiting for server's NewSessionTicket + + if socket.is_bio_mode() { + // BIO mode: Write pending data to outgoing BIO (one-time drain) + // Python's ssl_io_loop will read from outgoing BIO + if conn.wants_write() { + // Call write_tls ONCE to drain pending data + // Do NOT loop on wants_write() - avoid infinite loop/deadlock + let tls_data = ssl_write_tls_records(conn)?; + if !tls_data.is_empty() { + socket.sock_send(tls_data, vm).map_err(SslError::Py)?; + } + + // IMPORTANT: Don't check wants_write() again! + // Python's ssl_io_loop will call do_handshake() again if needed + } + } else if conn.wants_write() { + // Send all pending data (e.g., TLS 1.3 NewSessionTicket) to socket + while conn.wants_write() { + let tls_data = ssl_write_tls_records(conn)?; + if tls_data.is_empty() { + break; + } + socket.sock_send(tls_data, vm).map_err(SslError::Py)?; + } + } + + Ok(true) +} + +/// Try to read plaintext data from TLS connection buffer +/// +/// Returns Ok(Some(n)) if n bytes were read, Ok(None) if would block, +/// or Err on real errors. +fn try_read_plaintext(conn: &mut TlsConnection, buf: &mut [u8]) -> SslResult> { + let mut reader = conn.reader(); + match reader.read(buf) { + Ok(0) => { + // EOF from TLS connection + Ok(Some(0)) + } + Ok(n) => { + // Successfully read n bytes + Ok(Some(n)) + } + Err(e) if e.kind() != std::io::ErrorKind::WouldBlock => { + // Real error + Err(SslError::Io(e)) + } + Err(_) => { + // WouldBlock - no plaintext available + Ok(None) + } + } +} + +/// Equivalent to OpenSSL's SSL_do_handshake() +/// +/// Performs TLS handshake by exchanging data with the peer until completion. +/// This abstracts away the low-level rustls read_tls/write_tls loop. +/// +/// = SSL_do_handshake() +pub(super) fn ssl_do_handshake( + conn: &mut TlsConnection, + socket: &PySSLSocket, + vm: &VirtualMachine, +) -> SslResult<()> { + // Check if handshake is already done + if !conn.is_handshaking() { + return Ok(()); + } + + let is_bio = socket.is_bio_mode(); + let is_server = matches!(conn, TlsConnection::Server(_)); + let mut first_iteration = true; // Track if this is the first loop iteration + let mut iteration_count = 0; + + loop { + iteration_count += 1; + let mut made_progress = false; + + // IMPORTANT: In BIO mode, force initial write even if wants_write() is false + // rustls requires write_tls() to be called to generate ClientHello/ServerHello + let force_initial_write = is_bio && first_iteration; + + // Write TLS handshake data to socket/BIO + let write_progress = handshake_write_loop(conn, socket, force_initial_write, vm)?; + made_progress |= write_progress; + + // Read TLS handshake data from socket/BIO + let (read_progress, is_first_sni_read) = + handshake_read_data(conn, socket, is_bio, is_server, vm)?; + made_progress |= read_progress; + + // Process TLS packets (state machine) + if let Err(e) = conn.process_new_packets() { + // Send close_notify on error + if !is_bio { + conn.send_close_notify(); + // Actually send the close_notify alert + if let Ok(alert_data) = ssl_write_tls_records(conn) + && !alert_data.is_empty() + { + let _ = socket.sock_send(alert_data, vm); + } + } + + // Certificate verification errors are already handled by from_rustls + + return Err(SslError::from_rustls(e)); + } + + // SERVER-SPECIFIC: Check SNI callback after processing packets + // SNI name is extracted during process_new_packets() + // Invoke callback on FIRST read if callback is configured, regardless of SNI presence + if is_server && is_first_sni_read && socket.has_sni_callback() { + // IMPORTANT: Do NOT call the callback here! + // The connection lock is still held, which would cause deadlock. + // Return SniCallbackRestart to signal do_handshake to: + // 1. Drop conn_guard + // 2. Call the callback (with Some(name) or None) + // 3. Restart handshake + return Err(SslError::SniCallbackRestart); + } + + // Check if handshake is complete and handle post-handshake processing + // CRITICAL: We do NOT check wants_read() - this matches CPython/OpenSSL behavior! + if handle_handshake_complete(conn, socket, is_server, vm)? { + return Ok(()); + } + + // In BIO mode, stop after one iteration + if is_bio { + // Before returning WANT error, write any pending TLS data to BIO + // This is critical: if wants_write is true after process_new_packets, + // we need to write that data to the outgoing BIO before returning + if conn.wants_write() { + // Write all pending TLS data to outgoing BIO + loop { + let mut buf = vec![0u8; SSL3_RT_MAX_PLAIN_LENGTH]; + let n = match conn.write_tls(&mut buf.as_mut_slice()) { + Ok(n) => n, + Err(_) => break, + }; + if n == 0 { + break; + } + // Send to outgoing BIO + socket + .sock_send(buf[..n].to_vec(), vm) + .map_err(SslError::Py)?; + // Check if there's more to write + if !conn.wants_write() { + break; + } + } + // After writing, check if we still want more + // If all data was written, wants_write may now be false + if conn.wants_write() { + // Still need more - return WANT_WRITE + return Err(SslError::WantWrite); + } + // Otherwise fall through to check wants_read + } + + // Check if we need to read + if conn.wants_read() { + return Err(SslError::WantRead); + } + break; + } + + // Mark that we've completed the first iteration + first_iteration = false; + + // Improved loop termination logic: + // Continue looping if: + // 1. Rustls wants more I/O (wants_read or wants_write), OR + // 2. We made progress in this iteration + // + // This is more robust than just checking made_progress, because: + // - Rustls may need multiple iterations to process TLS state machine + // - Network delays may cause temporary "no progress" situations + // - wants_read/wants_write accurately reflect Rustls internal state + let should_continue = conn.wants_read() || conn.wants_write() || made_progress; + + if !should_continue { + break; + } + + // Safety check: prevent truly infinite loops (should never happen) + if iteration_count > 1000 { + break; + } + } + + // If we exit the loop without completing handshake, return error + // Check rustls state to provide better error message + if conn.is_handshaking() { + Err(SslError::Syscall(format!( + "SSL handshake failed: incomplete after {iteration_count} iterations", + ))) + } else { + // Handshake completed successfully (shouldn't reach here normally) + Ok(()) + } +} + +/// Equivalent to OpenSSL's SSL_read() +/// +/// Reads application data from TLS connection. +/// Automatically handles TLS record I/O as needed. +/// +/// = SSL_read_ex() +pub(super) fn ssl_read( + conn: &mut TlsConnection, + buf: &mut [u8], + socket: &PySSLSocket, + vm: &VirtualMachine, +) -> SslResult { + let is_bio = socket.is_bio_mode(); + + // Get socket timeout and calculate deadline (= _PyDeadline_Init) + let deadline = if !is_bio { + match socket.get_socket_timeout(vm).map_err(SslError::Py)? { + Some(timeout) if !timeout.is_zero() => Some(std::time::Instant::now() + timeout), + _ => None, // None = blocking (no deadline), Some(0) = non-blocking (handled below) + } + } else { + None // BIO mode has no deadline + }; + + // Loop to handle TLS records and post-handshake messages + // Matches SSL_read behavior which loops until data is available + // - CPython uses OpenSSL's SSL_read which loops on SSL_ERROR_WANT_READ/WANT_WRITE + // - We use rustls which requires manual read_tls/process_new_packets loop + // - No iteration limit: relies on deadline and blocking I/O + // - Blocking sockets: sock_select() and recv() wait at kernel level (no CPU busy-wait) + // - Non-blocking sockets: immediate return on first WantRead + // - Deadline prevents timeout issues + loop { + // Check deadline + if let Some(deadline) = deadline + && std::time::Instant::now() >= deadline + { + // Timeout expired + return Err(SslError::Timeout( + "The read operation timed out".to_string(), + )); + } + // Check if we need to read more TLS records BEFORE trying plaintext read + // This ensures we don't miss data that's already been processed + let needs_more_tls = conn.wants_read(); + + // Try to read plaintext from rustls buffer + if let Some(n) = try_read_plaintext(conn, buf)? { + return Ok(n); + } + + // No plaintext available and cannot read more TLS records + if !needs_more_tls { + if is_bio && let Some(bio_obj) = socket.incoming_bio() { + let is_eof = bio_obj + .get_attr("eof", vm) + .and_then(|v| v.try_into_value::(vm)) + .unwrap_or(false); + if is_eof { + return Err(SslError::Eof); + } + } + return Err(SslError::WantRead); + } + + // Read and process TLS records + // This will block for blocking sockets + match ssl_ensure_data_available(conn, socket, vm) { + Ok(_bytes_read) => { + // Successfully read and processed TLS data + // Continue loop to try reading plaintext + } + Err(SslError::Io(ref io_err)) if io_err.to_string().contains("message buffer full") => { + // Buffer is full - we need to consume plaintext before reading more + // Try to read plaintext now + match try_read_plaintext(conn, buf)? { + Some(n) if n > 0 => { + // Have plaintext - return it + // Python will call read() again if it needs more data + return Ok(n); + } + _ => { + // No plaintext available yet - this is unusual + // Return WantRead to let Python retry + return Err(SslError::WantRead); + } + } + } + Err(e) => { + // Other errors - check for buffered plaintext before propagating + match try_read_plaintext(conn, buf)? { + Some(n) if n > 0 => { + // Have buffered plaintext - return it successfully + return Ok(n); + } + _ => { + // No buffered data - propagate the error + return Err(e); + } + } + } + } + } +} + +// Helper functions (private-ish, used by public SSL functions) + +/// Write TLS records from rustls to socket +fn ssl_write_tls_records(conn: &mut TlsConnection) -> SslResult> { + let mut buf = Vec::new(); + let n = conn + .write_tls(&mut buf as &mut dyn std::io::Write) + .map_err(SslError::Io)?; + + if n > 0 { Ok(buf) } else { Ok(Vec::new()) } +} + +/// Read TLS records from socket to rustls +fn ssl_read_tls_records( + conn: &mut TlsConnection, + data: PyObjectRef, + is_bio: bool, + vm: &VirtualMachine, +) -> SslResult<()> { + // Convert PyObject to bytes-like (supports bytes, bytearray, etc.) + let bytes = ArgBytesLike::try_from_object(vm, data) + .map_err(|_| SslError::Syscall("Expected bytes-like object".to_string()))?; + + let bytes_data = bytes.borrow_buf(); + + if bytes_data.is_empty() { + // different error for BIO vs socket mode + if is_bio { + // In BIO mode, no data means WANT_READ + return Err(SslError::WantRead); + } else { + // In socket mode, empty recv() means TCP EOF (FIN received) + // Need to distinguish: + // 1. Clean shutdown: received TLS close_notify → return ZeroReturn (0 bytes) + // 2. Unexpected EOF: no close_notify → return Eof (SSLEOFError) + // + // SSL_ERROR_ZERO_RETURN vs SSL_ERROR_SYSCALL(errno=0) logic + // CPython checks SSL_get_shutdown() & SSL_RECEIVED_SHUTDOWN + // + // Process any buffered TLS records (may contain close_notify) + let _ = conn.process_new_packets(); + + // IMPORTANT: CPython's default behavior (suppress_ragged_eofs=True) + // treats empty recv() as clean shutdown, returning 0 bytes instead of raising SSLEOFError. + // + // This is necessary for HTTP/1.0 servers that: + // 1. Send response without Content-Length header + // 2. Signal end-of-response by closing connection (TCP FIN) + // 3. Don't send TLS close_notify before TCP close + // + // While this could theoretically allow truncation attacks, + // it's the standard behavior for compatibility with real-world servers. + // Python only raises SSLEOFError when suppress_ragged_eofs=False is explicitly set. + // + // TODO: Implement suppress_ragged_eofs parameter if needed for strict security mode. + return Err(SslError::ZeroReturn); + } + } + + // Feed all received data to read_tls - loop to consume all data + // read_tls may not consume all data in one call + let mut offset = 0; + while offset < bytes_data.len() { + let remaining = &bytes_data[offset..]; + let mut cursor = std::io::Cursor::new(remaining); + + match conn.read_tls(&mut cursor) { + Ok(read_bytes) => { + if read_bytes == 0 { + // No more data can be consumed + break; + } + offset += read_bytes; + } + Err(e) => { + // Real error - propagate it + return Err(SslError::Io(e)); + } + } + } + + Ok(()) +} + +/// Ensure TLS data is available for reading +/// Returns the number of bytes read from the socket +fn ssl_ensure_data_available( + conn: &mut TlsConnection, + socket: &PySSLSocket, + vm: &VirtualMachine, +) -> SslResult { + // Unlike OpenSSL's SSL_read, rustls requires explicit I/O + if conn.wants_read() { + let is_bio = socket.is_bio_mode(); + + // For non-BIO mode (regular sockets), check if socket is ready first + // PERFORMANCE OPTIMIZATION: Only use select for sockets with timeout + // - Blocking sockets (timeout=None): Skip select, recv() will block naturally + // - Timeout sockets: Use select to enforce timeout + // - Non-blocking sockets: Skip select, recv() will return EAGAIN immediately + if !is_bio { + let timeout = socket.get_socket_timeout(vm).map_err(SslError::Py)?; + + // Only use select if socket has a positive timeout + if let Some(t) = timeout + && !t.is_zero() + { + // Socket has timeout - use select to enforce it + let timed_out = socket + .sock_wait_for_io_impl(SelectKind::Read, vm) + .map_err(SslError::Py)?; + if timed_out { + // Socket not ready within timeout + return Err(SslError::WantRead); + } + } + // else: non-blocking socket (timeout=0) or blocking socket (timeout=None) - skip select + } + + let data = socket.sock_recv(2048, vm).map_err(SslError::Py)?; + + // Get the size of received data + let bytes_read = data + .clone() + .try_into_value::(vm) + .map(|b| b.as_bytes().len()) + .unwrap_or(0); + + // Check if BIO has EOF set (incoming BIO closed) + let is_eof = if is_bio { + // Check incoming BIO's eof property + if let Some(bio_obj) = socket.incoming_bio() { + bio_obj + .get_attr("eof", vm) + .and_then(|v| v.try_into_value::(vm)) + .unwrap_or(false) + } else { + false + } + } else { + false + }; + + // If BIO EOF is set and no data available, treat as connection EOF + if is_eof && bytes_read == 0 { + return Err(SslError::Eof); + } + + // Feed data to rustls and process packets + ssl_read_tls_records(conn, data, is_bio, vm)?; + + // Process any packets we successfully read + conn.process_new_packets().map_err(SslError::from_rustls)?; + + Ok(bytes_read) + } else { + // No data to read + Ok(0) + } +} + +// Multi-Certificate Resolver for RSA/ECC Support + +/// Multi-certificate resolver that selects appropriate certificate based on client capabilities +/// +/// This resolver implements OpenSSL's behavior of supporting multiple certificate/key pairs +/// (e.g., one RSA and one ECC) and selecting the appropriate one based on the client's +/// supported signature algorithms during the TLS handshake. +/// +/// OpenSSL's SSL_CTX_use_certificate_chain_file can be called multiple +/// times to add different certificate types, and OpenSSL automatically selects the best one. +#[derive(Debug)] +pub(super) struct MultiCertResolver { + cert_keys: Vec>, +} + +impl MultiCertResolver { + /// Create a new multi-certificate resolver + pub fn new(cert_keys: Vec>) -> Self { + Self { cert_keys } + } +} + +impl ResolvesServerCert for MultiCertResolver { + fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option> { + // Get the signature schemes supported by the client + let client_schemes = client_hello.signature_schemes(); + + // Try to find a certificate that matches the client's signature schemes + for cert_key in &self.cert_keys { + // Check if this certificate's signing key is compatible with any of the + // client's supported signature schemes + if let Some(_scheme) = cert_key.key.choose_scheme(client_schemes) { + return Some(cert_key.clone()); + } + } + + // If no perfect match, return the first certificate as fallback + // (This matches OpenSSL's behavior of using the first loaded cert if negotiation fails) + self.cert_keys.first().cloned() + } +} + +// Helper Functions for OpenSSL Compatibility: + +/// Normalize cipher suite name for OpenSSL compatibility +/// +/// Converts rustls cipher names to OpenSSL format: +/// - TLS_AES_256_GCM_SHA384 → AES256-GCM-SHA384 +/// - Replaces "AES-256" with "AES256" and "AES-128" with "AES128" +pub(super) fn normalize_cipher_name(rustls_name: &str) -> String { + rustls_name + .strip_prefix("TLS_") + .unwrap_or(rustls_name) + .replace("_WITH_", "_") + .replace('_', "-") + .replace("AES-256", "AES256") + .replace("AES-128", "AES128") +} + +/// Get cipher key size in bits from cipher suite name +/// +/// Returns: +/// - 256 for AES-256 and ChaCha20 +/// - 128 for AES-128 +/// - 0 for unknown ciphers +pub(super) fn get_cipher_key_bits(cipher_name: &str) -> i32 { + if cipher_name.contains("256") || cipher_name.contains("CHACHA20") { + 256 + } else if cipher_name.contains("128") { + 128 + } else { + 0 + } +} + +/// Get encryption algorithm description from cipher name +/// +/// Returns human-readable encryption description for OpenSSL compatibility +pub(super) fn get_cipher_encryption_desc(cipher_name: &str) -> &'static str { + if cipher_name.contains("AES256") { + "AESGCM(256)" + } else if cipher_name.contains("AES128") { + "AESGCM(128)" + } else if cipher_name.contains("CHACHA20") { + "CHACHA20-POLY1305(256)" + } else { + "Unknown" + } +} + +/// Normalize rustls cipher suite name to IANA standard format +/// +/// Converts rustls Debug format names to IANA standard: +/// - "TLS13_AES_256_GCM_SHA384" -> "TLS_AES_256_GCM_SHA384" +/// - Other names remain unchanged +pub(super) fn normalize_rustls_cipher_name(rustls_name: &str) -> String { + if rustls_name.starts_with("TLS13_") { + rustls_name.replace("TLS13_", "TLS_") + } else { + rustls_name.to_string() + } +} + +/// Convert rustls protocol version to string representation +/// +/// Returns the TLS version string +/// - TLSv1.2, TLSv1.3, or "Unknown" +pub(super) fn get_protocol_version_str(version: &rustls::SupportedProtocolVersion) -> &'static str { + match version.version { + rustls::ProtocolVersion::TLSv1_2 => "TLSv1.2", + rustls::ProtocolVersion::TLSv1_3 => "TLSv1.3", + _ => "Unknown", + } +} + +/// Cipher suite information +/// +/// Contains all relevant cipher information extracted from a rustls CipherSuite +pub(super) struct CipherInfo { + /// IANA standard cipher name (e.g., "TLS_AES_256_GCM_SHA384") + pub name: String, + /// TLS protocol version (e.g., "TLSv1.2", "TLSv1.3") + pub protocol: &'static str, + /// Key size in bits (e.g., 128, 256) + pub bits: i32, +} + +/// Extract cipher information from a rustls CipherSuite +/// +/// This consolidates the common cipher extraction logic used across +/// get_ciphers(), cipher(), and shared_ciphers() methods. +pub(super) fn extract_cipher_info(suite: &rustls::SupportedCipherSuite) -> CipherInfo { + let rustls_name = format!("{:?}", suite.suite()); + let name = normalize_rustls_cipher_name(&rustls_name); + let protocol = get_protocol_version_str(suite.version()); + let bits = get_cipher_key_bits(&name); + + CipherInfo { + name, + protocol, + bits, + } +} + +/// Convert curve name to rustls key exchange group +/// +/// Maps OpenSSL curve names (e.g., "prime256v1", "secp384r1") to rustls KxGroups. +/// Returns an error if the curve is not supported by rustls. +pub(super) fn curve_name_to_kx_group( + curve: &str, +) -> Result, String> { + // Get the default crypto provider's key exchange groups + let provider = rustls::crypto::aws_lc_rs::default_provider(); + let all_groups = &provider.kx_groups; + + match curve { + // P-256 (also known as secp256r1 or prime256v1) + "prime256v1" | "secp256r1" => { + // Find SECP256R1 in the provider's groups + all_groups + .iter() + .find(|g| g.name() == rustls::NamedGroup::secp256r1) + .map(|g| vec![*g]) + .ok_or_else(|| "secp256r1 not supported by crypto provider".to_owned()) + } + // P-384 (also known as secp384r1 or prime384v1) + "secp384r1" | "prime384v1" => all_groups + .iter() + .find(|g| g.name() == rustls::NamedGroup::secp384r1) + .map(|g| vec![*g]) + .ok_or_else(|| "secp384r1 not supported by crypto provider".to_owned()), + // X25519 + "X25519" | "x25519" => all_groups + .iter() + .find(|g| g.name() == rustls::NamedGroup::X25519) + .map(|g| vec![*g]) + .ok_or_else(|| "X25519 not supported by crypto provider".to_owned()), + // P-521 (also known as secp521r1 or prime521v1) + // Now supported with aws-lc-rs crypto provider + "prime521v1" | "secp521r1" => all_groups + .iter() + .find(|g| g.name() == rustls::NamedGroup::secp521r1) + .map(|g| vec![*g]) + .ok_or_else(|| "secp521r1 not supported by crypto provider".to_owned()), + // X448 + // Now supported with aws-lc-rs crypto provider + "X448" | "x448" => all_groups + .iter() + .find(|g| g.name() == rustls::NamedGroup::X448) + .map(|g| vec![*g]) + .ok_or_else(|| "X448 not supported by crypto provider".to_owned()), + _ => Err(format!("unknown curve name '{curve}'")), + } +} diff --git a/stdlib/src/ssl/oid.rs b/stdlib/src/ssl/oid.rs new file mode 100644 index 00000000000..2e13733a2a2 --- /dev/null +++ b/stdlib/src/ssl/oid.rs @@ -0,0 +1,464 @@ +// spell-checker: disable + +//! OID (Object Identifier) management for SSL/TLS +//! +//! This module provides OID lookup functionality compatible with CPython's ssl module. +//! It uses oid-registry crate for well-known OIDs while maintaining NID (Numerical Identifier) +//! mappings for CPython compatibility. + +use oid_registry::asn1_rs::Oid; +use std::collections::HashMap; + +/// OID entry with openssl-compatible metadata +#[derive(Debug, Clone)] +pub struct OidEntry { + /// NID (OpenSSL Numerical Identifier) - must match CPython/OpenSSL values + pub nid: i32, + /// Short name (e.g., "CN", "serverAuth") + pub short_name: &'static str, + /// Long name/description (e.g., "commonName", "TLS Web Server Authentication") + pub long_name: &'static str, + /// OID reference (static or dynamic) + pub oid: OidRef, +} + +/// OID reference - either from oid-registry or runtime-created +#[derive(Debug, Clone)] +pub enum OidRef { + /// Static OID from oid-registry crate (stored as value) + Static(Oid<'static>), + /// OID string (for OIDs not in oid-registry) - parsed on demand + String(&'static str), +} + +impl OidEntry { + /// Create entry from oid-registry static constant + pub fn from_static( + nid: i32, + short_name: &'static str, + long_name: &'static str, + oid: &Oid<'static>, + ) -> Self { + Self { + nid, + short_name, + long_name, + oid: OidRef::Static(oid.clone()), + } + } + + /// Create entry from OID string (for OIDs not in oid-registry) + pub const fn from_string( + nid: i32, + short_name: &'static str, + long_name: &'static str, + oid_str: &'static str, + ) -> Self { + Self { + nid, + short_name, + long_name, + oid: OidRef::String(oid_str), + } + } + + /// Get OID as string (e.g., "2.5.4.3") + pub fn oid_string(&self) -> String { + match &self.oid { + OidRef::Static(oid) => oid.to_id_string(), + OidRef::String(s) => s.to_string(), + } + } +} + +/// OID table with multiple indices for fast lookup +pub struct OidTable { + /// All entries + entries: Vec, + /// NID -> index mapping + nid_to_idx: HashMap, + /// Short name -> index mapping + short_name_to_idx: HashMap<&'static str, usize>, + /// Long name -> index mapping (case-insensitive) + long_name_to_idx: HashMap, + /// OID string -> index mapping + oid_str_to_idx: HashMap, +} + +impl OidTable { + fn build() -> Self { + let entries = build_oid_entries(); + let mut nid_to_idx = HashMap::with_capacity(entries.len()); + let mut short_name_to_idx = HashMap::with_capacity(entries.len()); + let mut long_name_to_idx = HashMap::with_capacity(entries.len()); + let mut oid_str_to_idx = HashMap::with_capacity(entries.len()); + + for (idx, entry) in entries.iter().enumerate() { + nid_to_idx.insert(entry.nid, idx); + short_name_to_idx.insert(entry.short_name, idx); + long_name_to_idx.insert(entry.long_name.to_lowercase(), idx); + oid_str_to_idx.insert(entry.oid_string(), idx); + } + + Self { + entries, + nid_to_idx, + short_name_to_idx, + long_name_to_idx, + oid_str_to_idx, + } + } + + pub fn find_by_nid(&self, nid: i32) -> Option<&OidEntry> { + self.nid_to_idx.get(&nid).map(|&idx| &self.entries[idx]) + } + + pub fn find_by_oid_string(&self, oid_str: &str) -> Option<&OidEntry> { + self.oid_str_to_idx + .get(oid_str) + .map(|&idx| &self.entries[idx]) + } + + pub fn find_by_name(&self, name: &str) -> Option<&OidEntry> { + // Try short name first (exact match) + self.short_name_to_idx + .get(name) + .or_else(|| { + // Try long name (case-insensitive) + self.long_name_to_idx.get(&name.to_lowercase()) + }) + .map(|&idx| &self.entries[idx]) + } +} + +/// Global OID table +static OID_TABLE: std::sync::LazyLock = std::sync::LazyLock::new(OidTable::build); + +/// Macro to define OID entry using oid-registry constant +macro_rules! oid_static { + ($nid:expr, $short:expr, $long:expr, $oid_const:path) => { + OidEntry::from_static($nid, $short, $long, &$oid_const) + }; +} + +/// Macro to define OID entry from string +macro_rules! oid_string { + ($nid:expr, $short:expr, $long:expr, $oid_str:expr) => { + OidEntry::from_string($nid, $short, $long, $oid_str) + }; +} + +/// Build the complete OID table +fn build_oid_entries() -> Vec { + vec![ + // Priority 1: X.509 DN Attributes (OpenSSL NID values) + // These NIDs MUST match OpenSSL for CPython compatibility + oid_static!(13, "CN", "commonName", oid_registry::OID_X509_COMMON_NAME), + oid_static!(14, "C", "countryName", oid_registry::OID_X509_COUNTRY_NAME), + oid_static!( + 15, + "L", + "localityName", + oid_registry::OID_X509_LOCALITY_NAME + ), + oid_static!( + 16, + "ST", + "stateOrProvinceName", + oid_registry::OID_X509_STATE_OR_PROVINCE_NAME + ), + oid_static!( + 17, + "O", + "organizationName", + oid_registry::OID_X509_ORGANIZATION_NAME + ), + oid_static!( + 18, + "OU", + "organizationalUnitName", + oid_registry::OID_X509_ORGANIZATIONAL_UNIT + ), + oid_static!(41, "name", "name", oid_registry::OID_X509_NAME), + oid_static!(42, "GN", "givenName", oid_registry::OID_X509_GIVEN_NAME), + oid_static!(43, "initials", "initials", oid_registry::OID_X509_INITIALS), + oid_static!( + 4, + "serialNumber", + "serialNumber", + oid_registry::OID_X509_SERIALNUMBER + ), + oid_static!(100, "surname", "surname", oid_registry::OID_X509_SURNAME), + // emailAddress is special - it's in PKCS#9, not X.509 + oid_static!( + 48, + "emailAddress", + "emailAddress", + oid_registry::OID_PKCS9_EMAIL_ADDRESS + ), + // Priority 2: X.509 Extensions (Critical ones) + oid_static!( + 82, + "subjectKeyIdentifier", + "X509v3 Subject Key Identifier", + oid_registry::OID_X509_EXT_SUBJECT_KEY_IDENTIFIER + ), + oid_static!( + 83, + "keyUsage", + "X509v3 Key Usage", + oid_registry::OID_X509_EXT_KEY_USAGE + ), + oid_static!( + 85, + "subjectAltName", + "X509v3 Subject Alternative Name", + oid_registry::OID_X509_EXT_SUBJECT_ALT_NAME + ), + oid_static!( + 86, + "issuerAltName", + "X509v3 Issuer Alternative Name", + oid_registry::OID_X509_EXT_ISSUER_ALT_NAME + ), + oid_static!( + 87, + "basicConstraints", + "X509v3 Basic Constraints", + oid_registry::OID_X509_EXT_BASIC_CONSTRAINTS + ), + oid_static!( + 88, + "crlNumber", + "X509v3 CRL Number", + oid_registry::OID_X509_EXT_CRL_NUMBER + ), + oid_static!( + 90, + "authorityKeyIdentifier", + "X509v3 Authority Key Identifier", + oid_registry::OID_X509_EXT_AUTHORITY_KEY_IDENTIFIER + ), + oid_static!( + 126, + "extendedKeyUsage", + "X509v3 Extended Key Usage", + oid_registry::OID_X509_EXT_EXTENDED_KEY_USAGE + ), + oid_static!( + 103, + "crlDistributionPoints", + "X509v3 CRL Distribution Points", + oid_registry::OID_X509_EXT_CRL_DISTRIBUTION_POINTS + ), + oid_static!( + 89, + "certificatePolicies", + "X509v3 Certificate Policies", + oid_registry::OID_X509_EXT_CERTIFICATE_POLICIES + ), + oid_static!( + 177, + "authorityInfoAccess", + "Authority Information Access", + oid_registry::OID_PKIX_AUTHORITY_INFO_ACCESS + ), + oid_static!( + 105, + "nameConstraints", + "X509v3 Name Constraints", + oid_registry::OID_X509_EXT_NAME_CONSTRAINTS + ), + // Priority 3: Extended Key Usage OIDs (not in oid-registry) + // These are defined in RFC 5280 but not in oid-registry, so we use strings + oid_string!( + 129, + "serverAuth", + "TLS Web Server Authentication", + "1.3.6.1.5.5.7.3.1" + ), + oid_string!( + 130, + "clientAuth", + "TLS Web Client Authentication", + "1.3.6.1.5.5.7.3.2" + ), + oid_string!(131, "codeSigning", "Code Signing", "1.3.6.1.5.5.7.3.3"), + oid_string!( + 132, + "emailProtection", + "E-mail Protection", + "1.3.6.1.5.5.7.3.4" + ), + oid_string!(133, "timeStamping", "Time Stamping", "1.3.6.1.5.5.7.3.8"), + oid_string!(180, "OCSPSigning", "OCSP Signing", "1.3.6.1.5.5.7.3.9"), + // Priority 4: Signature Algorithms + oid_static!( + 6, + "rsaEncryption", + "rsaEncryption", + oid_registry::OID_PKCS1_RSAENCRYPTION + ), + oid_static!( + 65, + "sha1WithRSAEncryption", + "sha1WithRSAEncryption", + oid_registry::OID_PKCS1_SHA1WITHRSA + ), + oid_static!( + 668, + "sha256WithRSAEncryption", + "sha256WithRSAEncryption", + oid_registry::OID_PKCS1_SHA256WITHRSA + ), + oid_static!( + 669, + "sha384WithRSAEncryption", + "sha384WithRSAEncryption", + oid_registry::OID_PKCS1_SHA384WITHRSA + ), + oid_static!( + 670, + "sha512WithRSAEncryption", + "sha512WithRSAEncryption", + oid_registry::OID_PKCS1_SHA512WITHRSA + ), + oid_static!( + 408, + "id-ecPublicKey", + "id-ecPublicKey", + oid_registry::OID_KEY_TYPE_EC_PUBLIC_KEY + ), + oid_static!( + 794, + "ecdsa-with-SHA256", + "ecdsa-with-SHA256", + oid_registry::OID_SIG_ECDSA_WITH_SHA256 + ), + oid_static!( + 795, + "ecdsa-with-SHA384", + "ecdsa-with-SHA384", + oid_registry::OID_SIG_ECDSA_WITH_SHA384 + ), + oid_static!( + 796, + "ecdsa-with-SHA512", + "ecdsa-with-SHA512", + oid_registry::OID_SIG_ECDSA_WITH_SHA512 + ), + // Priority 5: Hash Algorithms + oid_string!(64, "sha1", "sha1", "1.3.14.3.2.26"), + oid_static!(672, "sha256", "sha256", oid_registry::OID_NIST_HASH_SHA256), + oid_static!(673, "sha384", "sha384", oid_registry::OID_NIST_HASH_SHA384), + oid_static!(674, "sha512", "sha512", oid_registry::OID_NIST_HASH_SHA512), + oid_string!(675, "sha224", "sha224", "2.16.840.1.101.3.4.2.4"), + // Priority 6: Elliptic Curve OIDs + oid_static!(714, "secp256r1", "secp256r1", oid_registry::OID_EC_P256), + oid_string!(715, "secp384r1", "secp384r1", "1.3.132.0.34"), + oid_string!(716, "secp521r1", "secp521r1", "1.3.132.0.35"), + oid_string!(1172, "X25519", "X25519", "1.3.101.110"), + oid_string!(1173, "Ed25519", "Ed25519", "1.3.101.112"), + // Additional useful OIDs + oid_string!( + 183, + "subjectInfoAccess", + "Subject Information Access", + "1.3.6.1.5.5.7.1.11" + ), + oid_string!(920, "OCSP", "OCSP", "1.3.6.1.5.5.7.48.1"), + oid_string!(921, "caIssuers", "CA Issuers", "1.3.6.1.5.5.7.48.2"), + ] +} + +// Public API Functions + +/// Find OID entry by NID +pub fn find_by_nid(nid: i32) -> Option<&'static OidEntry> { + OID_TABLE.find_by_nid(nid) +} + +/// Find OID entry by OID string (e.g., "2.5.4.3") +pub fn find_by_oid_string(oid_str: &str) -> Option<&'static OidEntry> { + OID_TABLE.find_by_oid_string(oid_str) +} + +/// Find OID entry by name (short or long name) +pub fn find_by_name(name: &str) -> Option<&'static OidEntry> { + OID_TABLE.find_by_name(name) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_find_by_nid() { + let entry = find_by_nid(13).unwrap(); + assert_eq!(entry.short_name, "CN"); + assert_eq!(entry.long_name, "commonName"); + assert_eq!(entry.oid_string(), "2.5.4.3"); + } + + #[test] + fn test_find_by_oid_string() { + let entry = find_by_oid_string("2.5.4.3").unwrap(); + assert_eq!(entry.nid, 13); + assert_eq!(entry.short_name, "CN"); + } + + #[test] + fn test_find_by_name_short() { + let entry = find_by_name("CN").unwrap(); + assert_eq!(entry.nid, 13); + assert_eq!(entry.oid_string(), "2.5.4.3"); + } + + #[test] + fn test_find_by_name_long() { + let entry = find_by_name("commonName").unwrap(); + assert_eq!(entry.nid, 13); + assert_eq!(entry.short_name, "CN"); + } + + #[test] + fn test_find_by_name_case_insensitive() { + let entry = find_by_name("COMMONNAME").unwrap(); + assert_eq!(entry.nid, 13); + } + + #[test] + fn test_subject_alt_name() { + let entry = find_by_nid(85).unwrap(); + assert_eq!(entry.short_name, "subjectAltName"); + assert_eq!(entry.oid_string(), "2.5.29.17"); + } + + #[test] + fn test_server_auth_eku() { + let entry = find_by_nid(129).unwrap(); + assert_eq!(entry.short_name, "serverAuth"); + assert_eq!(entry.oid_string(), "1.3.6.1.5.5.7.3.1"); + } + + #[test] + fn test_no_duplicate_nids() { + let table = &*OID_TABLE; + assert_eq!( + table.entries.len(), + table.nid_to_idx.len(), + "Duplicate NIDs detected!" + ); + } + + #[test] + fn test_oid_count() { + let table = &*OID_TABLE; + // We should have 50+ OIDs defined + assert!( + table.entries.len() >= 50, + "Expected at least 50 OIDs, got {}", + table.entries.len() + ); + } +}