From 3c084b09e7c62c7067109d99e8b09cc76c6b36ee Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 08:48:07 +0100 Subject: [PATCH 01/14] Introduce server::test for low-level protocol tests --- rustls/src/lib.rs | 2 ++ rustls/src/server/test.rs | 54 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 rustls/src/server/test.rs diff --git a/rustls/src/lib.rs b/rustls/src/lib.rs index 51fa5741f3..cf428cef27 100644 --- a/rustls/src/lib.rs +++ b/rustls/src/lib.rs @@ -629,6 +629,8 @@ pub mod server { pub(crate) mod handy; mod hs; mod server_conn; + #[cfg(test)] + mod test; #[cfg(feature = "tls12")] mod tls12; mod tls13; diff --git a/rustls/src/server/test.rs b/rustls/src/server/test.rs new file mode 100644 index 0000000000..6f1b598065 --- /dev/null +++ b/rustls/src/server/test.rs @@ -0,0 +1,54 @@ +use std::prelude::v1::*; +use std::vec; + +use super::ServerConnectionData; +use crate::common_state::Context; +use crate::msgs::enums::Compression; +use crate::msgs::handshake::{ + ClientExtension, ClientHelloPayload, HandshakeMessagePayload, HandshakePayload, Random, + SessionId, +}; +use crate::msgs::message::{Message, MessagePayload}; +use crate::{CommonState, Error, HandshakeType, PeerIncompatible, ProtocolVersion, Side}; + +#[test] +fn null_compression_required() { + assert_eq!( + test_process_client_hello(ClientHelloPayload { + compression_methods: vec![], + ..minimal_client_hello() + }), + Err(PeerIncompatible::NullCompressionRequired.into()), + ); +} + +fn test_process_client_hello(hello: ClientHelloPayload) -> Result<(), Error> { + let m = Message { + version: ProtocolVersion::TLSv1_2, + payload: MessagePayload::handshake(HandshakeMessagePayload { + typ: HandshakeType::ClientHello, + payload: HandshakePayload::ClientHello(hello), + }), + }; + super::hs::process_client_hello( + &m, + false, + &mut Context { + common: &mut CommonState::new(Side::Server), + data: &mut ServerConnectionData::default(), + sendable_plaintext: None, + }, + ) + .map(|_| ()) +} + +fn minimal_client_hello() -> ClientHelloPayload { + ClientHelloPayload { + client_version: ProtocolVersion::TLSv1_3, + random: Random::from([0u8; 32]), + session_id: SessionId::empty(), + cipher_suites: vec![], + compression_methods: vec![Compression::Null], + extensions: vec![ClientExtension::SignatureAlgorithms(vec![])], + } +} From 739617399e68d6645b5a96aff4eff1a5ae6b6964 Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 09:03:05 +0100 Subject: [PATCH 02/14] server::test: port `server_ignores_sni_with_ip_address` --- rustls/src/server/test.rs | 28 +++++++++++++++++++++++++++- rustls/tests/api.rs | 8 -------- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/rustls/src/server/test.rs b/rustls/src/server/test.rs index 6f1b598065..09e91d08dc 100644 --- a/rustls/src/server/test.rs +++ b/rustls/src/server/test.rs @@ -3,7 +3,8 @@ use std::vec; use super::ServerConnectionData; use crate::common_state::Context; -use crate::msgs::enums::Compression; +use crate::msgs::codec::{Codec, LengthPrefixedBuffer, ListLength}; +use crate::msgs::enums::{Compression, ExtensionType}; use crate::msgs::handshake::{ ClientExtension, ClientHelloPayload, HandshakeMessagePayload, HandshakePayload, Random, SessionId, @@ -22,6 +23,15 @@ fn null_compression_required() { ); } +#[test] +fn server_ignores_sni_with_ip_address() { + let mut ch = minimal_client_hello(); + ch.extensions + .push(ClientExtension::read_bytes(&sni_extension(&[b"1.1.1.1"])).unwrap()); + std::println!("{:?}", ch.extensions); + assert_eq!(test_process_client_hello(ch), Ok(())); +} + fn test_process_client_hello(hello: ClientHelloPayload) -> Result<(), Error> { let m = Message { version: ProtocolVersion::TLSv1_2, @@ -52,3 +62,19 @@ fn minimal_client_hello() -> ClientHelloPayload { extensions: vec![ClientExtension::SignatureAlgorithms(vec![])], } } + +fn sni_extension(names: &[&[u8]]) -> Vec { + let mut r = Vec::new(); + ExtensionType::ServerName.encode(&mut r); + let outer = LengthPrefixedBuffer::new(ListLength::U16, &mut r); + let name_items = LengthPrefixedBuffer::new(ListLength::U16, outer.buf); + for name in names { + name_items.buf.push(0); + let host_name = LengthPrefixedBuffer::new(ListLength::U16, name_items.buf); + host_name.buf.extend_from_slice(name); + drop(host_name); + } + drop(name_items); + drop(outer); + r +} diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index e26a98edc4..495b0449c4 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -1542,14 +1542,6 @@ fn client_trims_terminating_dot() { } } -#[test] -fn server_ignores_sni_with_ip_address() { - check_sni_error( - encoding::Extension::new_sni(b"1.1.1.1"), - Error::General("no server certificate chain resolved".to_string()), - ); -} - #[test] fn server_rejects_sni_with_illegal_dns_name() { check_sni_error( From 20e1d8dce8bf3ef7042cba45aa4045fae039219f Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 09:12:28 +0100 Subject: [PATCH 03/14] server::test: port `server_rejects_sni_with_illegal_dns_name` --- rustls/src/server/test.rs | 16 +++++++++++++++- rustls/tests/api.rs | 8 -------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/rustls/src/server/test.rs b/rustls/src/server/test.rs index 09e91d08dc..52ea61316e 100644 --- a/rustls/src/server/test.rs +++ b/rustls/src/server/test.rs @@ -10,7 +10,9 @@ use crate::msgs::handshake::{ SessionId, }; use crate::msgs::message::{Message, MessagePayload}; -use crate::{CommonState, Error, HandshakeType, PeerIncompatible, ProtocolVersion, Side}; +use crate::{ + CommonState, Error, HandshakeType, PeerIncompatible, PeerMisbehaved, ProtocolVersion, Side, +}; #[test] fn null_compression_required() { @@ -32,6 +34,18 @@ fn server_ignores_sni_with_ip_address() { assert_eq!(test_process_client_hello(ch), Ok(())); } +#[test] +fn server_rejects_sni_with_illegal_dns_name() { + let mut ch = minimal_client_hello(); + ch.extensions + .push(ClientExtension::read_bytes(&sni_extension(&[b"ab@cd.com"])).unwrap()); + std::println!("{:?}", ch.extensions); + assert_eq!( + test_process_client_hello(ch), + Err(PeerMisbehaved::ServerNameMustContainOneHostName.into()) + ); +} + fn test_process_client_hello(hello: ClientHelloPayload) -> Result<(), Error> { let m = Message { version: ProtocolVersion::TLSv1_2, diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index 495b0449c4..234205f8d4 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -1542,14 +1542,6 @@ fn client_trims_terminating_dot() { } } -#[test] -fn server_rejects_sni_with_illegal_dns_name() { - check_sni_error( - encoding::Extension::new_sni(b"ab@cd.com"), - Error::PeerMisbehaved(PeerMisbehaved::ServerNameMustContainOneHostName), - ); -} - fn check_sni_error(sni_extension: encoding::Extension, expected_error: Error) { for kt in ALL_KEY_TYPES { let mut server_config = make_server_config(*kt); From 44dcab1f4b4c73fa278a2da013844b86a16d79b5 Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 09:30:40 +0100 Subject: [PATCH 04/14] Lower `test_server_rejects_empty_sni_extension` This moves from an integration test to a unit test of `ClientExtension`. --- rustls/src/msgs/handshake_test.rs | 8 ++++++++ rustls/tests/api.rs | 11 ----------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/rustls/src/msgs/handshake_test.rs b/rustls/src/msgs/handshake_test.rs index f5cb1ae633..2341686569 100644 --- a/rustls/src/msgs/handshake_test.rs +++ b/rustls/src/msgs/handshake_test.rs @@ -228,6 +228,14 @@ fn rejects_truncated_sni() { assert!(ClientExtension::read(&mut Reader::init(&bytes)).is_err()); } +#[test] +fn rejects_empty_sni_extension() { + assert_eq!( + ClientExtension::read_bytes(&[0, 0, 0, 2, 0, 0]).unwrap_err(), + InvalidMessage::IllegalEmptyList("ServerNames") + ); +} + #[test] fn can_round_trip_psk_identity() { let bytes = [0, 1, 0x99, 0x11, 0x22, 0x33, 0x44]; diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index 234205f8d4..80d969d70e 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -6263,17 +6263,6 @@ fn test_server_rejects_duplicate_sni_names() { ); } -#[test] -fn test_server_rejects_empty_sni_extension() { - check_sni_error( - encoding::Extension { - typ: ExtensionType::ServerName, - body: encoding::len_u16(vec![]), - }, - Error::InvalidMessage(InvalidMessage::IllegalEmptyList("ServerNames")), - ); -} - #[test] fn test_server_rejects_clients_without_any_kx_groups() { let (_, mut server) = make_pair(KeyType::Rsa2048); From 7e381258612f0a30e8b8191c5cf34719592e8a27 Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 09:39:48 +0100 Subject: [PATCH 05/14] Lower `test_server_rejects_duplicate_sni_names` This becomes a unit test on `ClientExtension` decoding. --- rustls/src/msgs/handshake_test.rs | 9 ++++++++ rustls/tests/api.rs | 35 ------------------------------- rustls/tests/common/mod.rs | 17 --------------- 3 files changed, 9 insertions(+), 52 deletions(-) diff --git a/rustls/src/msgs/handshake_test.rs b/rustls/src/msgs/handshake_test.rs index 2341686569..e90546bde2 100644 --- a/rustls/src/msgs/handshake_test.rs +++ b/rustls/src/msgs/handshake_test.rs @@ -236,6 +236,15 @@ fn rejects_empty_sni_extension() { ); } +#[test] +fn rejects_duplicate_names_in_sni_extension() { + assert_eq!( + ClientExtension::read_bytes(&[0, 0, 0, 10, 0, 8, 0, 0, 1, b'a', 0, 0, 1, b'b',]) + .unwrap_err(), + InvalidMessage::InvalidServerName + ); +} + #[test] fn can_round_trip_psk_identity() { let bytes = [0, 1, 0x99, 0x11, 0x22, 0x33, 0x44]; diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index 80d969d70e..767c94696d 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -1542,28 +1542,6 @@ fn client_trims_terminating_dot() { } } -fn check_sni_error(sni_extension: encoding::Extension, expected_error: Error) { - for kt in ALL_KEY_TYPES { - let mut server_config = make_server_config(*kt); - server_config.cert_resolver = Arc::new(ServerCheckNoSni {}); - - let mut server = ServerConnection::new(Arc::new(server_config)).unwrap(); - server - .read_tls( - &mut encoding::message_framing( - ContentType::Handshake, - ProtocolVersion::TLSv1_2, - encoding::basic_client_hello(vec![sni_extension.clone()]), - ) - .as_slice(), - ) - .unwrap(); - - assert_eq!(server.process_new_packets(), Err(expected_error.clone()),); - assert_eq!(None, server.server_name()); - } -} - #[cfg(feature = "tls12")] fn check_sigalgs_reduced_by_ciphersuite( kt: KeyType, @@ -6250,19 +6228,6 @@ fn connection_types_are_not_huge() { ); } -#[test] -fn test_server_rejects_duplicate_sni_names() { - let mut body = encoding::Extension::sni_dns_hostname(b"example.com"); - body.extend_from_slice(&encoding::Extension::sni_dns_hostname(b"example.com")); - check_sni_error( - encoding::Extension { - typ: ExtensionType::ServerName, - body: encoding::len_u16(body), - }, - Error::InvalidMessage(InvalidMessage::InvalidServerName), - ); -} - #[test] fn test_server_rejects_clients_without_any_kx_groups() { let (_, mut server) = make_pair(KeyType::Rsa2048); diff --git a/rustls/tests/common/mod.rs b/rustls/tests/common/mod.rs index 34169c32bc..31c7b2b1fb 100644 --- a/rustls/tests/common/mod.rs +++ b/rustls/tests/common/mod.rs @@ -1642,23 +1642,6 @@ pub mod encoding { } impl Extension { - pub fn new_sni(name: &[u8]) -> Self { - let body = Self::sni_dns_hostname(name); - let body = len_u16(body); - Self { - typ: ExtensionType::ServerName, - body, - } - } - - pub fn sni_dns_hostname(name: &[u8]) -> Vec { - const SNI_HOSTNAME_TYPE: u8 = 0; - - let mut out = len_u16(name.to_vec()); - out.insert(0, SNI_HOSTNAME_TYPE); - out - } - pub fn new_sig_algs() -> Extension { Extension { typ: ExtensionType::SignatureAlgorithms, From f338cc7bac302efe3c2d4b417c60becc8801341a Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 10:52:06 +0100 Subject: [PATCH 06/14] Move `test_no_session_ticket_request_on_tls_1_3` into crate --- rustls/src/client/test.rs | 68 +++++++++++++++++++++++++++++++++++++++ rustls/src/lib.rs | 2 ++ rustls/tests/api.rs | 34 -------------------- 3 files changed, 70 insertions(+), 34 deletions(-) create mode 100644 rustls/src/client/test.rs diff --git a/rustls/src/client/test.rs b/rustls/src/client/test.rs new file mode 100644 index 0000000000..dbe68d26c8 --- /dev/null +++ b/rustls/src/client/test.rs @@ -0,0 +1,68 @@ +#![cfg(any(feature = "ring", feature = "aws_lc_rs"))] +use std::prelude::v1::*; + +use pki_types::{CertificateDer, ServerName}; + +use crate::client::{ClientConfig, ClientConnection, Resumption, Tls12Resumption}; +use crate::msgs::codec::Reader; +use crate::msgs::handshake::{ClientHelloPayload, HandshakeMessagePayload, HandshakePayload}; +use crate::msgs::message::{Message, MessagePayload, OutboundOpaqueMessage}; +use crate::{Error, RootCertStore}; + +#[macro_rules_attribute::apply(test_for_each_provider)] +mod tests { + use super::super::*; + use crate::version; + + /// Tests that session_ticket(35) extension + /// is not sent if the client does not support TLS 1.2. + #[test] + fn test_no_session_ticket_request_on_tls_1_3() { + let mut config = + ClientConfig::builder_with_provider(super::provider::default_provider().into()) + .with_protocol_versions(&[&version::TLS13]) + .unwrap() + .with_root_certificates(roots()) + .with_no_client_auth(); + config.resumption = Resumption::in_memory_sessions(128) + .tls12_resumption(Tls12Resumption::SessionIdOrTickets); + let ch = client_hello_sent_for_config(config).unwrap(); + assert!(ch.ticket_extension().is_none()); + } +} + +fn client_hello_sent_for_config(config: ClientConfig) -> Result { + let mut conn = + ClientConnection::new(config.into(), ServerName::try_from("localhost").unwrap())?; + let mut bytes = Vec::new(); + conn.write_tls(&mut bytes).unwrap(); + + let message = OutboundOpaqueMessage::read(&mut Reader::init(&bytes)) + .unwrap() + .into_plain_message(); + + match Message::try_from(message).unwrap() { + Message { + payload: + MessagePayload::Handshake { + parsed: + HandshakeMessagePayload { + payload: HandshakePayload::ClientHello(ch), + .. + }, + .. + }, + .. + } => Ok(ch), + other => panic!("unexpected message {other:?}"), + } +} + +fn roots() -> RootCertStore { + let mut r = RootCertStore::empty(); + r.add(CertificateDer::from_slice(include_bytes!( + "../../../test-ca/rsa-2048/ca.der" + ))) + .unwrap(); + r +} diff --git a/rustls/src/lib.rs b/rustls/src/lib.rs index cf428cef27..1898ba2f40 100644 --- a/rustls/src/lib.rs +++ b/rustls/src/lib.rs @@ -588,6 +588,8 @@ pub mod client { mod ech; pub(super) mod handy; mod hs; + #[cfg(test)] + mod test; #[cfg(feature = "tls12")] mod tls12; mod tls13; diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index 767c94696d..11ee94666b 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -6259,40 +6259,6 @@ fn test_server_rejects_clients_without_any_kx_groups() { ); } -/// Tests that session_ticket(35) extension -/// is not sent if the client does not support TLS 1.2. -#[test] -fn test_no_session_ticket_request_on_tls_1_3() { - /// Panics if TLS 1.2 session_ticket(35) extension is detected. - /// - /// Does not actually alter the payload. - fn panic_on_session_ticket(msg: &mut Message) -> Altered { - let MessagePayload::Handshake { parsed, encoded: _ } = &msg.payload else { - return Altered::InPlace; - }; - - let HandshakePayload::ClientHello(ch) = &parsed.payload else { - return Altered::InPlace; - }; - - for ext in &ch.extensions { - if matches!(ext, ClientExtension::SessionTicket(_)) { - panic!("TLS 1.2 session_ticket extension in TLS 1.3 handshake detected."); - } - } - - Altered::InPlace - } - - let client_config = - make_client_config_with_versions(KeyType::Rsa2048, &[&rustls::version::TLS13]); - let server_config = make_server_config(KeyType::Rsa2048); - - let (client, server) = make_pair_for_configs(client_config, server_config); - let (mut client, mut server) = (client.into(), server.into()); - transfer_altered(&mut client, panic_on_session_ticket, &mut server); -} - #[test] fn test_server_rejects_clients_without_any_kx_group_overlap() { for version in rustls::ALL_VERSIONS { From b2a1a738bf70dedae88d65c5b82f04fd64b3a9d0 Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 12:05:02 +0100 Subject: [PATCH 07/14] Make provider explicit in RPK test functions --- rustls/tests/api.rs | 28 ++++++---- rustls/tests/client_cert_verifier.rs | 12 +++-- rustls/tests/common/mod.rs | 76 ++++++++++++++++++++-------- rustls/tests/server_cert_verifier.rs | 2 +- 4 files changed, 81 insertions(+), 37 deletions(-) diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index 11ee94666b..3230274e24 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -60,9 +60,10 @@ mod test_raw_keys { #[test] fn successful_raw_key_connection_and_correct_peer_certificates() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let client_config = make_client_config_with_raw_key_support(*kt); - let server_config = make_server_config_with_raw_key_support(*kt); + let client_config = make_client_config_with_raw_key_support(*kt, &provider); + let server_config = make_server_config_with_raw_key_support(*kt, &provider); let (mut client, mut server) = make_pair_for_configs(client_config, server_config); do_handshake(&mut client, &mut server); @@ -95,9 +96,10 @@ mod test_raw_keys { #[test] fn correct_certificate_type_extensions_from_client_hello() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let client_config = make_client_config_with_raw_key_support(*kt); - let mut server_config = make_server_config_with_raw_key_support(*kt); + let client_config = make_client_config_with_raw_key_support(*kt, &provider); + let mut server_config = make_server_config_with_raw_key_support(*kt, &provider); server_config.cert_resolver = Arc::new(ServerCheckCertResolve { expected_client_cert_types: Some(vec![CertificateType::RawPublicKey]), @@ -113,8 +115,9 @@ mod test_raw_keys { #[test] fn only_client_supports_raw_keys() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let client_config_rpk = make_client_config_with_raw_key_support(*kt); + let client_config_rpk = make_client_config_with_raw_key_support(*kt, &provider); let server_config = make_server_config(*kt); let (mut client_rpk, mut server) = @@ -139,9 +142,10 @@ mod test_raw_keys { #[test] fn only_server_supports_raw_keys() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { let client_config = make_client_config_with_versions(*kt, &[&rustls::version::TLS13]); - let server_config_rpk = make_server_config_with_raw_key_support(*kt); + let server_config_rpk = make_server_config_with_raw_key_support(*kt, &provider); let (mut client, mut server_rpk) = make_pair_for_configs(client_config, server_config_rpk); @@ -188,10 +192,11 @@ mod test_raw_keys { client_cert_types: Option<&Vec>, expected_result: Result<(), ErrorFromPeer>, ) { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { let client_config = Arc::new(make_client_config(*kt)); let server_config_rpk = match server_requires_raw_keys { - true => Arc::new(make_server_config_with_raw_key_support(*kt)), + true => Arc::new(make_server_config_with_raw_key_support(*kt, &provider)), false => Arc::new(make_server_config(*kt)), }; @@ -211,12 +216,13 @@ mod test_raw_keys { #[test] fn incorrectly_alter_server_hello() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { let supported_suite = cipher_suite::TLS13_AES_256_GCM_SHA384; // Alter Server Hello server certificate extension and expect IncorrectCertificateTypeExtension error - let client_config_rpk = make_client_config_with_raw_key_support(*kt); - let server_config_rpk = make_server_config_with_raw_key_support(*kt); + let client_config_rpk = make_client_config_with_raw_key_support(*kt, &provider); + let server_config_rpk = make_server_config_with_raw_key_support(*kt, &provider); add_keylog_and_do_altered_handshake( client_config_rpk, server_config_rpk, @@ -227,8 +233,8 @@ mod test_raw_keys { ); // Alter Server Hello client certificate extension and expect IncorrectCertificateTypeExtension error - let client_config_rpk = make_client_config_with_raw_key_support(*kt); - let server_config_rpk = make_server_config_with_raw_key_support(*kt); + let client_config_rpk = make_client_config_with_raw_key_support(*kt, &provider); + let server_config_rpk = make_server_config_with_raw_key_support(*kt, &provider); add_keylog_and_do_altered_handshake( client_config_rpk, server_config_rpk, diff --git a/rustls/tests/client_cert_verifier.rs b/rustls/tests/client_cert_verifier.rs index a0915d9d31..ea9ef1a121 100644 --- a/rustls/tests/client_cert_verifier.rs +++ b/rustls/tests/client_cert_verifier.rs @@ -46,8 +46,9 @@ fn server_config_with_verifier( #[test] // Happy path, we resolve to a root, it is verified OK, should be able to connect fn client_verifier_works() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES.iter() { - let client_verifier = MockClientVerifier::new(ver_ok, *kt); + let client_verifier = MockClientVerifier::new(ver_ok, *kt, &provider); let server_config = server_config_with_verifier(*kt, client_verifier); let server_config = Arc::new(server_config); @@ -64,8 +65,9 @@ fn client_verifier_works() { // Server offers no verification schemes #[test] fn client_verifier_no_schemes() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES.iter() { - let mut client_verifier = MockClientVerifier::new(ver_ok, *kt); + let mut client_verifier = MockClientVerifier::new(ver_ok, *kt, &provider); client_verifier.offered_schemes = Some(vec![]); let server_config = server_config_with_verifier(*kt, client_verifier); let server_config = Arc::new(server_config); @@ -88,8 +90,9 @@ fn client_verifier_no_schemes() { // If we do have a root, we must do auth #[test] fn client_verifier_no_auth_yes_root() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES.iter() { - let client_verifier = MockClientVerifier::new(ver_unreachable, *kt); + let client_verifier = MockClientVerifier::new(ver_unreachable, *kt, &provider); let server_config = server_config_with_verifier(*kt, client_verifier); let server_config = Arc::new(server_config); @@ -115,8 +118,9 @@ fn client_verifier_no_auth_yes_root() { #[test] // Triple checks we propagate the rustls::Error through fn client_verifier_fails_properly() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES.iter() { - let client_verifier = MockClientVerifier::new(ver_err, *kt); + let client_verifier = MockClientVerifier::new(ver_err, *kt, &provider); let server_config = server_config_with_verifier(*kt, client_verifier); let server_config = Arc::new(server_config); diff --git a/rustls/tests/common/mod.rs b/rustls/tests/common/mod.rs index 31c7b2b1fb..aa40837c8b 100644 --- a/rustls/tests/common/mod.rs +++ b/rustls/tests/common/mod.rs @@ -17,7 +17,9 @@ use rustls::client::{ WebPkiServerVerifier, }; use rustls::crypto::cipher::{InboundOpaqueMessage, MessageDecrypter, MessageEncrypter}; -use rustls::crypto::{CryptoProvider, verify_tls13_signature_with_raw_key}; +use rustls::crypto::{ + CryptoProvider, WebPkiSupportedAlgorithms, verify_tls13_signature_with_raw_key, +}; use rustls::internal::msgs::codec::{Codec, Reader}; use rustls::internal::msgs::message::{Message, OutboundOpaqueMessage, PlainMessage}; use rustls::server::danger::{ClientCertVerified, ClientCertVerifier}; @@ -388,8 +390,11 @@ impl KeyType { SubjectPublicKeyInfoDer::from_pem_slice(self.bytes_for("client.spki.pem")).unwrap() } - pub fn get_certified_client_key(&self) -> Result, Error> { - let private_key = provider::default_provider() + pub fn get_certified_client_key( + &self, + provider: &CryptoProvider, + ) -> Result, Error> { + let private_key = provider .key_provider .load_private_key(self.get_client_key())?; let public_key = private_key @@ -402,8 +407,11 @@ impl KeyType { ))) } - pub fn certified_key_with_raw_pub_key(&self) -> Result, Error> { - let private_key = provider::default_provider() + pub fn certified_key_with_raw_pub_key( + &self, + provider: &CryptoProvider, + ) -> Result, Error> { + let private_key = provider .key_provider .load_private_key(self.get_key())?; let public_key = private_key @@ -416,8 +424,11 @@ impl KeyType { ))) } - pub fn certified_key_with_cert_chain(&self) -> Result, Error> { - let private_key = provider::default_provider() + pub fn certified_key_with_cert_chain( + &self, + provider: &CryptoProvider, + ) -> Result, Error> { + let private_key = provider .key_provider .load_private_key(self.get_key())?; Ok(Arc::new(CertifiedKey::new(self.get_chain(), private_key))) @@ -584,10 +595,14 @@ pub fn make_server_config_with_client_verifier( .unwrap() } -pub fn make_server_config_with_raw_key_support(kt: KeyType) -> ServerConfig { - let mut client_verifier = MockClientVerifier::new(|| Ok(ClientCertVerified::assertion()), kt); +pub fn make_server_config_with_raw_key_support( + kt: KeyType, + provider: &CryptoProvider, +) -> ServerConfig { + let mut client_verifier = + MockClientVerifier::new(|| Ok(ClientCertVerified::assertion()), kt, provider); let server_cert_resolver = Arc::new(AlwaysResolvesServerRawPublicKeys::new( - kt.certified_key_with_raw_pub_key() + kt.certified_key_with_raw_pub_key(provider) .unwrap(), )); client_verifier.expect_raw_public_keys = true; @@ -597,10 +612,14 @@ pub fn make_server_config_with_raw_key_support(kt: KeyType) -> ServerConfig { .with_cert_resolver(server_cert_resolver) } -pub fn make_client_config_with_raw_key_support(kt: KeyType) -> ClientConfig { - let server_verifier = Arc::new(MockServerVerifier::expects_raw_public_keys()); +pub fn make_client_config_with_raw_key_support( + kt: KeyType, + provider: &CryptoProvider, +) -> ClientConfig { + let server_verifier = Arc::new(MockServerVerifier::expects_raw_public_keys(provider)); let client_cert_resolver = Arc::new(AlwaysResolvesClientRawPublicKeys::new( - kt.get_certified_client_key().unwrap(), + kt.get_certified_client_key(provider) + .unwrap(), )); // We don't support tls1.2 for Raw Public Keys, hence the version is hard-coded. client_config_builder_with_versions(&[&rustls::version::TLS13]) @@ -612,15 +631,17 @@ pub fn make_client_config_with_raw_key_support(kt: KeyType) -> ClientConfig { pub fn make_client_config_with_cipher_suite_and_raw_key_support( kt: KeyType, cipher_suite: SupportedCipherSuite, + provider: &CryptoProvider, ) -> ClientConfig { - let server_verifier = Arc::new(MockServerVerifier::expects_raw_public_keys()); + let server_verifier = Arc::new(MockServerVerifier::expects_raw_public_keys(provider)); let client_cert_resolver = Arc::new(AlwaysResolvesClientRawPublicKeys::new( - kt.get_certified_client_key().unwrap(), + kt.get_certified_client_key(provider) + .unwrap(), )); ClientConfig::builder_with_provider( CryptoProvider { cipher_suites: vec![cipher_suite], - ..provider::default_provider() + ..provider.clone() } .into(), ) @@ -1072,6 +1093,7 @@ pub struct MockServerVerifier { signature_schemes: Vec, expected_ocsp_response: Option>, requires_raw_public_keys: bool, + raw_public_key_algorithms: Option, } impl ServerCertVerifier for MockServerVerifier { @@ -1121,7 +1143,9 @@ impl ServerCertVerifier for MockServerVerifier { message, &SubjectPublicKeyInfoDer::from(cert.as_ref()), dss, - &provider::default_provider().signature_verification_algorithms, + self.raw_public_key_algorithms + .as_ref() + .unwrap(), ), _ => Ok(HandshakeSignatureValid::assertion()), } @@ -1179,9 +1203,10 @@ impl MockServerVerifier { } } - pub fn expects_raw_public_keys() -> Self { + pub fn expects_raw_public_keys(provider: &CryptoProvider) -> Self { MockServerVerifier { requires_raw_public_keys: true, + raw_public_key_algorithms: Some(provider.signature_verification_algorithms), ..Default::default() } } @@ -1203,6 +1228,7 @@ impl Default for MockServerVerifier { ], expected_ocsp_response: None, requires_raw_public_keys: false, + raw_public_key_algorithms: None, } } } @@ -1213,12 +1239,17 @@ pub struct MockClientVerifier { pub subjects: Vec, pub mandatory: bool, pub offered_schemes: Option>, - pub expect_raw_public_keys: bool, + expect_raw_public_keys: bool, + raw_public_key_algorithms: Option, parent: Arc, } impl MockClientVerifier { - pub fn new(verified: fn() -> Result, kt: KeyType) -> Self { + pub fn new( + verified: fn() -> Result, + kt: KeyType, + provider: &CryptoProvider, + ) -> Self { Self { parent: webpki_client_verifier_builder(get_client_root_store(kt)) .build() @@ -1228,6 +1259,7 @@ impl MockClientVerifier { mandatory: true, offered_schemes: None, expect_raw_public_keys: false, + raw_public_key_algorithms: Some(provider.signature_verification_algorithms), } } } @@ -1275,7 +1307,9 @@ impl ClientCertVerifier for MockClientVerifier { message, &SubjectPublicKeyInfoDer::from(cert.as_ref()), dss, - &provider::default_provider().signature_verification_algorithms, + self.raw_public_key_algorithms + .as_ref() + .unwrap(), ) } else { self.parent diff --git a/rustls/tests/server_cert_verifier.rs b/rustls/tests/server_cert_verifier.rs index 1f0a6bb7d2..d0eadac386 100644 --- a/rustls/tests/server_cert_verifier.rs +++ b/rustls/tests/server_cert_verifier.rs @@ -238,7 +238,7 @@ fn client_can_request_certain_trusted_cas() { kt.ca_distinguished_name() .to_vec() .into(), - kt.certified_key_with_cert_chain() + kt.certified_key_with_cert_chain(&provider::default_provider()) .unwrap(), ) }) From 549d0f9ec09f8661ab176b0f42b66d01fa6e0f01 Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 12:22:37 +0100 Subject: [PATCH 08/14] Make provider explicit in general testing helpers --- rustls/benches/benchmarks.rs | 2 +- rustls/tests/api.rs | 938 +++++++++++++++++---------- rustls/tests/client_cert_verifier.rs | 13 +- rustls/tests/common/mod.rs | 94 ++- rustls/tests/key_log_file_env.rs | 12 +- rustls/tests/server_cert_verifier.rs | 48 +- rustls/tests/unbuffered.rs | 86 ++- 7 files changed, 766 insertions(+), 427 deletions(-) diff --git a/rustls/benches/benchmarks.rs b/rustls/benches/benchmarks.rs index 28d596ff3c..85c2013ff9 100644 --- a/rustls/benches/benchmarks.rs +++ b/rustls/benches/benchmarks.rs @@ -13,7 +13,7 @@ use rustls::ServerConnection; use test_utils::*; fn bench_ewouldblock(c: &mut Bencher) { - let server_config = make_server_config(KeyType::Rsa2048); + let server_config = make_server_config(KeyType::Rsa2048, &provider::default_provider()); let mut server = ServerConnection::new(Arc::new(server_config)).unwrap(); let mut read_ewouldblock = FailsReads::new(io::ErrorKind::WouldBlock); c.iter(|| server.read_tls(&mut read_ewouldblock)); diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index 3230274e24..75a4aac4c8 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -118,7 +118,7 @@ mod test_raw_keys { let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { let client_config_rpk = make_client_config_with_raw_key_support(*kt, &provider); - let server_config = make_server_config(*kt); + let server_config = make_server_config(*kt, &provider); let (mut client_rpk, mut server) = make_pair_for_configs(client_config_rpk, server_config); @@ -144,7 +144,8 @@ mod test_raw_keys { fn only_server_supports_raw_keys() { let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let client_config = make_client_config_with_versions(*kt, &[&rustls::version::TLS13]); + let client_config = + make_client_config_with_versions(*kt, &[&rustls::version::TLS13], &provider); let server_config_rpk = make_server_config_with_raw_key_support(*kt, &provider); let (mut client, mut server_rpk) = @@ -194,10 +195,10 @@ mod test_raw_keys { ) { let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let client_config = Arc::new(make_client_config(*kt)); + let client_config = Arc::new(make_client_config(*kt, &provider)); let server_config_rpk = match server_requires_raw_keys { true => Arc::new(make_server_config_with_raw_key_support(*kt, &provider)), - false => Arc::new(make_server_config(*kt)), + false => Arc::new(make_server_config(*kt, &provider)), }; // Alter Client Hello client certificate extension @@ -245,8 +246,8 @@ mod test_raw_keys { ); // Alter Server Hello server certificate extension and expect UnexpectedCertificateTypeExtension error - let client_config = make_client_config(*kt); - let server_config_rpk = make_server_config(*kt); + let client_config = make_client_config(*kt, &provider); + let server_config_rpk = make_server_config(*kt, &provider); add_keylog_and_do_altered_handshake( client_config, server_config_rpk, @@ -257,8 +258,8 @@ mod test_raw_keys { ); // Alter Server Hello client certificate extension and expect UnexpectedCertificateTypeExtension error - let client_config = make_client_config(*kt); - let server_config_rpk = make_server_config(*kt); + let client_config = make_client_config(*kt, &provider); + let server_config_rpk = make_server_config(*kt, &provider); add_keylog_and_do_altered_handshake( client_config, server_config_rpk, @@ -456,13 +457,15 @@ fn alpn_test_error( agreed: Option<&[u8]>, expected_error: Option, ) { - let mut server_config = make_server_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider); server_config.alpn_protocols = server_protos; let server_config = Arc::new(server_config); for version in rustls::ALL_VERSIONS { - let mut client_config = make_client_config_with_versions(KeyType::Rsa2048, &[version]); + let mut client_config = + make_client_config_with_versions(KeyType::Rsa2048, &[version], &provider); client_config .alpn_protocols .clone_from(&client_protos); @@ -520,12 +523,13 @@ fn alpn() { #[test] fn connection_level_alpn_protocols() { - let mut server_config = make_server_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider); server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let server_config = Arc::new(server_config); // Config specifies `h2` - let mut client_config = make_client_config(KeyType::Rsa2048); + let mut client_config = make_client_config(KeyType::Rsa2048, &provider); client_config.alpn_protocols = vec![b"h2".to_vec()]; let client_config = Arc::new(client_config); @@ -553,6 +557,7 @@ fn version_test( server_versions: &[&'static rustls::SupportedProtocolVersion], result: Option, ) { + let provider = provider::default_provider(); let client_versions = if client_versions.is_empty() { rustls::ALL_VERSIONS } else { @@ -564,8 +569,10 @@ fn version_test( server_versions }; - let client_config = make_client_config_with_versions(KeyType::Rsa2048, client_versions); - let server_config = make_server_config_with_versions(KeyType::Rsa2048, server_versions); + let client_config = + make_client_config_with_versions(KeyType::Rsa2048, client_versions, &provider); + let server_config = + make_server_config_with_versions(KeyType::Rsa2048, server_versions, &provider); println!("version {client_versions:?} {server_versions:?} -> {result:?}"); @@ -797,10 +804,12 @@ fn config_builder_for_server_with_time() { #[test] fn buffered_client_data_sent() { - let server_config = Arc::new(make_server_config(KeyType::Rsa2048)); + let provider = provider::default_provider(); + let server_config = Arc::new(make_server_config(KeyType::Rsa2048, &provider)); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(KeyType::Rsa2048, &[version]); + let client_config = + make_client_config_with_versions(KeyType::Rsa2048, &[version], &provider); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); @@ -816,10 +825,12 @@ fn buffered_client_data_sent() { #[test] fn buffered_server_data_sent() { - let server_config = Arc::new(make_server_config(KeyType::Rsa2048)); + let provider = provider::default_provider(); + let server_config = Arc::new(make_server_config(KeyType::Rsa2048, &provider)); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(KeyType::Rsa2048, &[version]); + let client_config = + make_client_config_with_versions(KeyType::Rsa2048, &[version], &provider); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); @@ -835,10 +846,12 @@ fn buffered_server_data_sent() { #[test] fn buffered_both_data_sent() { - let server_config = Arc::new(make_server_config(KeyType::Rsa2048)); + let provider = provider::default_provider(); + let server_config = Arc::new(make_server_config(KeyType::Rsa2048, &provider)); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(KeyType::Rsa2048, &[version]); + let client_config = + make_client_config_with_versions(KeyType::Rsa2048, &[version], &provider); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); @@ -871,11 +884,12 @@ fn buffered_both_data_sent() { #[test] fn client_can_get_server_cert() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(*kt, &[version]); + let client_config = make_client_config_with_versions(*kt, &[version], &provider); let (mut client, mut server) = - make_pair_for_configs(client_config, make_server_config(*kt)); + make_pair_for_configs(client_config, make_server_config(*kt, &provider)); do_handshake(&mut client, &mut server); let certs = client.peer_certificates(); @@ -886,10 +900,11 @@ fn client_can_get_server_cert() { #[test] fn client_can_get_server_cert_after_resumption() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let server_config = make_server_config(*kt); + let server_config = make_server_config(*kt, &provider); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(*kt, &[version]); + let client_config = make_client_config_with_versions(*kt, &[version], &provider); let (mut client, mut server) = make_pair_for_configs(client_config.clone(), server_config.clone()); do_handshake(&mut client, &mut server); @@ -911,13 +926,14 @@ fn client_can_get_server_cert_after_resumption() { #[test] fn client_only_attempts_resumption_with_compatible_security() { + let provider = provider::default_provider(); let kt = KeyType::Rsa2048; CountingLogger::install(); CountingLogger::reset(); - let server_config = make_server_config(kt); + let server_config = make_server_config(kt, &provider); for version in rustls::ALL_VERSIONS { - let base_client_config = make_client_config_with_versions(kt, &[version]); + let base_client_config = make_client_config_with_versions(kt, &[version], &provider); let (mut client, mut server) = make_pair_for_configs(base_client_config.clone(), server_config.clone()); do_handshake(&mut client, &mut server); @@ -939,7 +955,8 @@ fn client_only_attempts_resumption_with_compatible_security() { // disallowed case: unmatching `client_auth_cert_resolver` let mut client_config = ClientConfig::clone(&base_client_config); client_config.client_auth_cert_resolver = - make_client_config_with_versions_with_auth(kt, &[version]).client_auth_cert_resolver; + make_client_config_with_versions_with_auth(kt, &[version], &provider) + .client_auth_cert_resolver; CountingLogger::reset(); let (mut client, mut server) = @@ -954,7 +971,8 @@ fn client_only_attempts_resumption_with_compatible_security() { })); // disallowed case: unmatching `verifier` - let mut client_config = make_client_config_with_versions_with_auth(kt, &[version]); + let mut client_config = + make_client_config_with_versions_with_auth(kt, &[version], &provider); client_config.resumption = base_client_config.resumption.clone(); client_config.client_auth_cert_resolver = Arc::clone(&base_client_config.client_auth_cert_resolver); @@ -976,11 +994,15 @@ fn client_only_attempts_resumption_with_compatible_security() { #[test] fn server_can_get_client_cert() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let server_config = Arc::new(make_server_config_with_mandatory_client_auth(*kt)); + let server_config = Arc::new(make_server_config_with_mandatory_client_auth( + *kt, &provider, + )); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(*kt, &[version]); + let client_config = + make_client_config_with_versions_with_auth(*kt, &[version], &provider); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); do_handshake(&mut client, &mut server); @@ -993,11 +1015,15 @@ fn server_can_get_client_cert() { #[test] fn server_can_get_client_cert_after_resumption() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let server_config = Arc::new(make_server_config_with_mandatory_client_auth(*kt)); + let server_config = Arc::new(make_server_config_with_mandatory_client_auth( + *kt, &provider, + )); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(*kt, &[version]); + let client_config = + make_client_config_with_versions_with_auth(*kt, &[version], &provider); let client_config = Arc::new(client_config); let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); @@ -1015,10 +1041,11 @@ fn server_can_get_client_cert_after_resumption() { #[test] fn resumption_combinations() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let server_config = make_server_config(*kt); + let server_config = make_server_config(*kt, &provider); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(*kt, &[version]); + let client_config = make_client_config_with_versions(*kt, &[version], &provider); let (mut client, mut server) = make_pair_for_configs(client_config.clone(), server_config.clone()); do_handshake(&mut client, &mut server); @@ -1094,7 +1121,10 @@ fn test_config_builders_debug() { .into(), ); let _ = format!("{b:?}"); - let b = server_config_builder_with_versions(&[&rustls::version::TLS13]); + let b = server_config_builder_with_versions( + &[&rustls::version::TLS13], + &provider::default_provider(), + ); let _ = format!("{b:?}"); let b = b.with_no_client_auth(); let _ = format!("{b:?}"); @@ -1108,7 +1138,10 @@ fn test_config_builders_debug() { .into(), ); let _ = format!("{b:?}"); - let b = client_config_builder_with_versions(&[&rustls::version::TLS13]); + let b = client_config_builder_with_versions( + &[&rustls::version::TLS13], + &provider::default_provider(), + ); let _ = format!("{b:?}"); } @@ -1118,15 +1151,16 @@ fn test_config_builders_debug() { /// certificate and not being given one. #[test] fn server_allow_any_anonymous_or_authenticated_client() { + let provider = provider::default_provider(); let kt = KeyType::Rsa2048; for client_cert_chain in [None, Some(kt.get_client_chain())] { let client_auth_roots = get_client_root_store(kt); - let client_auth = webpki_client_verifier_builder(client_auth_roots.clone()) + let client_auth = webpki_client_verifier_builder(client_auth_roots.clone(), &provider) .allow_unauthenticated() .build() .unwrap(); - let server_config = server_config_builder() + let server_config = server_config_builder(&provider) .with_client_cert_verifier(client_auth) .with_single_cert(kt.get_chain(), kt.get_key()) .unwrap(); @@ -1134,9 +1168,9 @@ fn server_allow_any_anonymous_or_authenticated_client() { for version in rustls::ALL_VERSIONS { let client_config = if client_cert_chain.is_some() { - make_client_config_with_versions_with_auth(kt, &[version]) + make_client_config_with_versions_with_auth(kt, &[version], &provider) } else { - make_client_config_with_versions(kt, &[version]) + make_client_config_with_versions(kt, &[version], &provider) }; let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); @@ -1155,11 +1189,12 @@ fn check_read_and_close(reader: &mut dyn io::Read, expect: &[u8]) { #[test] fn server_close_notify() { + let provider = provider::default_provider(); let kt = KeyType::Rsa2048; - let server_config = Arc::new(make_server_config_with_mandatory_client_auth(kt)); + let server_config = Arc::new(make_server_config_with_mandatory_client_auth(kt, &provider)); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(kt, &[version]); + let client_config = make_client_config_with_versions_with_auth(kt, &[version], &provider); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); do_handshake(&mut client, &mut server); @@ -1194,11 +1229,12 @@ fn server_close_notify() { #[test] fn client_close_notify() { + let provider = provider::default_provider(); let kt = KeyType::Rsa2048; - let server_config = Arc::new(make_server_config_with_mandatory_client_auth(kt)); + let server_config = Arc::new(make_server_config_with_mandatory_client_auth(kt, &provider)); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(kt, &[version]); + let client_config = make_client_config_with_versions_with_auth(kt, &[version], &provider); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); do_handshake(&mut client, &mut server); @@ -1233,11 +1269,12 @@ fn client_close_notify() { #[test] fn server_closes_uncleanly() { + let provider = provider::default_provider(); let kt = KeyType::Rsa2048; - let server_config = Arc::new(make_server_config(kt)); + let server_config = Arc::new(make_server_config(kt, &provider)); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(kt, &[version]); + let client_config = make_client_config_with_versions(kt, &[version], &provider); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); do_handshake(&mut client, &mut server); @@ -1278,11 +1315,12 @@ fn server_closes_uncleanly() { #[test] fn client_closes_uncleanly() { + let provider = provider::default_provider(); let kt = KeyType::Rsa2048; - let server_config = Arc::new(make_server_config(kt)); + let server_config = Arc::new(make_server_config(kt, &provider)); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(kt, &[version]); + let client_config = make_client_config_with_versions(kt, &[version], &provider); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); do_handshake(&mut client, &mut server); @@ -1323,7 +1361,7 @@ fn client_closes_uncleanly() { #[test] fn test_tls13_valid_early_plaintext_alert() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); // Perform the start of a TLS 1.3 handshake, sending a client hello to the server. // The client will not have written a CCS or any encrypted messages to the server yet. @@ -1352,7 +1390,7 @@ fn test_tls13_valid_early_plaintext_alert() { #[test] fn test_tls13_too_short_early_plaintext_alert() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); // Perform the start of a TLS 1.3 handshake, sending a client hello to the server. // The client will not have written a CCS or any encrypted messages to the server yet. @@ -1375,7 +1413,7 @@ fn test_tls13_too_short_early_plaintext_alert() { #[test] fn test_tls13_late_plaintext_alert() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); // Complete a bi-directional TLS1.3 handshake. After this point no plaintext messages // should occur. @@ -1488,9 +1526,10 @@ impl ResolvesServerCert for ServerCheckCertResolve { #[test] fn server_cert_resolve_with_sni() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let client_config = make_client_config(*kt); - let mut server_config = make_server_config(*kt); + let client_config = make_client_config(*kt, &provider); + let mut server_config = make_server_config(*kt, &provider); server_config.cert_resolver = Arc::new(ServerCheckCertResolve { expected_sni: Some("the-value-from-sni".into()), @@ -1509,11 +1548,12 @@ fn server_cert_resolve_with_sni() { #[test] fn server_cert_resolve_with_alpn() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let mut client_config = make_client_config(*kt); + let mut client_config = make_client_config(*kt, &provider); client_config.alpn_protocols = vec!["foo".into(), "bar".into()]; - let mut server_config = make_server_config(*kt); + let mut server_config = make_server_config(*kt, &provider); server_config.cert_resolver = Arc::new(ServerCheckCertResolve { expected_alpn: Some(vec![b"foo".to_vec(), b"bar".to_vec()]), ..Default::default() @@ -1530,9 +1570,10 @@ fn server_cert_resolve_with_alpn() { #[test] fn client_trims_terminating_dot() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let client_config = make_client_config(*kt); - let mut server_config = make_server_config(*kt); + let client_config = make_client_config(*kt, &provider); + let mut server_config = make_server_config(*kt, &provider); server_config.cert_resolver = Arc::new(ServerCheckCertResolve { expected_sni: Some("some-host.com".into()), @@ -1567,7 +1608,7 @@ fn check_sigalgs_reduced_by_ciphersuite( .unwrap(), ); - let mut server_config = make_server_config(kt); + let mut server_config = make_server_config(kt, &provider::default_provider()); server_config.cert_resolver = Arc::new(ServerCheckCertResolve { expected_sigalgs: Some(expected_sigalgs), @@ -1636,13 +1677,14 @@ impl ResolvesServerCert for ServerCheckNoSni { #[test] fn client_with_sni_disabled_does_not_send_sni() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let mut server_config = make_server_config(*kt); + let mut server_config = make_server_config(*kt, &provider); server_config.cert_resolver = Arc::new(ServerCheckNoSni {}); let server_config = Arc::new(server_config); for version in rustls::ALL_VERSIONS { - let mut client_config = make_client_config_with_versions(*kt, &[version]); + let mut client_config = make_client_config_with_versions(*kt, &[version], &provider); client_config.enable_sni = false; let mut client = @@ -1658,11 +1700,12 @@ fn client_with_sni_disabled_does_not_send_sni() { #[test] fn client_checks_server_certificate_with_given_name() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let server_config = Arc::new(make_server_config(*kt)); + let server_config = Arc::new(make_server_config(*kt, &provider)); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(*kt, &[version]); + let client_config = make_client_config_with_versions(*kt, &[version], &provider); let mut client = ClientConnection::new( Arc::new(client_config), server_name("not-the-right-hostname.com"), @@ -1693,11 +1736,13 @@ fn client_checks_server_certificate_with_given_ip_address() { do_handshake_until_error(&mut client, &mut server) } + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let server_config = Arc::new(make_server_config(*kt)); + let server_config = Arc::new(make_server_config(*kt, &provider)); for version in rustls::ALL_VERSIONS { - let client_config = Arc::new(make_client_config_with_versions(*kt, &[version])); + let client_config = + Arc::new(make_client_config_with_versions(*kt, &[version], &provider)); // positive ipv4 case assert_eq!( @@ -1732,17 +1777,19 @@ fn client_checks_server_certificate_with_given_ip_address() { #[test] fn client_check_server_certificate_ee_revoked() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let server_config = Arc::new(make_server_config(*kt)); + let server_config = Arc::new(make_server_config(*kt, &provider)); // Setup a server verifier that will check the EE certificate's revocation status. let crls = vec![kt.end_entity_crl()]; - let builder = webpki_server_verifier_builder(get_client_root_store(*kt)) + let builder = webpki_server_verifier_builder(get_client_root_store(*kt), &provider) .with_crls(crls) .only_check_end_entity_revocation(); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_verifier(&[version], builder.clone()); + let client_config = + make_client_config_with_verifier(&[version], builder.clone(), &provider); let mut client = ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); @@ -1761,26 +1808,32 @@ fn client_check_server_certificate_ee_revoked() { #[test] fn client_check_server_certificate_ee_unknown_revocation() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let server_config = Arc::new(make_server_config(*kt)); + let server_config = Arc::new(make_server_config(*kt, &provider)); // Setup a server verifier builder that will check the EE certificate's revocation status, but not // allow unknown revocation status (the default). We'll provide CRLs that are not relevant // to the EE cert to ensure its status is unknown. let unrelated_crls = vec![kt.intermediate_crl()]; - let forbid_unknown_verifier = webpki_server_verifier_builder(get_client_root_store(*kt)) - .with_crls(unrelated_crls.clone()) - .only_check_end_entity_revocation(); + let forbid_unknown_verifier = + webpki_server_verifier_builder(get_client_root_store(*kt), &provider) + .with_crls(unrelated_crls.clone()) + .only_check_end_entity_revocation(); // Also set up a verifier builder that will allow unknown revocation status. - let allow_unknown_verifier = webpki_server_verifier_builder(get_client_root_store(*kt)) - .with_crls(unrelated_crls) - .only_check_end_entity_revocation() - .allow_unknown_revocation_status(); + let allow_unknown_verifier = + webpki_server_verifier_builder(get_client_root_store(*kt), &provider) + .with_crls(unrelated_crls) + .only_check_end_entity_revocation() + .allow_unknown_revocation_status(); for version in rustls::ALL_VERSIONS { - let client_config = - make_client_config_with_verifier(&[version], forbid_unknown_verifier.clone()); + let client_config = make_client_config_with_verifier( + &[version], + forbid_unknown_verifier.clone(), + &provider, + ); let mut client = ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); @@ -1796,8 +1849,11 @@ fn client_check_server_certificate_ee_unknown_revocation() { ); // We expect if we use the allow_unknown_verifier that the handshake will not fail. - let client_config = - make_client_config_with_verifier(&[version], allow_unknown_verifier.clone()); + let client_config = make_client_config_with_verifier( + &[version], + allow_unknown_verifier.clone(), + &provider, + ); let mut client = ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); @@ -1809,28 +1865,33 @@ fn client_check_server_certificate_ee_unknown_revocation() { #[test] fn client_check_server_certificate_intermediate_revoked() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let server_config = Arc::new(make_server_config(*kt)); + let server_config = Arc::new(make_server_config(*kt, &provider)); // Setup a server verifier builder that will check the full chain revocation status against a CRL // that marks the intermediate certificate as revoked. We allow unknown revocation status // so the EE cert's unknown status doesn't cause an error. let crls = vec![kt.intermediate_crl()]; let full_chain_verifier_builder = - webpki_server_verifier_builder(get_client_root_store(*kt)) + webpki_server_verifier_builder(get_client_root_store(*kt), &provider) .with_crls(crls.clone()) .allow_unknown_revocation_status(); // Also set up a verifier builder that will use the same CRL, but only check the EE certificate // revocation status. - let ee_verifier_builder = webpki_server_verifier_builder(get_client_root_store(*kt)) - .with_crls(crls.clone()) - .only_check_end_entity_revocation() - .allow_unknown_revocation_status(); + let ee_verifier_builder = + webpki_server_verifier_builder(get_client_root_store(*kt), &provider) + .with_crls(crls.clone()) + .only_check_end_entity_revocation() + .allow_unknown_revocation_status(); for version in rustls::ALL_VERSIONS { - let client_config = - make_client_config_with_verifier(&[version], full_chain_verifier_builder.clone()); + let client_config = make_client_config_with_verifier( + &[version], + full_chain_verifier_builder.clone(), + &provider, + ); let mut client = ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); @@ -1845,8 +1906,11 @@ fn client_check_server_certificate_intermediate_revoked() { ))) ); - let client_config = - make_client_config_with_verifier(&[version], ee_verifier_builder.clone()); + let client_config = make_client_config_with_verifier( + &[version], + ee_verifier_builder.clone(), + &provider, + ); let mut client = ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); @@ -1860,25 +1924,31 @@ fn client_check_server_certificate_intermediate_revoked() { #[test] fn client_check_server_certificate_ee_crl_expired() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let server_config = Arc::new(make_server_config(*kt)); + let server_config = Arc::new(make_server_config(*kt, &provider)); // Setup a server verifier that will check the EE certificate's revocation status, with CRL expiration enforced. let crls = vec![kt.end_entity_crl_expired()]; - let enforce_expiration_builder = webpki_server_verifier_builder(get_client_root_store(*kt)) - .with_crls(crls) - .only_check_end_entity_revocation() - .enforce_revocation_expiration(); + let enforce_expiration_builder = + webpki_server_verifier_builder(get_client_root_store(*kt), &provider) + .with_crls(crls) + .only_check_end_entity_revocation() + .enforce_revocation_expiration(); // Also setup a server verifier without CRL expiration enforced. let crls = vec![kt.end_entity_crl_expired()]; - let ignore_expiration_builder = webpki_server_verifier_builder(get_client_root_store(*kt)) - .with_crls(crls) - .only_check_end_entity_revocation(); + let ignore_expiration_builder = + webpki_server_verifier_builder(get_client_root_store(*kt), &provider) + .with_crls(crls) + .only_check_end_entity_revocation(); for version in rustls::ALL_VERSIONS { - let client_config = - make_client_config_with_verifier(&[version], enforce_expiration_builder.clone()); + let client_config = make_client_config_with_verifier( + &[version], + enforce_expiration_builder.clone(), + &provider, + ); let mut client = ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); @@ -1892,8 +1962,11 @@ fn client_check_server_certificate_ee_crl_expired() { ))) )); - let client_config = - make_client_config_with_verifier(&[version], ignore_expiration_builder.clone()); + let client_config = make_client_config_with_verifier( + &[version], + ignore_expiration_builder.clone(), + &provider, + ); let mut client = ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); @@ -2039,10 +2112,11 @@ fn test_client_cert_resolve( server_config: Arc, expected_root_hint_subjects: Vec>, ) { + let provider = provider::default_provider(); for version in rustls::ALL_VERSIONS { println!("{:?} {:?}:", version.version, key_type); - let mut client_config = make_client_config_with_versions(key_type, &[version]); + let mut client_config = make_client_config_with_versions(key_type, &[version], &provider); client_config.client_auth_cert_resolver = Arc::new(ClientCheckCertResolve::new( 1, expected_root_hint_subjects.clone(), @@ -2090,8 +2164,11 @@ fn default_signature_schemes(version: ProtocolVersion) -> Vec { fn client_cert_resolve_default() { // Test that in the default configuration that a client cert resolver gets the expected // CA subject hints, and supported signature algorithms. + let provider = provider::default_provider(); for key_type in ALL_KEY_TYPES { - let server_config = Arc::new(make_server_config_with_mandatory_client_auth(*key_type)); + let server_config = Arc::new(make_server_config_with_mandatory_client_auth( + *key_type, &provider, + )); // In a default configuration we expect that the verifier's trust anchors are used // for the hint subjects. @@ -2109,11 +2186,12 @@ fn client_cert_resolve_default() { fn client_cert_resolve_server_no_hints() { // Test that a server can provide no hints and the client cert resolver gets the expected // arguments. + let provider = provider::default_provider(); for key_type in ALL_KEY_TYPES { // Build a verifier with no hint subjects. - let verifier = webpki_client_verifier_builder(get_client_root_store(*key_type)) + let verifier = webpki_client_verifier_builder(get_client_root_store(*key_type), &provider) .clear_root_hint_subjects(); - let server_config = make_server_config_with_client_verifier(*key_type, verifier); + let server_config = make_server_config_with_client_verifier(*key_type, verifier, &provider); let expected_root_hint_subjects = Vec::default(); // no hints expected. test_client_cert_resolve(*key_type, server_config.into(), expected_root_hint_subjects); } @@ -2123,6 +2201,7 @@ fn client_cert_resolve_server_no_hints() { fn client_cert_resolve_server_added_hint() { // Test that a server can add an extra subject above/beyond those found in its trust store // and the client cert resolver gets the expected arguments. + let provider = provider::default_provider(); let extra_name = b"0\x1a1\x180\x16\x06\x03U\x04\x03\x0c\x0fponyland IDK CA".to_vec(); for key_type in ALL_KEY_TYPES { let expected_hint_subjects = vec![ @@ -2133,20 +2212,24 @@ fn client_cert_resolve_server_added_hint() { ]; // Create a verifier that adds the extra_name as a hint subject in addition to the ones // from the root cert store. - let verifier = webpki_client_verifier_builder(get_client_root_store(*key_type)) + let verifier = webpki_client_verifier_builder(get_client_root_store(*key_type), &provider) .add_root_hint_subjects([DistinguishedName::from(extra_name.clone())].into_iter()); - let server_config = make_server_config_with_client_verifier(*key_type, verifier); + let server_config = make_server_config_with_client_verifier(*key_type, verifier, &provider); test_client_cert_resolve(*key_type, server_config.into(), expected_hint_subjects); } } #[test] fn client_auth_works() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let server_config = Arc::new(make_server_config_with_mandatory_client_auth(*kt)); + let server_config = Arc::new(make_server_config_with_mandatory_client_auth( + *kt, &provider, + )); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(*kt, &[version]); + let client_config = + make_client_config_with_versions_with_auth(*kt, &[version], &provider); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); do_handshake(&mut client, &mut server); @@ -2159,43 +2242,52 @@ fn client_mandatory_auth_client_revocation_works() { for kt in ALL_KEY_TYPES { // Create a server configuration that includes a CRL that specifies the client certificate // is revoked. + let provider = provider::default_provider(); let relevant_crls = vec![kt.client_crl()]; // Only check the EE certificate status. See client_mandatory_auth_intermediate_revocation_works // for testing revocation status of the whole chain. - let ee_verifier_builder = webpki_client_verifier_builder(get_client_root_store(*kt)) - .with_crls(relevant_crls) - .only_check_end_entity_revocation(); + let ee_verifier_builder = + webpki_client_verifier_builder(get_client_root_store(*kt), &provider) + .with_crls(relevant_crls) + .only_check_end_entity_revocation(); let revoked_server_config = Arc::new(make_server_config_with_client_verifier( *kt, ee_verifier_builder, + &provider, )); // Create a server configuration that includes a CRL that doesn't cover the client certificate, // and uses the default behaviour of treating unknown revocation status as an error. let unrelated_crls = vec![kt.intermediate_crl()]; - let ee_verifier_builder = webpki_client_verifier_builder(get_client_root_store(*kt)) - .with_crls(unrelated_crls.clone()) - .only_check_end_entity_revocation(); + let ee_verifier_builder = + webpki_client_verifier_builder(get_client_root_store(*kt), &provider) + .with_crls(unrelated_crls.clone()) + .only_check_end_entity_revocation(); let missing_client_crl_server_config = Arc::new(make_server_config_with_client_verifier( *kt, ee_verifier_builder, + &provider, )); // Create a server configuration that includes a CRL that doesn't cover the client certificate, // but change the builder to allow unknown revocation status. - let ee_verifier_builder = webpki_client_verifier_builder(get_client_root_store(*kt)) - .with_crls(unrelated_crls.clone()) - .only_check_end_entity_revocation() - .allow_unknown_revocation_status(); + let ee_verifier_builder = + webpki_client_verifier_builder(get_client_root_store(*kt), &provider) + .with_crls(unrelated_crls.clone()) + .only_check_end_entity_revocation() + .allow_unknown_revocation_status(); let allow_missing_client_crl_server_config = Arc::new( - make_server_config_with_client_verifier(*kt, ee_verifier_builder), + make_server_config_with_client_verifier(*kt, ee_verifier_builder, &provider), ); for version in rustls::ALL_VERSIONS { // Connecting to the server with a CRL that indicates the client certificate is revoked // should fail with the expected error. - let client_config = - Arc::new(make_client_config_with_versions_with_auth(*kt, &[version])); + let client_config = Arc::new(make_client_config_with_versions_with_auth( + *kt, + &[version], + &provider, + )); let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &revoked_server_config); let err = do_handshake_until_error(&mut client, &mut server); @@ -2228,35 +2320,42 @@ fn client_mandatory_auth_client_revocation_works() { #[test] fn client_mandatory_auth_intermediate_revocation_works() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { // Create a server configuration that includes a CRL that specifies the intermediate certificate // is revoked. We check the full chain for revocation status (default), and allow unknown // revocation status so the EE's unknown revocation status isn't an error. let crls = vec![kt.intermediate_crl()]; let full_chain_verifier_builder = - webpki_client_verifier_builder(get_client_root_store(*kt)) + webpki_client_verifier_builder(get_client_root_store(*kt), &provider) .with_crls(crls.clone()) .allow_unknown_revocation_status(); let full_chain_server_config = Arc::new(make_server_config_with_client_verifier( *kt, full_chain_verifier_builder, + &provider, )); // Also create a server configuration that uses the same CRL, but that only checks the EE // cert revocation status. - let ee_only_verifier_builder = webpki_client_verifier_builder(get_client_root_store(*kt)) - .with_crls(crls) - .only_check_end_entity_revocation() - .allow_unknown_revocation_status(); + let ee_only_verifier_builder = + webpki_client_verifier_builder(get_client_root_store(*kt), &provider) + .with_crls(crls) + .only_check_end_entity_revocation() + .allow_unknown_revocation_status(); let ee_server_config = Arc::new(make_server_config_with_client_verifier( *kt, ee_only_verifier_builder, + &provider, )); for version in rustls::ALL_VERSIONS { // When checking the full chain, we expect an error - the intermediate is revoked. - let client_config = - Arc::new(make_client_config_with_versions_with_auth(*kt, &[version])); + let client_config = Arc::new(make_client_config_with_versions_with_auth( + *kt, + &[version], + &provider, + )); let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &full_chain_server_config); let err = do_handshake_until_error(&mut client, &mut server); @@ -2277,14 +2376,18 @@ fn client_mandatory_auth_intermediate_revocation_works() { #[test] fn client_optional_auth_client_revocation_works() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { // Create a server configuration that includes a CRL that specifies the client certificate // is revoked. let crls = vec![kt.client_crl()]; - let server_config = Arc::new(make_server_config_with_optional_client_auth(*kt, crls)); + let server_config = Arc::new(make_server_config_with_optional_client_auth( + *kt, crls, &provider, + )); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(*kt, &[version]); + let client_config = + make_client_config_with_versions_with_auth(*kt, &[version], &provider); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); // Because the client certificate is revoked, the handshake should fail. @@ -2301,7 +2404,7 @@ fn client_optional_auth_client_revocation_works() { #[test] fn client_error_is_sticky() { - let (mut client, _) = make_pair(KeyType::Rsa2048); + let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); client .read_tls(&mut b"\x16\x03\x03\x00\x08\x0f\x00\x00\x04junk".as_ref()) .unwrap(); @@ -2313,7 +2416,7 @@ fn client_error_is_sticky() { #[test] fn server_error_is_sticky() { - let (_, mut server) = make_pair(KeyType::Rsa2048); + let (_, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); server .read_tls(&mut b"\x16\x03\x03\x00\x08\x0f\x00\x00\x04junk".as_ref()) .unwrap(); @@ -2325,20 +2428,20 @@ fn server_error_is_sticky() { #[test] fn server_flush_does_nothing() { - let (_, mut server) = make_pair(KeyType::Rsa2048); + let (_, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); assert!(matches!(server.writer().flush(), Ok(()))); } #[test] fn client_flush_does_nothing() { - let (mut client, _) = make_pair(KeyType::Rsa2048); + let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); assert!(matches!(client.writer().flush(), Ok(()))); } #[allow(clippy::no_effect)] #[test] fn server_is_send_and_sync() { - let (_, server) = make_pair(KeyType::Rsa2048); + let (_, server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); &server as &dyn Send; &server as &dyn Sync; } @@ -2346,14 +2449,14 @@ fn server_is_send_and_sync() { #[allow(clippy::no_effect)] #[test] fn client_is_send_and_sync() { - let (client, _) = make_pair(KeyType::Rsa2048); + let (client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); &client as &dyn Send; &client as &dyn Sync; } #[test] fn server_respects_buffer_limit_pre_handshake() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); server.set_buffer_limit(Some(32)); @@ -2381,7 +2484,7 @@ fn server_respects_buffer_limit_pre_handshake() { #[test] fn server_respects_buffer_limit_pre_handshake_with_vectored_write() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); server.set_buffer_limit(Some(32)); @@ -2405,7 +2508,7 @@ fn server_respects_buffer_limit_pre_handshake_with_vectored_write() { #[test] fn server_respects_buffer_limit_post_handshake() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); // this test will vary in behaviour depending on the default suites do_handshake(&mut client, &mut server); @@ -2434,7 +2537,7 @@ fn server_respects_buffer_limit_post_handshake() { #[test] fn client_respects_buffer_limit_pre_handshake() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); client.set_buffer_limit(Some(32)); @@ -2462,7 +2565,7 @@ fn client_respects_buffer_limit_pre_handshake() { #[test] fn client_respects_buffer_limit_pre_handshake_with_vectored_write() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); client.set_buffer_limit(Some(32)); @@ -2486,7 +2589,7 @@ fn client_respects_buffer_limit_pre_handshake_with_vectored_write() { #[test] fn client_respects_buffer_limit_post_handshake() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); do_handshake(&mut client, &mut server); client.set_buffer_limit(Some(48)); @@ -2515,7 +2618,7 @@ fn client_respects_buffer_limit_post_handshake() { #[test] fn client_detects_broken_write_vectored_impl() { // see https://github.com/rustls/rustls/issues/2316 - let (mut client, _) = make_pair(KeyType::Rsa2048); + let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); let err = client .write_tls(&mut BrokenWriteVectored) .unwrap_err(); @@ -2543,7 +2646,7 @@ fn client_detects_broken_write_vectored_impl() { #[test] fn buf_read() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); do_handshake(&mut client, &mut server); @@ -2697,35 +2800,35 @@ where #[test] fn server_read_returns_wouldblock_when_no_data() { - let (_, mut server) = make_pair(KeyType::Rsa2048); + let (_, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); assert!(matches!(server.reader().read(&mut [0u8; 1]), Err(err) if err.kind() == io::ErrorKind::WouldBlock)); } #[test] fn client_read_returns_wouldblock_when_no_data() { - let (mut client, _) = make_pair(KeyType::Rsa2048); + let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); assert!(matches!(client.reader().read(&mut [0u8; 1]), Err(err) if err.kind() == io::ErrorKind::WouldBlock)); } #[test] fn server_fill_buf_returns_wouldblock_when_no_data() { - let (_, mut server) = make_pair(KeyType::Rsa2048); + let (_, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); assert!(matches!(server.reader().fill_buf(), Err(err) if err.kind() == io::ErrorKind::WouldBlock)); } #[test] fn client_fill_buf_returns_wouldblock_when_no_data() { - let (mut client, _) = make_pair(KeyType::Rsa2048); + let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); assert!(matches!(client.reader().fill_buf(), Err(err) if err.kind() == io::ErrorKind::WouldBlock)); } #[test] fn new_server_returns_initial_io_state() { - let (_, mut server) = make_pair(KeyType::Rsa2048); + let (_, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); let io_state = server.process_new_packets().unwrap(); println!("IoState is Debug {io_state:?}"); assert_eq!(io_state.plaintext_bytes_to_read(), 0); @@ -2735,7 +2838,7 @@ fn new_server_returns_initial_io_state() { #[test] fn new_client_returns_initial_io_state() { - let (mut client, _) = make_pair(KeyType::Rsa2048); + let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); let io_state = client.process_new_packets().unwrap(); println!("IoState is Debug {io_state:?}"); assert_eq!(io_state.plaintext_bytes_to_read(), 0); @@ -2745,7 +2848,7 @@ fn new_client_returns_initial_io_state() { #[test] fn client_complete_io_for_handshake() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); assert!(client.is_handshaking()); let (rdlen, wrlen) = client @@ -2758,7 +2861,7 @@ fn client_complete_io_for_handshake() { #[test] fn buffered_client_complete_io_for_handshake() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); assert!(client.is_handshaking()); let (rdlen, wrlen) = client @@ -2771,7 +2874,7 @@ fn buffered_client_complete_io_for_handshake() { #[test] fn client_complete_io_for_handshake_eof() { - let (mut client, _) = make_pair(KeyType::Rsa2048); + let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); let mut input = io::Cursor::new(Vec::new()); assert!(client.is_handshaking()); @@ -2783,8 +2886,9 @@ fn client_complete_io_for_handshake_eof() { #[test] fn client_complete_io_for_write() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let (mut client, mut server) = make_pair(*kt); + let (mut client, mut server) = make_pair(*kt, &provider); do_handshake(&mut client, &mut server); @@ -2812,8 +2916,9 @@ fn client_complete_io_for_write() { #[test] fn buffered_client_complete_io_for_write() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let (mut client, mut server) = make_pair(*kt); + let (mut client, mut server) = make_pair(*kt, &provider); do_handshake(&mut client, &mut server); @@ -2841,8 +2946,9 @@ fn buffered_client_complete_io_for_write() { #[test] fn client_complete_io_for_read() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let (mut client, mut server) = make_pair(*kt); + let (mut client, mut server) = make_pair(*kt, &provider); do_handshake(&mut client, &mut server); @@ -2862,8 +2968,9 @@ fn client_complete_io_for_read() { #[test] fn server_complete_io_for_handshake() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let (mut client, mut server) = make_pair(*kt); + let (mut client, mut server) = make_pair(*kt, &provider); assert!(server.is_handshaking()); let (rdlen, wrlen) = server @@ -2877,7 +2984,7 @@ fn server_complete_io_for_handshake() { #[test] fn server_complete_io_for_handshake_eof() { - let (_, mut server) = make_pair(KeyType::Rsa2048); + let (_, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); let mut input = io::Cursor::new(Vec::new()); assert!(server.is_handshaking()); @@ -2889,8 +2996,9 @@ fn server_complete_io_for_handshake_eof() { #[test] fn server_complete_io_for_write() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let (mut client, mut server) = make_pair(*kt); + let (mut client, mut server) = make_pair(*kt, &provider); do_handshake(&mut client, &mut server); @@ -2917,8 +3025,9 @@ fn server_complete_io_for_write() { #[test] fn server_complete_io_for_write_eof() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let (mut client, mut server) = make_pair(*kt); + let (mut client, mut server) = make_pair(*kt, &provider); do_handshake(&mut client, &mut server); @@ -2973,8 +3082,9 @@ impl std::io::Read for EofWriter { #[test] fn server_complete_io_for_read() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let (mut client, mut server) = make_pair(*kt); + let (mut client, mut server) = make_pair(*kt, &provider); do_handshake(&mut client, &mut server); @@ -3011,8 +3121,9 @@ enum StreamKind { } fn test_client_stream_write(stream_kind: StreamKind) { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let (mut client, mut server) = make_pair(*kt); + let (mut client, mut server) = make_pair(*kt, &provider); let data = b"hello"; { let mut pipe = OtherSession::new(&mut server); @@ -3027,8 +3138,9 @@ fn test_client_stream_write(stream_kind: StreamKind) { } fn test_server_stream_write(stream_kind: StreamKind) { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let (mut client, mut server) = make_pair(*kt); + let (mut client, mut server) = make_pair(*kt, &provider); let data = b"hello"; { let mut pipe = OtherSession::new(&mut client); @@ -3095,8 +3207,9 @@ fn test_stream_read(read_kind: ReadKind, mut stream: impl BufRead, data: &[u8]) } fn test_client_stream_read(stream_kind: StreamKind, read_kind: ReadKind) { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let (mut client, mut server) = make_pair(*kt); + let (mut client, mut server) = make_pair(*kt, &provider); let data = b"world"; server.writer().write_all(data).unwrap(); @@ -3115,8 +3228,9 @@ fn test_client_stream_read(stream_kind: StreamKind, read_kind: ReadKind) { } fn test_server_stream_read(stream_kind: StreamKind, read_kind: ReadKind) { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let (mut client, mut server) = make_pair(*kt); + let (mut client, mut server) = make_pair(*kt, &provider); let data = b"world"; client.writer().write_all(data).unwrap(); @@ -3136,7 +3250,7 @@ fn test_server_stream_read(stream_kind: StreamKind, read_kind: ReadKind) { #[test] fn test_client_write_and_vectored_write_equivalence() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); do_handshake(&mut client, &mut server); const N: usize = 1000; @@ -3191,7 +3305,7 @@ impl io::Write for FailsWrites { #[test] fn stream_write_reports_underlying_io_error_before_plaintext_processed() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); do_handshake(&mut client, &mut server); let mut pipe = FailsWrites { @@ -3211,7 +3325,7 @@ fn stream_write_reports_underlying_io_error_before_plaintext_processed() { #[test] fn stream_write_swallows_underlying_io_error_after_plaintext_processed() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); do_handshake(&mut client, &mut server); let mut pipe = FailsWrites { @@ -3346,23 +3460,23 @@ fn server_streamowned_handshake_error() { #[test] fn server_config_is_clone() { - let _ = make_server_config(KeyType::Rsa2048); + let _ = make_server_config(KeyType::Rsa2048, &provider::default_provider()); } #[test] fn client_config_is_clone() { - let _ = make_client_config(KeyType::Rsa2048); + let _ = make_client_config(KeyType::Rsa2048, &provider::default_provider()); } #[test] fn client_connection_is_debug() { - let (client, _) = make_pair(KeyType::Rsa2048); + let (client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); println!("{client:?}"); } #[test] fn server_connection_is_debug() { - let (_, server) = make_pair(KeyType::Rsa2048); + let (_, server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); println!("{server:?}"); } @@ -3387,14 +3501,16 @@ fn server_complete_io_for_handshake_ending_with_alert() { #[test] fn server_exposes_offered_sni() { let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(kt, &[version]); + let client_config = make_client_config_with_versions(kt, &[version], &provider); let mut client = ClientConnection::new( Arc::new(client_config), server_name("second.testserver.com"), ) .unwrap(); - let mut server = ServerConnection::new(Arc::new(make_server_config(kt))).unwrap(); + let mut server = + ServerConnection::new(Arc::new(make_server_config(kt, &provider))).unwrap(); assert_eq!(None, server.server_name()); do_handshake(&mut client, &mut server); @@ -3406,14 +3522,16 @@ fn server_exposes_offered_sni() { fn server_exposes_offered_sni_smashed_to_lowercase() { // webpki actually does this for us in its DnsName type let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(kt, &[version]); + let client_config = make_client_config_with_versions(kt, &[version], &provider); let mut client = ClientConnection::new( Arc::new(client_config), server_name("SECOND.TESTServer.com"), ) .unwrap(); - let mut server = ServerConnection::new(Arc::new(make_server_config(kt))).unwrap(); + let mut server = + ServerConnection::new(Arc::new(make_server_config(kt, &provider))).unwrap(); assert_eq!(None, server.server_name()); do_handshake(&mut client, &mut server); @@ -3424,14 +3542,15 @@ fn server_exposes_offered_sni_smashed_to_lowercase() { #[test] fn server_exposes_offered_sni_even_if_resolver_fails() { let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); let resolver = rustls::server::ResolvesServerCertUsingSni::new(); - let mut server_config = make_server_config(kt); + let mut server_config = make_server_config(kt, &provider); server_config.cert_resolver = Arc::new(resolver); let server_config = Arc::new(server_config); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(kt, &[version]); + let client_config = make_client_config_with_versions(kt, &[version], &provider); let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); let mut client = ClientConnection::new(Arc::new(client_config), server_name("thisdoesNOTexist.com")) @@ -3452,6 +3571,7 @@ fn server_exposes_offered_sni_even_if_resolver_fails() { #[test] fn sni_resolver_works() { let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); let mut resolver = rustls::server::ResolvesServerCertUsingSni::new(); let signing_key = RsaSigningKey::new(&kt.get_key()).unwrap(); let signing_key: Arc = Arc::new(signing_key); @@ -3462,19 +3582,22 @@ fn sni_resolver_works() { ) .unwrap(); - let mut server_config = make_server_config(kt); + let mut server_config = make_server_config(kt, &provider); server_config.cert_resolver = Arc::new(resolver); let server_config = Arc::new(server_config); let mut server1 = ServerConnection::new(Arc::clone(&server_config)).unwrap(); - let mut client1 = - ClientConnection::new(Arc::new(make_client_config(kt)), server_name("localhost")).unwrap(); + let mut client1 = ClientConnection::new( + Arc::new(make_client_config(kt, &provider)), + server_name("localhost"), + ) + .unwrap(); let err = do_handshake_until_error(&mut client1, &mut server1); assert_eq!(err, Ok(())); let mut server2 = ServerConnection::new(Arc::clone(&server_config)).unwrap(); let mut client2 = ClientConnection::new( - Arc::new(make_client_config(kt)), + Arc::new(make_client_config(kt, &provider)), server_name("notlocalhost"), ) .unwrap(); @@ -3538,6 +3661,7 @@ fn certificate_error_expecting_name(expected: &str) -> CertificateError { #[test] fn sni_resolver_lower_cases_configured_names() { let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); let mut resolver = rustls::server::ResolvesServerCertUsingSni::new(); let signing_key = RsaSigningKey::new(&kt.get_key()).unwrap(); let signing_key: Arc = Arc::new(signing_key); @@ -3550,13 +3674,16 @@ fn sni_resolver_lower_cases_configured_names() { ) ); - let mut server_config = make_server_config(kt); + let mut server_config = make_server_config(kt, &provider); server_config.cert_resolver = Arc::new(resolver); let server_config = Arc::new(server_config); let mut server1 = ServerConnection::new(Arc::clone(&server_config)).unwrap(); - let mut client1 = - ClientConnection::new(Arc::new(make_client_config(kt)), server_name("localhost")).unwrap(); + let mut client1 = ClientConnection::new( + Arc::new(make_client_config(kt, &provider)), + server_name("localhost"), + ) + .unwrap(); let err = do_handshake_until_error(&mut client1, &mut server1); assert_eq!(err, Ok(())); } @@ -3565,6 +3692,7 @@ fn sni_resolver_lower_cases_configured_names() { fn sni_resolver_lower_cases_queried_names() { // actually, the handshake parser does this, but the effect is the same. let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); let mut resolver = rustls::server::ResolvesServerCertUsingSni::new(); let signing_key = RsaSigningKey::new(&kt.get_key()).unwrap(); let signing_key: Arc = Arc::new(signing_key); @@ -3577,13 +3705,16 @@ fn sni_resolver_lower_cases_queried_names() { ) ); - let mut server_config = make_server_config(kt); + let mut server_config = make_server_config(kt, &provider); server_config.cert_resolver = Arc::new(resolver); let server_config = Arc::new(server_config); let mut server1 = ServerConnection::new(Arc::clone(&server_config)).unwrap(); - let mut client1 = - ClientConnection::new(Arc::new(make_client_config(kt)), server_name("LOCALHOST")).unwrap(); + let mut client1 = ClientConnection::new( + Arc::new(make_client_config(kt, &provider)), + server_name("LOCALHOST"), + ) + .unwrap(); let err = do_handshake_until_error(&mut client1, &mut server1); assert_eq!(err, Ok(())); } @@ -3737,9 +3868,11 @@ fn do_exporter_test(client_config: ClientConfig, server_config: ServerConfig) { #[cfg(feature = "tls12")] #[test] fn test_tls12_exporter() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let client_config = make_client_config_with_versions(*kt, &[&rustls::version::TLS12]); - let server_config = make_server_config(*kt); + let client_config = + make_client_config_with_versions(*kt, &[&rustls::version::TLS12], &provider); + let server_config = make_server_config(*kt, &provider); do_exporter_test(client_config, server_config); } @@ -3747,9 +3880,11 @@ fn test_tls12_exporter() { #[test] fn test_tls13_exporter() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let client_config = make_client_config_with_versions(*kt, &[&rustls::version::TLS13]); - let server_config = make_server_config(*kt); + let client_config = + make_client_config_with_versions(*kt, &[&rustls::version::TLS13], &provider); + let server_config = make_server_config(*kt, &provider); do_exporter_test(client_config, server_config); } @@ -3757,9 +3892,10 @@ fn test_tls13_exporter() { #[test] fn test_tls13_exporter_maximum_output_length() { + let provider = provider::default_provider(); let client_config = - make_client_config_with_versions(KeyType::EcdsaP256, &[&rustls::version::TLS13]); - let server_config = make_server_config(KeyType::EcdsaP256); + make_client_config_with_versions(KeyType::EcdsaP256, &[&rustls::version::TLS13], &provider); + let server_config = make_server_config(KeyType::EcdsaP256, &provider); let (mut client, mut server) = make_pair_for_configs(client_config, server_config); do_handshake(&mut client, &mut server); @@ -3885,10 +4021,11 @@ fn test_ciphersuites() -> Vec<( #[test] fn negotiated_ciphersuite_default() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { do_suite_and_kx_test( - make_client_config(*kt), - make_server_config(*kt), + make_client_config(*kt, &provider), + make_server_config(*kt, &provider), find_suite(CipherSuite::TLS13_AES_256_GCM_SHA384), expected_kx_for_version(&rustls::version::TLS13), ProtocolVersion::TLSv1_3, @@ -3923,7 +4060,7 @@ fn negotiated_ciphersuite_client() { do_suite_and_kx_test( client_config, - make_server_config(kt), + make_server_config(kt, &provider::default_provider()), scs, expected_kx_for_version(version), version.version, @@ -3949,7 +4086,7 @@ fn negotiated_ciphersuite_server() { ); do_suite_and_kx_test( - make_client_config(kt), + make_client_config(kt, &provider::default_provider()), server_config, scs, expected_kx_for_version(version), @@ -4063,12 +4200,14 @@ fn key_log_for_tls12() { let client_key_log = Arc::new(KeyLogToVec::new("client")); let server_key_log = Arc::new(KeyLogToVec::new("server")); + let provider = provider::default_provider(); let kt = KeyType::Rsa2048; - let mut client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS12]); + let mut client_config = + make_client_config_with_versions(kt, &[&rustls::version::TLS12], &provider); client_config.key_log = client_key_log.clone(); let client_config = Arc::new(client_config); - let mut server_config = make_server_config(kt); + let mut server_config = make_server_config(kt, &provider); server_config.key_log = server_key_log.clone(); let server_config = Arc::new(server_config); @@ -4099,12 +4238,14 @@ fn key_log_for_tls13() { let client_key_log = Arc::new(KeyLogToVec::new("client")); let server_key_log = Arc::new(KeyLogToVec::new("server")); + let provider = provider::default_provider(); let kt = KeyType::Rsa2048; - let mut client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13]); + let mut client_config = + make_client_config_with_versions(kt, &[&rustls::version::TLS13], &provider); client_config.key_log = client_key_log.clone(); let client_config = Arc::new(client_config); - let mut server_config = make_server_config(kt); + let mut server_config = make_server_config(kt, &provider); server_config.key_log = server_key_log.clone(); let server_config = Arc::new(server_config); @@ -4171,7 +4312,7 @@ fn key_log_for_tls13() { #[test] fn vectored_write_for_server_appdata() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); do_handshake(&mut client, &mut server); server @@ -4196,7 +4337,7 @@ fn vectored_write_for_server_appdata() { #[test] fn vectored_write_for_client_appdata() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); do_handshake(&mut client, &mut server); client @@ -4221,10 +4362,11 @@ fn vectored_write_for_client_appdata() { #[test] fn vectored_write_for_server_handshake_with_half_rtt_data() { - let mut server_config = make_server_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider); server_config.send_half_rtt_data = true; let (mut client, mut server) = make_pair_for_configs( - make_client_config_with_auth(KeyType::Rsa2048), + make_client_config_with_auth(KeyType::Rsa2048, &provider), server_config, ); @@ -4266,7 +4408,7 @@ fn vectored_write_for_server_handshake_with_half_rtt_data() { fn check_half_rtt_does_not_work(server_config: ServerConfig) { let (mut client, mut server) = make_pair_for_configs( - make_client_config_with_auth(KeyType::Rsa2048), + make_client_config_with_auth(KeyType::Rsa2048, &provider::default_provider()), server_config, ); @@ -4312,21 +4454,24 @@ fn check_half_rtt_does_not_work(server_config: ServerConfig) { #[test] fn vectored_write_for_server_handshake_no_half_rtt_with_client_auth() { - let mut server_config = make_server_config_with_mandatory_client_auth(KeyType::Rsa2048); + let mut server_config = make_server_config_with_mandatory_client_auth( + KeyType::Rsa2048, + &provider::default_provider(), + ); server_config.send_half_rtt_data = true; // ask even though it will be ignored check_half_rtt_does_not_work(server_config); } #[test] fn vectored_write_for_server_handshake_no_half_rtt_by_default() { - let server_config = make_server_config(KeyType::Rsa2048); + let server_config = make_server_config(KeyType::Rsa2048, &provider::default_provider()); assert!(!server_config.send_half_rtt_data); check_half_rtt_does_not_work(server_config); } #[test] fn vectored_write_for_client_handshake() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); client .writer() @@ -4363,7 +4508,7 @@ fn vectored_write_for_client_handshake() { #[test] fn vectored_write_with_slow_client() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); client.set_buffer_limit(Some(32)); @@ -4596,10 +4741,11 @@ impl rustls::client::ClientSessionStore for ClientStorage { #[test] fn tls13_stateful_resumption() { let kt = KeyType::Rsa2048; - let client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13]); + let provider = provider::default_provider(); + let client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13], &provider); let client_config = Arc::new(client_config); - let mut server_config = make_server_config(kt); + let mut server_config = make_server_config(kt, &provider); let storage = Arc::new(ServerStorage::new()); server_config.session_storage = storage.clone(); let server_config = Arc::new(server_config); @@ -4657,10 +4803,11 @@ fn tls13_stateful_resumption() { #[test] fn tls13_stateless_resumption() { let kt = KeyType::Rsa2048; - let client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13]); + let provider = provider::default_provider(); + let client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13], &provider); let client_config = Arc::new(client_config); - let mut server_config = make_server_config(kt); + let mut server_config = make_server_config(kt, &provider); server_config.ticketer = provider::Ticketer::new().unwrap(); let storage = Arc::new(ServerStorage::new()); server_config.session_storage = storage.clone(); @@ -4718,17 +4865,18 @@ fn tls13_stateless_resumption() { #[test] fn early_data_not_available() { - let (mut client, _) = make_pair(KeyType::Rsa2048); + let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); assert!(client.early_data().is_none()); } fn early_data_configs() -> (Arc, Arc) { let kt = KeyType::Rsa2048; - let mut client_config = make_client_config(kt); + let provider = provider::default_provider(); + let mut client_config = make_client_config(kt, &provider); client_config.enable_early_data = true; client_config.resumption = Resumption::store(Arc::new(ClientStorage::new())); - let mut server_config = make_server_config(kt); + let mut server_config = make_server_config(kt, &provider); server_config.max_early_data_size = 1234; (Arc::new(client_config), Arc::new(server_config)) } @@ -4778,7 +4926,11 @@ fn early_data_is_available_on_resumption() { #[test] fn early_data_not_available_on_server_before_client_hello() { - let mut server = ServerConnection::new(Arc::new(make_server_config(KeyType::Rsa2048))).unwrap(); + let mut server = ServerConnection::new(Arc::new(make_server_config( + KeyType::Rsa2048, + &provider::default_provider(), + ))) + .unwrap(); assert!(server.early_data().is_none()); } @@ -5029,10 +5181,13 @@ mod test_quic { } let kt = KeyType::Rsa2048; - let mut client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13]); + let provider = provider::default_provider(); + let mut client_config = + make_client_config_with_versions(kt, &[&rustls::version::TLS13], &provider); client_config.enable_early_data = true; let client_config = Arc::new(client_config); - let mut server_config = make_server_config_with_versions(kt, &[&rustls::version::TLS13]); + let mut server_config = + make_server_config_with_versions(kt, &[&rustls::version::TLS13], &provider); server_config.max_early_data_size = 0xffffffff; let server_config = Arc::new(server_config); let client_params = &b"client params"[..]; @@ -5233,13 +5388,15 @@ mod test_quic { fn test_quic_rejects_missing_alpn() { let client_params = &b"client params"[..]; let server_params = &b"server params"[..]; + let provider = provider::default_provider(); for &kt in ALL_KEY_TYPES { - let client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13]); + let client_config = + make_client_config_with_versions(kt, &[&rustls::version::TLS13], &provider); let client_config = Arc::new(client_config); let mut server_config = - make_server_config_with_versions(kt, &[&rustls::version::TLS13]); + make_server_config_with_versions(kt, &[&rustls::version::TLS13], &provider); server_config.alpn_protocols = vec!["foo".into()]; let server_config = Arc::new(server_config); @@ -5271,8 +5428,12 @@ mod test_quic { #[cfg(feature = "tls12")] #[test] fn test_quic_no_tls13_error() { - let mut client_config = - make_client_config_with_versions(KeyType::Ed25519, &[&rustls::version::TLS12]); + let provider = provider::default_provider(); + let mut client_config = make_client_config_with_versions( + KeyType::Ed25519, + &[&rustls::version::TLS12], + &provider, + ); client_config.alpn_protocols = vec!["foo".into()]; let client_config = Arc::new(client_config); @@ -5286,8 +5447,11 @@ mod test_quic { .is_err() ); - let mut server_config = - make_server_config_with_versions(KeyType::Ed25519, &[&rustls::version::TLS12]); + let mut server_config = make_server_config_with_versions( + KeyType::Ed25519, + &[&rustls::version::TLS12], + &provider, + ); server_config.alpn_protocols = vec!["foo".into()]; let server_config = Arc::new(server_config); @@ -5303,8 +5467,12 @@ mod test_quic { #[test] fn test_quic_invalid_early_data_size() { - let mut server_config = - make_server_config_with_versions(KeyType::Ed25519, &[&rustls::version::TLS13]); + let provider = provider::default_provider(); + let mut server_config = make_server_config_with_versions( + KeyType::Ed25519, + &[&rustls::version::TLS13], + &provider, + ); server_config.alpn_protocols = vec!["foo".into()]; let cases = [ @@ -5331,8 +5499,12 @@ mod test_quic { #[test] fn test_quic_server_no_params_received() { - let server_config = - make_server_config_with_versions(KeyType::Ed25519, &[&rustls::version::TLS13]); + let provider = provider::default_provider(); + let server_config = make_server_config_with_versions( + KeyType::Ed25519, + &[&rustls::version::TLS13], + &provider, + ); let server_config = Arc::new(server_config); let mut server = quic::ServerConnection::new( @@ -5353,8 +5525,12 @@ mod test_quic { #[test] fn test_quic_server_no_tls12() { - let mut server_config = - make_server_config_with_versions(KeyType::Ed25519, &[&rustls::version::TLS13]); + let provider = provider::default_provider(); + let mut server_config = make_server_config_with_versions( + KeyType::Ed25519, + &[&rustls::version::TLS13], + &provider, + ); server_config.alpn_protocols = vec!["foo".into()]; let server_config = Arc::new(server_config); @@ -5577,9 +5753,12 @@ mod test_quic { #[test] fn test_quic_exporter() { + let provider = provider::default_provider(); for &kt in ALL_KEY_TYPES { - let client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13]); - let server_config = make_server_config_with_versions(kt, &[&rustls::version::TLS13]); + let client_config = + make_client_config_with_versions(kt, &[&rustls::version::TLS13], &provider); + let server_config = + make_server_config_with_versions(kt, &[&rustls::version::TLS13], &provider); do_exporter_test(client_config, server_config); } @@ -5588,8 +5767,11 @@ mod test_quic { #[test] fn test_fragmented_append() { // Create a QUIC client connection. - let client_config = - make_client_config_with_versions(KeyType::Rsa2048, &[&rustls::version::TLS13]); + let client_config = make_client_config_with_versions( + KeyType::Rsa2048, + &[&rustls::version::TLS13], + &provider::default_provider(), + ); let client_config = Arc::new(client_config); let mut client = quic::ClientConnection::new( Arc::clone(&client_config), @@ -5627,10 +5809,12 @@ fn test_client_does_not_offer_sha1() { use rustls::internal::msgs::handshake::HandshakePayload; use rustls::internal::msgs::message::{MessagePayload, OutboundOpaqueMessage}; + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(*kt, &[version]); - let (mut client, _) = make_pair_for_configs(client_config, make_server_config(*kt)); + let client_config = make_client_config_with_versions(*kt, &[version], &provider); + let (mut client, _) = + make_pair_for_configs(client_config, make_server_config(*kt, &provider)); assert!(client.wants_write()); let mut buf = [0u8; 262144]; @@ -5662,19 +5846,28 @@ fn test_client_does_not_offer_sha1() { #[test] fn test_client_config_keyshare() { + let provider = provider::default_provider(); let kx_groups = vec![provider::kx_group::SECP384R1]; - let client_config = make_client_config_with_kx_groups(KeyType::Rsa2048, kx_groups.clone()); - let server_config = make_server_config_with_kx_groups(KeyType::Rsa2048, kx_groups); + let client_config = + make_client_config_with_kx_groups(KeyType::Rsa2048, kx_groups.clone(), &provider); + let server_config = make_server_config_with_kx_groups(KeyType::Rsa2048, kx_groups, &provider); let (mut client, mut server) = make_pair_for_configs(client_config, server_config); do_handshake_until_error(&mut client, &mut server).unwrap(); } #[test] fn test_client_config_keyshare_mismatch() { - let client_config = - make_client_config_with_kx_groups(KeyType::Rsa2048, vec![provider::kx_group::SECP384R1]); - let server_config = - make_server_config_with_kx_groups(KeyType::Rsa2048, vec![provider::kx_group::X25519]); + let provider = provider::default_provider(); + let client_config = make_client_config_with_kx_groups( + KeyType::Rsa2048, + vec![provider::kx_group::SECP384R1], + &provider, + ); + let server_config = make_server_config_with_kx_groups( + KeyType::Rsa2048, + vec![provider::kx_group::X25519], + &provider, + ); let (mut client, mut server) = make_pair_for_configs(client_config, server_config); assert!(do_handshake_until_error(&mut client, &mut server).is_err()); } @@ -5682,18 +5875,23 @@ fn test_client_config_keyshare_mismatch() { #[cfg(feature = "tls12")] #[test] fn test_client_sends_helloretryrequest() { + let provider = provider::default_provider(); // client sends a secp384r1 key share let mut client_config = make_client_config_with_kx_groups( KeyType::Rsa2048, vec![provider::kx_group::SECP384R1, provider::kx_group::X25519], + &provider, ); let storage = Arc::new(ClientStorage::new()); client_config.resumption = Resumption::store(storage.clone()); // but server only accepts x25519, so a HRR is required - let server_config = - make_server_config_with_kx_groups(KeyType::Rsa2048, vec![provider::kx_group::X25519]); + let server_config = make_server_config_with_kx_groups( + KeyType::Rsa2048, + vec![provider::kx_group::X25519], + &provider, + ); let (mut client, mut server) = make_pair_for_configs(client_config, server_config); @@ -5840,13 +6038,18 @@ fn test_client_rejects_hrr_with_varied_session_id() { }; // client prefers a secp384r1 key share, server only accepts x25519 + let provider = provider::default_provider(); let client_config = make_client_config_with_kx_groups( KeyType::Rsa2048, vec![provider::kx_group::SECP384R1, provider::kx_group::X25519], + &provider, ); - let server_config = - make_server_config_with_kx_groups(KeyType::Rsa2048, vec![provider::kx_group::X25519]); + let server_config = make_server_config_with_kx_groups( + KeyType::Rsa2048, + vec![provider::kx_group::X25519], + &provider, + ); let (client, server) = make_pair_for_configs(client_config, server_config); let (mut client, mut server) = (client.into(), server.into()); @@ -5874,20 +6077,27 @@ fn test_client_rejects_hrr_with_varied_session_id() { fn test_client_attempts_to_use_unsupported_kx_group() { // common to both client configs let shared_storage = Arc::new(ClientStorage::new()); + let provider = provider::default_provider(); // first, client sends a secp-256 share and server agrees. secp-256 is inserted // into kx group cache. - let mut client_config_1 = - make_client_config_with_kx_groups(KeyType::Rsa2048, vec![provider::kx_group::SECP256R1]); + let mut client_config_1 = make_client_config_with_kx_groups( + KeyType::Rsa2048, + vec![provider::kx_group::SECP256R1], + &provider, + ); client_config_1.resumption = Resumption::store(shared_storage.clone()); // second, client only supports secp-384 and so kx group cache // contains an unusable value. - let mut client_config_2 = - make_client_config_with_kx_groups(KeyType::Rsa2048, vec![provider::kx_group::SECP384R1]); + let mut client_config_2 = make_client_config_with_kx_groups( + KeyType::Rsa2048, + vec![provider::kx_group::SECP384R1], + &provider, + ); client_config_2.resumption = Resumption::store(shared_storage.clone()); - let server_config = make_server_config(KeyType::Rsa2048); + let server_config = make_server_config(KeyType::Rsa2048, &provider); // first handshake let (mut client_1, mut server) = make_pair_for_configs(client_config_1, server_config.clone()); @@ -5927,11 +6137,15 @@ fn test_client_sends_share_for_less_preferred_group() { // common to both client configs let shared_storage = Arc::new(ClientStorage::new()); + let provider = provider::default_provider(); // first, client sends a secp384r1 share and server agrees. secp384r1 is inserted // into kx group cache. - let mut client_config_1 = - make_client_config_with_kx_groups(KeyType::Rsa2048, vec![provider::kx_group::SECP384R1]); + let mut client_config_1 = make_client_config_with_kx_groups( + KeyType::Rsa2048, + vec![provider::kx_group::SECP384R1], + &provider, + ); client_config_1.resumption = Resumption::store(shared_storage.clone()); // second, client supports (x25519, secp384r1) and so kx group cache @@ -5939,11 +6153,15 @@ fn test_client_sends_share_for_less_preferred_group() { let mut client_config_2 = make_client_config_with_kx_groups( KeyType::Rsa2048, vec![provider::kx_group::X25519, provider::kx_group::SECP384R1], + &provider, ); client_config_2.resumption = Resumption::store(shared_storage.clone()); - let server_config = - make_server_config_with_kx_groups(KeyType::Rsa2048, provider::ALL_KX_GROUPS.to_vec()); + let server_config = make_server_config_with_kx_groups( + KeyType::Rsa2048, + provider::ALL_KX_GROUPS.to_vec(), + &provider, + ); // first handshake let (mut client_1, mut server) = make_pair_for_configs(client_config_1, server_config.clone()); @@ -6010,12 +6228,13 @@ fn test_client_sends_share_for_less_preferred_group() { #[test] fn test_tls13_client_resumption_does_not_reuse_tickets() { let shared_storage = Arc::new(ClientStorage::new()); + let provider = provider::default_provider(); - let mut client_config = make_client_config(KeyType::Rsa2048); + let mut client_config = make_client_config(KeyType::Rsa2048, &provider); client_config.resumption = Resumption::store(shared_storage.clone()); let client_config = Arc::new(client_config); - let mut server_config = make_server_config(KeyType::Rsa2048); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider); server_config.send_tls13_tickets = 5; let server_config = Arc::new(server_config); @@ -6093,8 +6312,9 @@ fn test_client_mtu_reduction() { collector.writevs[0].clone() } + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { - let mut client_config = make_client_config(*kt); + let mut client_config = make_client_config(*kt, &provider); client_config.max_fragment_size = Some(64); let mut client = ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); @@ -6107,11 +6327,14 @@ fn test_client_mtu_reduction() { #[test] fn test_server_mtu_reduction() { - let mut server_config = make_server_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider); server_config.max_fragment_size = Some(64); server_config.send_half_rtt_data = true; - let (mut client, mut server) = - make_pair_for_configs(make_client_config(KeyType::Rsa2048), server_config); + let (mut client, mut server) = make_pair_for_configs( + make_client_config(KeyType::Rsa2048, &provider), + server_config, + ); let big_data = [0u8; 2048]; server @@ -6154,7 +6377,8 @@ fn test_server_mtu_reduction() { } fn check_client_max_fragment_size(size: usize) -> Option { - let mut client_config = make_client_config(KeyType::Ed25519); + let provider = provider::default_provider(); + let mut client_config = make_client_config(KeyType::Ed25519, &provider); client_config.max_fragment_size = Some(size); ClientConnection::new(Arc::new(client_config), server_name("localhost")).err() } @@ -6183,14 +6407,16 @@ fn bad_client_max_fragment_sizes() { #[test] fn handshakes_complete_and_data_flows_with_gratuitious_max_fragment_sizes() { // general exercising of msgs::fragmenter and msgs::deframer + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { for version in rustls::ALL_VERSIONS { // no hidden significance to these numbers for frag_size in [37, 61, 101, 257] { println!("test kt={kt:?} version={version:?} frag={frag_size:?}"); - let mut client_config = make_client_config_with_versions(*kt, &[version]); + let mut client_config = + make_client_config_with_versions(*kt, &[version], &provider); client_config.max_fragment_size = Some(frag_size); - let mut server_config = make_server_config(*kt); + let mut server_config = make_server_config(*kt, &provider); server_config.max_fragment_size = Some(frag_size); let (mut client, mut server) = make_pair_for_configs(client_config, server_config); @@ -6236,7 +6462,7 @@ fn connection_types_are_not_huge() { #[test] fn test_server_rejects_clients_without_any_kx_groups() { - let (_, mut server) = make_pair(KeyType::Rsa2048); + let (_, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); server .read_tls( &mut encoding::message_framing( @@ -6269,7 +6495,11 @@ fn test_server_rejects_clients_without_any_kx_groups() { fn test_server_rejects_clients_without_any_kx_group_overlap() { for version in rustls::ALL_VERSIONS { let (mut client, mut server) = make_pair_for_configs( - make_client_config_with_kx_groups(KeyType::Rsa2048, vec![provider::kx_group::X25519]), + make_client_config_with_kx_groups( + KeyType::Rsa2048, + vec![provider::kx_group::X25519], + &provider::default_provider(), + ), finish_server_config( KeyType::Rsa2048, ServerConfig::builder_with_provider( @@ -6312,7 +6542,7 @@ fn test_client_rejects_illegal_tls13_ccs() { Altered::InPlace } - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); transfer(&mut client, &mut server); server.process_new_packets().unwrap(); @@ -6331,7 +6561,8 @@ fn test_client_rejects_illegal_tls13_ccs() { #[test] fn test_client_rejects_no_extended_master_secret_extension_when_require_ems_or_fips() { let key_type = KeyType::Rsa2048; - let mut client_config = make_client_config(key_type); + let provider = provider::default_provider(); + let mut client_config = make_client_config(key_type, &provider); if provider_is_fips() { assert!(client_config.require_ems); } else { @@ -6339,7 +6570,7 @@ fn test_client_rejects_no_extended_master_secret_extension_when_require_ems_or_f } let mut server_config = finish_server_config( key_type, - server_config_builder_with_versions(&[&rustls::version::TLS12]), + server_config_builder_with_versions(&[&rustls::version::TLS12], &provider), ); server_config.require_ems = false; let (client, server) = make_pair_for_configs(client_config, server_config); @@ -6359,10 +6590,11 @@ fn test_client_rejects_no_extended_master_secret_extension_when_require_ems_or_f #[test] fn test_server_rejects_no_extended_master_secret_extension_when_require_ems_or_fips() { let key_type = KeyType::Rsa2048; - let client_config = make_client_config(key_type); + let provider = provider::default_provider(); + let client_config = make_client_config(key_type, &provider); let mut server_config = finish_server_config( key_type, - server_config_builder_with_versions(&[&rustls::version::TLS12]), + server_config_builder_with_versions(&[&rustls::version::TLS12], &provider), ); if provider_is_fips() { assert!(server_config.require_ems); @@ -6398,19 +6630,20 @@ fn remove_ems_request(msg: &mut Message) -> Altered { #[cfg(feature = "tls12")] #[test] fn test_client_tls12_no_resume_after_server_downgrade() { - let mut client_config = common::make_client_config(KeyType::Ed25519); + let provider = provider::default_provider(); + let mut client_config = common::make_client_config(KeyType::Ed25519, &provider); let client_storage = Arc::new(ClientStorage::new()); client_config.resumption = Resumption::store(client_storage.clone()); let client_config = Arc::new(client_config); let server_config_1 = Arc::new(common::finish_server_config( KeyType::Ed25519, - server_config_builder_with_versions(&[&rustls::version::TLS13]), + server_config_builder_with_versions(&[&rustls::version::TLS13], &provider), )); let mut server_config_2 = common::finish_server_config( KeyType::Ed25519, - server_config_builder_with_versions(&[&rustls::version::TLS12]), + server_config_builder_with_versions(&[&rustls::version::TLS12], &provider), ); server_config_2.session_storage = Arc::new(rustls::server::NoServerSessionStorage {}); @@ -6493,7 +6726,11 @@ fn test_client_with_custom_verifier_can_accept_ecdsa_sha1_signatures() { .dangerous() .with_custom_certificate_verifier(Arc::new(MockServerVerifier::accepts_anything())) .with_no_client_auth(); - let server_config = make_server_config_with_kx_groups(KeyType::EcdsaP256, kx_groups.to_vec()); + let server_config = make_server_config_with_kx_groups( + KeyType::EcdsaP256, + kx_groups.to_vec(), + &provider::default_provider(), + ); let (mut client, mut server) = make_pair_for_configs(client_config, server_config); transfer(&mut client, &mut server); server.process_new_packets().unwrap(); @@ -6510,12 +6747,13 @@ fn test_client_with_custom_verifier_can_accept_ecdsa_sha1_signatures() { fn test_acceptor() { use rustls::server::Acceptor; - let client_config = Arc::new(make_client_config(KeyType::Ed25519)); + let provider = provider::default_provider(); + let client_config = Arc::new(make_client_config(KeyType::Ed25519, &provider)); let mut client = ClientConnection::new(client_config, server_name("localhost")).unwrap(); let mut buf = Vec::new(); client.write_tls(&mut buf).unwrap(); - let server_config = Arc::new(make_server_config(KeyType::Ed25519)); + let server_config = Arc::new(make_server_config(KeyType::Ed25519, &provider)); let mut acceptor = Acceptor::default(); acceptor .read_tls(&mut buf.as_slice()) @@ -6648,11 +6886,12 @@ fn test_no_warning_logging_during_successful_sessions() { CountingLogger::install(); CountingLogger::reset(); + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES { for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(*kt, &[version]); + let client_config = make_client_config_with_versions(*kt, &[version], &provider); let (mut client, mut server) = - make_pair_for_configs(client_config, make_server_config(*kt)); + make_pair_for_configs(client_config, make_server_config(*kt, &provider)); do_handshake(&mut client, &mut server); } } @@ -6691,6 +6930,7 @@ fn test_secret_extraction_enabled() { // We support 3 different AEAD algorithms (AES-128-GCM mode, AES-256-GCM, and // Chacha20Poly1305), so that's 2*3 = 6 combinations to test. let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); for suite in [ cipher_suite::TLS13_AES_128_GCM_SHA256, cipher_suite::TLS13_AES_256_GCM_SHA384, @@ -6708,7 +6948,7 @@ fn test_secret_extraction_enabled() { let mut server_config = ServerConfig::builder_with_provider( CryptoProvider { cipher_suites: vec![suite], - ..provider::default_provider() + ..provider.clone() } .into(), ) @@ -6721,7 +6961,7 @@ fn test_secret_extraction_enabled() { server_config.enable_secret_extraction = true; let server_config = Arc::new(server_config); - let mut client_config = make_client_config(kt); + let mut client_config = make_client_config(kt, &provider); client_config.enable_secret_extraction = true; let (mut client, mut server) = @@ -6854,7 +7094,7 @@ fn test_secret_extraction_disabled_or_too_early() { server_config.enable_secret_extraction = server_enable; let server_config = Arc::new(server_config); - let mut client_config = make_client_config(kt); + let mut client_config = make_client_config(kt, &provider); client_config.enable_secret_extraction = client_enable; let client_config = Arc::new(client_config); @@ -6896,12 +7136,13 @@ fn test_secret_extraction_disabled_or_too_early() { #[test] fn test_received_plaintext_backpressure() { let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); let server_config = Arc::new( ServerConfig::builder_with_provider( CryptoProvider { cipher_suites: vec![cipher_suite::TLS13_AES_128_GCM_SHA256], - ..provider::default_provider() + ..provider.clone() } .into(), ) @@ -6912,7 +7153,7 @@ fn test_received_plaintext_backpressure() { .unwrap(), ); - let client_config = Arc::new(make_client_config(kt)); + let client_config = Arc::new(make_client_config(kt, &provider)); let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); do_handshake(&mut client, &mut server); @@ -7153,12 +7394,13 @@ fn test_client_removes_tls12_session_if_server_sends_undecryptable_first_message } } + let provider = provider::default_provider(); let mut client_config = - make_client_config_with_versions(KeyType::Rsa2048, &[&rustls::version::TLS12]); + make_client_config_with_versions(KeyType::Rsa2048, &[&rustls::version::TLS12], &provider); let storage = Arc::new(ClientStorage::new()); client_config.resumption = Resumption::store(storage.clone()); let client_config = Arc::new(client_config); - let server_config = Arc::new(make_server_config(KeyType::Rsa2048)); + let server_config = Arc::new(make_server_config(KeyType::Rsa2048, &provider)); // successful handshake to allow resumption let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); @@ -7194,7 +7436,7 @@ fn test_client_removes_tls12_session_if_server_sends_undecryptable_first_message #[test] fn test_client_fips_service_indicator() { assert_eq!( - make_client_config(KeyType::Rsa2048).fips(), + make_client_config(KeyType::Rsa2048, &provider::default_provider()).fips(), provider_is_fips() ); } @@ -7202,15 +7444,16 @@ fn test_client_fips_service_indicator() { #[test] fn test_server_fips_service_indicator() { assert_eq!( - make_server_config(KeyType::Rsa2048).fips(), + make_server_config(KeyType::Rsa2048, &provider::default_provider()).fips(), provider_is_fips() ); } #[test] fn test_connection_fips_service_indicator() { - let client_config = Arc::new(make_client_config(KeyType::Rsa2048)); - let server_config = Arc::new(make_server_config(KeyType::Rsa2048)); + let provider = provider::default_provider(); + let client_config = Arc::new(make_client_config(KeyType::Rsa2048, &provider)); + let server_config = Arc::new(make_server_config(KeyType::Rsa2048, &provider)); let conn_pair = make_pair_for_arc_configs(&client_config, &server_config); // Each connection's FIPS status should reflect the FIPS status of the config it was created // from. @@ -7224,7 +7467,7 @@ fn test_client_fips_service_indicator_includes_require_ems() { return; } - let mut client_config = make_client_config(KeyType::Rsa2048); + let mut client_config = make_client_config(KeyType::Rsa2048, &provider::default_provider()); assert!(client_config.fips()); client_config.require_ems = false; assert!(!client_config.fips()); @@ -7236,7 +7479,7 @@ fn test_server_fips_service_indicator_includes_require_ems() { return; } - let mut server_config = make_server_config(KeyType::Rsa2048); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider::default_provider()); assert!(server_config.fips()); server_config.require_ems = false; assert!(!server_config.fips()); @@ -7300,7 +7543,11 @@ fn test_client_fips_service_indicator_includes_ech_hpke_suite() { #[test] fn test_complete_io_errors_if_close_notify_received_too_early() { - let mut server = ServerConnection::new(Arc::new(make_server_config(KeyType::Rsa2048))).unwrap(); + let mut server = ServerConnection::new(Arc::new(make_server_config( + KeyType::Rsa2048, + &provider::default_provider(), + ))) + .unwrap(); let client_hello_followed_by_close_notify_alert = b"\ \x16\x03\x01\x00\xc8\x01\x00\x00\xc4\x03\x03\xec\x12\xdd\x17\x64\ \xa4\x39\xfd\x7e\x8c\x85\x46\xb8\x4d\x1e\xa0\x6e\xb3\xd7\xa0\x51\ @@ -7329,7 +7576,7 @@ fn test_complete_io_errors_if_close_notify_received_too_early() { #[test] fn test_complete_io_with_no_io_needed() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); do_handshake(&mut client, &mut server); client .writer() @@ -7367,7 +7614,7 @@ fn test_complete_io_with_no_io_needed() { #[test] fn test_junk_after_close_notify_received() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); do_handshake(&mut client, &mut server); client .writer() @@ -7410,7 +7657,7 @@ fn test_junk_after_close_notify_received() { #[test] fn test_data_after_close_notify_is_ignored() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); do_handshake(&mut client, &mut server); client @@ -7442,7 +7689,7 @@ fn test_data_after_close_notify_is_ignored() { #[test] fn test_close_notify_sent_prior_to_handshake_complete() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); client.send_close_notify(); assert_eq!( do_handshake_until_error(&mut client, &mut server), @@ -7454,7 +7701,7 @@ fn test_close_notify_sent_prior_to_handshake_complete() { #[test] fn test_subsequent_close_notify_ignored() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); client.send_close_notify(); assert!(transfer(&mut client, &mut server) > 0); @@ -7465,7 +7712,7 @@ fn test_subsequent_close_notify_ignored() { #[test] fn test_second_close_notify_after_handshake() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); do_handshake(&mut client, &mut server); client.send_close_notify(); assert!(transfer(&mut client, &mut server) > 0); @@ -7478,7 +7725,7 @@ fn test_second_close_notify_after_handshake() { #[test] fn test_read_tls_artificial_eof_after_close_notify() { - let (mut client, mut server) = make_pair(KeyType::Rsa2048); + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); do_handshake(&mut client, &mut server); client.send_close_notify(); assert!(transfer(&mut client, &mut server) > 0); @@ -7497,14 +7744,15 @@ fn test_read_tls_artificial_eof_after_close_notify() { fn test_pinned_ocsp_response_given_to_custom_server_cert_verifier() { let ocsp_response = b"hello-ocsp-world!"; let kt = KeyType::EcdsaP256; + let provider = provider::default_provider(); for version in rustls::ALL_VERSIONS { - let server_config = server_config_builder() + let server_config = server_config_builder(&provider) .with_no_client_auth() .with_single_cert_with_ocsp(kt.get_chain(), kt.get_key(), ocsp_response.to_vec()) .unwrap(); - let client_config = client_config_builder_with_versions(&[version]) + let client_config = client_config_builder_with_versions(&[version], &provider) .dangerous() .with_custom_certificate_verifier(Arc::new(MockServerVerifier::expects_ocsp_response( ocsp_response, @@ -7521,9 +7769,10 @@ fn test_pinned_ocsp_response_given_to_custom_server_cert_verifier() { fn test_server_uses_cached_compressed_certificates() { static COMPRESS_COUNT: AtomicUsize = AtomicUsize::new(0); - let mut server_config = make_server_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = make_server_config(&provider, KeyType::Rsa2048); server_config.cert_compressors = vec![&CountingCompressor]; - let mut client_config = make_client_config(KeyType::Rsa2048); + let mut client_config = make_client_config(&provider, KeyType::Rsa2048); client_config.resumption = Resumption::disabled(); let server_config = Arc::new(server_config); @@ -7559,9 +7808,10 @@ fn test_server_uses_cached_compressed_certificates() { #[test] fn test_server_uses_uncompressed_certificate_if_compression_fails() { - let mut server_config = make_server_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider); server_config.cert_compressors = vec![&FailingCompressor]; - let mut client_config = make_client_config(KeyType::Rsa2048); + let mut client_config = make_client_config(KeyType::Rsa2048, &provider); client_config.cert_decompressors = vec![&NeverDecompressor]; let (mut client, mut server) = make_pair_for_configs(client_config, server_config); @@ -7570,9 +7820,11 @@ fn test_server_uses_uncompressed_certificate_if_compression_fails() { #[test] fn test_client_uses_uncompressed_certificate_if_compression_fails() { - let mut server_config = make_server_config_with_mandatory_client_auth(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = + make_server_config_with_mandatory_client_auth(KeyType::Rsa2048, &provider); server_config.cert_decompressors = vec![&NeverDecompressor]; - let mut client_config = make_client_config_with_auth(KeyType::Rsa2048); + let mut client_config = make_client_config_with_auth(KeyType::Rsa2048, &provider); client_config.cert_compressors = vec![&FailingCompressor]; let (mut client, mut server) = make_pair_for_configs(client_config, server_config); @@ -7619,10 +7871,11 @@ impl rustls::compress::CertDecompressor for NeverDecompressor { fn test_server_can_opt_out_of_compression_cache() { static COMPRESS_COUNT: AtomicUsize = AtomicUsize::new(0); - let mut server_config = make_server_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = make_server_config(&provider, KeyType::Rsa2048); server_config.cert_compressors = vec![&AlwaysInteractiveCompressor]; server_config.cert_compression_cache = Arc::new(rustls::compress::CompressionCache::Disabled); - let mut client_config = make_client_config(KeyType::Rsa2048); + let mut client_config = make_client_config(&provider, KeyType::Rsa2048); client_config.resumption = Resumption::disabled(); let server_config = Arc::new(server_config); @@ -7659,9 +7912,10 @@ fn test_server_can_opt_out_of_compression_cache() { #[test] fn test_cert_decompression_by_client_produces_invalid_cert_payload() { - let mut server_config = make_server_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider); server_config.cert_compressors = vec![&IdentityCompressor]; - let mut client_config = make_client_config(KeyType::Rsa2048); + let mut client_config = make_client_config(KeyType::Rsa2048, &provider); client_config.cert_decompressors = vec![&GarbageDecompressor]; let (mut client, mut server) = make_pair_for_configs(client_config, server_config); @@ -7680,9 +7934,11 @@ fn test_cert_decompression_by_client_produces_invalid_cert_payload() { #[test] fn test_cert_decompression_by_server_produces_invalid_cert_payload() { - let mut server_config = make_server_config_with_mandatory_client_auth(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = + make_server_config_with_mandatory_client_auth(KeyType::Rsa2048, &provider); server_config.cert_decompressors = vec![&GarbageDecompressor]; - let mut client_config = make_client_config_with_auth(KeyType::Rsa2048); + let mut client_config = make_client_config_with_auth(KeyType::Rsa2048, &provider); client_config.cert_compressors = vec![&IdentityCompressor]; let (mut client, mut server) = make_pair_for_configs(client_config, server_config); @@ -7701,9 +7957,11 @@ fn test_cert_decompression_by_server_produces_invalid_cert_payload() { #[test] fn test_cert_decompression_by_server_fails() { - let mut server_config = make_server_config_with_mandatory_client_auth(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = + make_server_config_with_mandatory_client_auth(KeyType::Rsa2048, &provider); server_config.cert_decompressors = vec![&FailingDecompressor]; - let mut client_config = make_client_config_with_auth(KeyType::Rsa2048); + let mut client_config = make_client_config_with_auth(KeyType::Rsa2048, &provider); client_config.cert_compressors = vec![&IdentityCompressor]; let (mut client, mut server) = make_pair_for_configs(client_config, server_config); @@ -7723,8 +7981,9 @@ fn test_cert_decompression_by_server_fails() { #[cfg(feature = "zlib")] #[test] fn test_cert_decompression_by_server_would_result_in_excessively_large_cert() { - let server_config = make_server_config_with_mandatory_client_auth(KeyType::Rsa2048); - let mut client_config = make_client_config_with_auth(KeyType::Rsa2048); + let provider = provider::default_provider(); + let server_config = make_server_config_with_mandatory_client_auth(&provider, KeyType::Rsa2048); + let mut client_config = make_client_config_with_auth(&provider, KeyType::Rsa2048); let big_cert = CertificateDer::from(vec![0u8; 0xffff]); let key = provider::default_provider() @@ -7825,9 +8084,10 @@ impl io::Write for FakeStream<'_> { #[test] fn test_illegal_server_renegotiation_attempt_after_tls13_handshake() { + let provider = provider::default_provider(); let client_config = - make_client_config_with_versions(KeyType::Rsa2048, &[&rustls::version::TLS13]); - let mut server_config = make_server_config(KeyType::Rsa2048); + make_client_config_with_versions(KeyType::Rsa2048, &[&rustls::version::TLS13], &provider); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider); server_config.enable_secret_extraction = true; let (mut client, mut server) = make_pair_for_configs(client_config, server_config); @@ -7859,9 +8119,10 @@ fn test_illegal_server_renegotiation_attempt_after_tls13_handshake() { #[cfg(feature = "tls12")] #[test] fn test_illegal_server_renegotiation_attempt_after_tls12_handshake() { + let provider = provider::default_provider(); let client_config = - make_client_config_with_versions(KeyType::Rsa2048, &[&rustls::version::TLS12]); - let mut server_config = make_server_config(KeyType::Rsa2048); + make_client_config_with_versions(KeyType::Rsa2048, &[&rustls::version::TLS12], &provider); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider); server_config.enable_secret_extraction = true; let (mut client, mut server) = make_pair_for_configs(client_config, server_config); @@ -7898,10 +8159,11 @@ fn test_illegal_server_renegotiation_attempt_after_tls12_handshake() { #[test] fn test_illegal_client_renegotiation_attempt_after_tls13_handshake() { + let provider = provider::default_provider(); let mut client_config = - make_client_config_with_versions(KeyType::Rsa2048, &[&rustls::version::TLS13]); + make_client_config_with_versions(KeyType::Rsa2048, &[&rustls::version::TLS13], &provider); client_config.enable_secret_extraction = true; - let server_config = make_server_config(KeyType::Rsa2048); + let server_config = make_server_config(KeyType::Rsa2048, &provider); let (mut client, mut server) = make_pair_for_configs(client_config, server_config); do_handshake(&mut client, &mut server); @@ -7926,9 +8188,10 @@ fn test_illegal_client_renegotiation_attempt_after_tls13_handshake() { #[cfg(feature = "tls12")] #[test] fn test_illegal_client_renegotiation_attempt_during_tls12_handshake() { - let server_config = make_server_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let server_config = make_server_config(KeyType::Rsa2048, &provider); let client_config = - make_client_config_with_versions(KeyType::Rsa2048, &[&rustls::version::TLS12]); + make_client_config_with_versions(KeyType::Rsa2048, &[&rustls::version::TLS12], &provider); let (mut client, mut server) = make_pair_for_configs(client_config, server_config); let mut client_hello = vec![]; @@ -7955,7 +8218,7 @@ fn test_illegal_client_renegotiation_attempt_during_tls12_handshake() { #[test] fn test_refresh_traffic_keys_during_handshake() { - let (mut client, mut server) = make_pair(KeyType::Ed25519); + let (mut client, mut server) = make_pair(KeyType::Ed25519, &provider::default_provider()); assert_eq!( client .refresh_traffic_keys() @@ -7972,7 +8235,7 @@ fn test_refresh_traffic_keys_during_handshake() { #[test] fn test_refresh_traffic_keys() { - let (mut client, mut server) = make_pair(KeyType::Ed25519); + let (mut client, mut server) = make_pair(KeyType::Ed25519, &provider::default_provider()); do_handshake(&mut client, &mut server); fn check_both_directions(client: &mut ClientConnection, server: &mut ServerConnection) { @@ -8181,7 +8444,7 @@ fn tls13_packed_handshake() { #[test] fn large_client_hello() { - let (_, mut server) = make_pair(KeyType::Rsa2048); + let (_, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); let hello = include_bytes!("data/bug2227-clienthello.bin"); let mut cursor = io::Cursor::new(hello); loop { @@ -8222,7 +8485,7 @@ fn hybrid_kx_component_share_offered_if_supported_seperately() { .with_safe_default_protocol_versions() .unwrap(), ); - let server_config = make_server_config(kt); + let server_config = make_server_config(kt, &provider::default_provider()); let (client, server) = make_pair_for_configs(client_config, server_config); let (mut client, mut server) = (client.into(), server.into()); @@ -8248,7 +8511,7 @@ fn hybrid_kx_component_share_not_offered_unless_supported_seperately() { .with_safe_default_protocol_versions() .unwrap(), ); - let server_config = make_server_config(kt); + let server_config = make_server_config(kt, &provider::default_provider()); let (client, server) = make_pair_for_configs(client_config, server_config); let (mut client, mut server) = (client.into(), server.into()); @@ -8274,10 +8537,11 @@ fn hybrid_kx_component_share_offered_but_server_chooses_something_else() { .with_safe_default_protocol_versions() .unwrap(), ); - let server_config = make_server_config(kt); + let provider = provider::default_provider(); + let server_config = make_server_config(kt, &provider); let (mut client_1, mut server) = make_pair_for_configs(client_config, server_config); - let (mut client_2, _) = make_pair(kt); + let (mut client_2, _) = make_pair(kt, &provider); // client_2 supplies the ClientHello, client_1 receives the ServerHello transfer(&mut client_2, &mut server); diff --git a/rustls/tests/client_cert_verifier.rs b/rustls/tests/client_cert_verifier.rs index ea9ef1a121..381c16d430 100644 --- a/rustls/tests/client_cert_verifier.rs +++ b/rustls/tests/client_cert_verifier.rs @@ -37,7 +37,7 @@ fn server_config_with_verifier( kt: KeyType, client_cert_verifier: MockClientVerifier, ) -> ServerConfig { - server_config_builder() + server_config_builder(&provider::default_provider()) .with_client_cert_verifier(Arc::new(client_cert_verifier)) .with_single_cert(kt.get_chain(), kt.get_key()) .unwrap() @@ -53,7 +53,8 @@ fn client_verifier_works() { let server_config = Arc::new(server_config); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(*kt, &[version]); + let client_config = + make_client_config_with_versions_with_auth(*kt, &[version], &provider); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config.clone()), &server_config); let err = do_handshake_until_error(&mut client, &mut server); @@ -73,7 +74,8 @@ fn client_verifier_no_schemes() { let server_config = Arc::new(server_config); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(*kt, &[version]); + let client_config = + make_client_config_with_versions_with_auth(*kt, &[version], &provider); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config.clone()), &server_config); let err = do_handshake_until_error(&mut client, &mut server); @@ -97,7 +99,7 @@ fn client_verifier_no_auth_yes_root() { let server_config = Arc::new(server_config); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(*kt, &[version]); + let client_config = make_client_config_with_versions(*kt, &[version], &provider); let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); let mut client = ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); @@ -125,7 +127,8 @@ fn client_verifier_fails_properly() { let server_config = Arc::new(server_config); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions_with_auth(*kt, &[version]); + let client_config = + make_client_config_with_versions_with_auth(*kt, &[version], &provider); let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); let mut client = ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); diff --git a/rustls/tests/common/mod.rs b/rustls/tests/common/mod.rs index aa40837c8b..f377605ee6 100644 --- a/rustls/tests/common/mod.rs +++ b/rustls/tests/common/mod.rs @@ -454,13 +454,15 @@ impl KeyType { } } -pub fn server_config_builder() -> rustls::ConfigBuilder { +pub fn server_config_builder( + provider: &CryptoProvider, +) -> rustls::ConfigBuilder { // ensure `ServerConfig::builder()` is covered, even though it is // equivalent to `builder_with_provider(provider::provider().into())`. if exactly_one_provider() { rustls::ServerConfig::builder() } else { - rustls::ServerConfig::builder_with_provider(provider::default_provider().into()) + rustls::ServerConfig::builder_with_provider(provider.clone().into()) .with_safe_default_protocol_versions() .unwrap() } @@ -468,23 +470,26 @@ pub fn server_config_builder() -> rustls::ConfigBuilder rustls::ConfigBuilder { if exactly_one_provider() { rustls::ServerConfig::builder_with_protocol_versions(versions) } else { - rustls::ServerConfig::builder_with_provider(provider::default_provider().into()) + rustls::ServerConfig::builder_with_provider(provider.clone().into()) .with_protocol_versions(versions) .unwrap() } } -pub fn client_config_builder() -> rustls::ConfigBuilder { +pub fn client_config_builder( + provider: &CryptoProvider, +) -> rustls::ConfigBuilder { // ensure `ClientConfig::builder()` is covered, even though it is // equivalent to `builder_with_provider(provider::provider().into())`. if exactly_one_provider() { rustls::ClientConfig::builder() } else { - rustls::ClientConfig::builder_with_provider(provider::default_provider().into()) + rustls::ClientConfig::builder_with_provider(provider.clone().into()) .with_safe_default_protocol_versions() .unwrap() } @@ -492,11 +497,12 @@ pub fn client_config_builder() -> rustls::ConfigBuilder rustls::ConfigBuilder { if exactly_one_provider() { rustls::ClientConfig::builder_with_protocol_versions(versions) } else { - rustls::ClientConfig::builder_with_provider(provider::default_provider().into()) + rustls::ClientConfig::builder_with_provider(provider.clone().into()) .with_protocol_versions(versions) .unwrap() } @@ -511,27 +517,29 @@ pub fn finish_server_config( .unwrap() } -pub fn make_server_config(kt: KeyType) -> ServerConfig { - finish_server_config(kt, server_config_builder()) +pub fn make_server_config(kt: KeyType, provider: &CryptoProvider) -> ServerConfig { + finish_server_config(kt, server_config_builder(provider)) } pub fn make_server_config_with_versions( kt: KeyType, versions: &[&'static rustls::SupportedProtocolVersion], + provider: &CryptoProvider, ) -> ServerConfig { - finish_server_config(kt, server_config_builder_with_versions(versions)) + finish_server_config(kt, server_config_builder_with_versions(versions, provider)) } pub fn make_server_config_with_kx_groups( kt: KeyType, kx_groups: Vec<&'static dyn rustls::crypto::SupportedKxGroup>, + provider: &CryptoProvider, ) -> ServerConfig { finish_server_config( kt, ServerConfig::builder_with_provider( CryptoProvider { kx_groups, - ..provider::default_provider() + ..provider.clone() } .into(), ) @@ -558,38 +566,47 @@ pub fn get_client_root_store(kt: KeyType) -> Arc { pub fn make_server_config_with_mandatory_client_auth_crls( kt: KeyType, crls: Vec>, + provider: &CryptoProvider, ) -> ServerConfig { make_server_config_with_client_verifier( kt, - webpki_client_verifier_builder(get_client_root_store(kt)).with_crls(crls), + webpki_client_verifier_builder(get_client_root_store(kt), provider).with_crls(crls), + provider, ) } -pub fn make_server_config_with_mandatory_client_auth(kt: KeyType) -> ServerConfig { +pub fn make_server_config_with_mandatory_client_auth( + kt: KeyType, + provider: &CryptoProvider, +) -> ServerConfig { make_server_config_with_client_verifier( kt, - webpki_client_verifier_builder(get_client_root_store(kt)), + webpki_client_verifier_builder(get_client_root_store(kt), provider), + provider, ) } pub fn make_server_config_with_optional_client_auth( kt: KeyType, crls: Vec>, + provider: &CryptoProvider, ) -> ServerConfig { make_server_config_with_client_verifier( kt, - webpki_client_verifier_builder(get_client_root_store(kt)) + webpki_client_verifier_builder(get_client_root_store(kt), provider) .with_crls(crls) .allow_unknown_revocation_status() .allow_unauthenticated(), + provider, ) } pub fn make_server_config_with_client_verifier( kt: KeyType, verifier_builder: ClientCertVerifierBuilder, + provider: &CryptoProvider, ) -> ServerConfig { - server_config_builder() + server_config_builder(provider) .with_client_cert_verifier(verifier_builder.build().unwrap()) .with_single_cert(kt.get_chain(), kt.get_key()) .unwrap() @@ -607,7 +624,7 @@ pub fn make_server_config_with_raw_key_support( )); client_verifier.expect_raw_public_keys = true; // We don't support tls1.2 for Raw Public Keys, hence the version is hard-coded. - server_config_builder_with_versions(&[&rustls::version::TLS13]) + server_config_builder_with_versions(&[&rustls::version::TLS13], provider) .with_client_cert_verifier(Arc::new(client_verifier)) .with_cert_resolver(server_cert_resolver) } @@ -622,7 +639,7 @@ pub fn make_client_config_with_raw_key_support( .unwrap(), )); // We don't support tls1.2 for Raw Public Keys, hence the version is hard-coded. - client_config_builder_with_versions(&[&rustls::version::TLS13]) + client_config_builder_with_versions(&[&rustls::version::TLS13], provider) .dangerous() .with_custom_certificate_verifier(server_verifier) .with_client_cert_resolver(client_cert_resolver) @@ -681,18 +698,19 @@ pub fn finish_client_config_with_creds( .unwrap() } -pub fn make_client_config(kt: KeyType) -> ClientConfig { - finish_client_config(kt, client_config_builder()) +pub fn make_client_config(kt: KeyType, provider: &CryptoProvider) -> ClientConfig { + finish_client_config(kt, client_config_builder(provider)) } pub fn make_client_config_with_kx_groups( kt: KeyType, kx_groups: Vec<&'static dyn rustls::crypto::SupportedKxGroup>, + provider: &CryptoProvider, ) -> ClientConfig { let builder = ClientConfig::builder_with_provider( CryptoProvider { kx_groups, - ..provider::default_provider() + ..provider.clone() } .into(), ) @@ -704,49 +722,61 @@ pub fn make_client_config_with_kx_groups( pub fn make_client_config_with_versions( kt: KeyType, versions: &[&'static rustls::SupportedProtocolVersion], + provider: &CryptoProvider, ) -> ClientConfig { - finish_client_config(kt, client_config_builder_with_versions(versions)) + finish_client_config(kt, client_config_builder_with_versions(versions, provider)) } -pub fn make_client_config_with_auth(kt: KeyType) -> ClientConfig { - finish_client_config_with_creds(kt, client_config_builder()) +pub fn make_client_config_with_auth(kt: KeyType, provider: &CryptoProvider) -> ClientConfig { + finish_client_config_with_creds(kt, client_config_builder(provider)) } pub fn make_client_config_with_versions_with_auth( kt: KeyType, versions: &[&'static rustls::SupportedProtocolVersion], + provider: &CryptoProvider, ) -> ClientConfig { - finish_client_config_with_creds(kt, client_config_builder_with_versions(versions)) + finish_client_config_with_creds(kt, client_config_builder_with_versions(versions, provider)) } pub fn make_client_config_with_verifier( versions: &[&'static rustls::SupportedProtocolVersion], verifier_builder: ServerCertVerifierBuilder, + provider: &CryptoProvider, ) -> ClientConfig { - client_config_builder_with_versions(versions) + client_config_builder_with_versions(versions, provider) .dangerous() .with_custom_certificate_verifier(verifier_builder.build().unwrap()) .with_no_client_auth() } -pub fn webpki_client_verifier_builder(roots: Arc) -> ClientCertVerifierBuilder { +pub fn webpki_client_verifier_builder( + roots: Arc, + provider: &CryptoProvider, +) -> ClientCertVerifierBuilder { if exactly_one_provider() { WebPkiClientVerifier::builder(roots) } else { - WebPkiClientVerifier::builder_with_provider(roots, provider::default_provider().into()) + WebPkiClientVerifier::builder_with_provider(roots, provider.clone().into()) } } -pub fn webpki_server_verifier_builder(roots: Arc) -> ServerCertVerifierBuilder { +pub fn webpki_server_verifier_builder( + roots: Arc, + provider: &CryptoProvider, +) -> ServerCertVerifierBuilder { if exactly_one_provider() { WebPkiServerVerifier::builder(roots) } else { - WebPkiServerVerifier::builder_with_provider(roots, provider::default_provider().into()) + WebPkiServerVerifier::builder_with_provider(roots, provider.clone().into()) } } -pub fn make_pair(kt: KeyType) -> (ClientConnection, ServerConnection) { - make_pair_for_configs(make_client_config(kt), make_server_config(kt)) +pub fn make_pair(kt: KeyType, provider: &CryptoProvider) -> (ClientConnection, ServerConnection) { + make_pair_for_configs( + make_client_config(kt, provider), + make_server_config(kt, provider), + ) } pub fn make_pair_for_configs( @@ -1251,7 +1281,7 @@ impl MockClientVerifier { provider: &CryptoProvider, ) -> Self { Self { - parent: webpki_client_verifier_builder(get_client_root_store(kt)) + parent: webpki_client_verifier_builder(get_client_root_store(kt), provider) .build() .unwrap(), verified, diff --git a/rustls/tests/key_log_file_env.rs b/rustls/tests/key_log_file_env.rs index 0f2c889549..8f18c1a141 100644 --- a/rustls/tests/key_log_file_env.rs +++ b/rustls/tests/key_log_file_env.rs @@ -37,11 +37,13 @@ use common::{ #[test] fn exercise_key_log_file_for_client() { serialized(|| { - let server_config = Arc::new(make_server_config(KeyType::Rsa2048)); + let provider = provider::default_provider(); + let server_config = Arc::new(make_server_config(KeyType::Rsa2048, &provider)); unsafe { env::set_var("SSLKEYLOGFILE", "./sslkeylogfile.txt") }; for version in rustls::ALL_VERSIONS { - let mut client_config = make_client_config_with_versions(KeyType::Rsa2048, &[version]); + let mut client_config = + make_client_config_with_versions(KeyType::Rsa2048, &[version], &provider); client_config.key_log = Arc::new(rustls::KeyLogFile::new()); let (mut client, mut server) = @@ -59,7 +61,8 @@ fn exercise_key_log_file_for_client() { #[test] fn exercise_key_log_file_for_server() { serialized(|| { - let mut server_config = make_server_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider); unsafe { env::set_var("SSLKEYLOGFILE", "./sslkeylogfile.txt") }; server_config.key_log = Arc::new(rustls::KeyLogFile::new()); @@ -67,7 +70,8 @@ fn exercise_key_log_file_for_server() { let server_config = Arc::new(server_config); for version in rustls::ALL_VERSIONS { - let client_config = make_client_config_with_versions(KeyType::Rsa2048, &[version]); + let client_config = + make_client_config_with_versions(KeyType::Rsa2048, &[version], &provider); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); diff --git a/rustls/tests/server_cert_verifier.rs b/rustls/tests/server_cert_verifier.rs index d0eadac386..b0d2508b65 100644 --- a/rustls/tests/server_cert_verifier.rs +++ b/rustls/tests/server_cert_verifier.rs @@ -32,13 +32,14 @@ use x509_parser::x509::X509Name; #[test] fn client_can_override_certificate_verification() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES.iter() { let verifier = Arc::new(MockServerVerifier::accepts_anything()); - let server_config = Arc::new(make_server_config(*kt)); + let server_config = Arc::new(make_server_config(*kt, &provider)); for version in rustls::ALL_VERSIONS { - let mut client_config = make_client_config_with_versions(*kt, &[version]); + let mut client_config = make_client_config_with_versions(*kt, &[version], &provider); client_config .dangerous() .set_certificate_verifier(verifier.clone()); @@ -52,15 +53,16 @@ fn client_can_override_certificate_verification() { #[test] fn client_can_override_certificate_verification_and_reject_certificate() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES.iter() { let verifier = Arc::new(MockServerVerifier::rejects_certificate( Error::InvalidMessage(InvalidMessage::HandshakePayloadTooLarge), )); - let server_config = Arc::new(make_server_config(*kt)); + let server_config = Arc::new(make_server_config(*kt, &provider)); for version in rustls::ALL_VERSIONS { - let mut client_config = make_client_config_with_versions(*kt, &[version]); + let mut client_config = make_client_config_with_versions(*kt, &[version], &provider); client_config .dangerous() .set_certificate_verifier(verifier.clone()); @@ -84,8 +86,10 @@ fn client_can_override_certificate_verification_and_reject_certificate() { #[cfg(feature = "tls12")] #[test] fn client_can_override_certificate_verification_and_reject_tls12_signatures() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES.iter() { - let mut client_config = make_client_config_with_versions(*kt, &[&rustls::version::TLS12]); + let mut client_config = + make_client_config_with_versions(*kt, &[&rustls::version::TLS12], &provider); let verifier = Arc::new(MockServerVerifier::rejects_tls12_signatures( Error::InvalidMessage(InvalidMessage::HandshakePayloadTooLarge), )); @@ -94,7 +98,7 @@ fn client_can_override_certificate_verification_and_reject_tls12_signatures() { .dangerous() .set_certificate_verifier(verifier); - let server_config = Arc::new(make_server_config(*kt)); + let server_config = Arc::new(make_server_config(*kt, &provider)); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); @@ -113,8 +117,13 @@ fn client_can_override_certificate_verification_and_reject_tls12_signatures() { #[test] fn client_can_override_certificate_verification_and_reject_tls13_signatures() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES.iter() { - let mut client_config = make_client_config_with_versions(*kt, &[&rustls::version::TLS13]); + let mut client_config = make_client_config_with_versions( + *kt, + &[&rustls::version::TLS13], + &provider::default_provider(), + ); let verifier = Arc::new(MockServerVerifier::rejects_tls13_signatures( Error::InvalidMessage(InvalidMessage::HandshakePayloadTooLarge), )); @@ -123,7 +132,7 @@ fn client_can_override_certificate_verification_and_reject_tls13_signatures() { .dangerous() .set_certificate_verifier(verifier); - let server_config = Arc::new(make_server_config(*kt)); + let server_config = Arc::new(make_server_config(*kt, &provider)); let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); @@ -142,13 +151,14 @@ fn client_can_override_certificate_verification_and_reject_tls13_signatures() { #[test] fn client_can_override_certificate_verification_and_offer_no_signature_schemes() { + let provider = provider::default_provider(); for kt in ALL_KEY_TYPES.iter() { let verifier = Arc::new(MockServerVerifier::offers_no_signature_schemes()); - let server_config = Arc::new(make_server_config(*kt)); + let server_config = Arc::new(make_server_config(*kt, &provider)); for version in rustls::ALL_VERSIONS { - let mut client_config = make_client_config_with_versions(*kt, &[version]); + let mut client_config = make_client_config_with_versions(*kt, &[version], &provider); client_config .dangerous() .set_certificate_verifier(verifier.clone()); @@ -171,7 +181,8 @@ fn client_can_override_certificate_verification_and_offer_no_signature_schemes() #[test] fn cas_extension_in_client_hello_if_server_verifier_requests_it() { - let server_config = Arc::new(make_server_config(KeyType::Rsa2048)); + let provider = provider::default_provider(); + let server_config = Arc::new(make_server_config(KeyType::Rsa2048, &provider)); let mut root_cert_store = RootCertStore::empty(); root_cert_store @@ -180,7 +191,7 @@ fn cas_extension_in_client_hello_if_server_verifier_requests_it() { let server_verifier = WebPkiServerVerifier::builder_with_provider( Arc::new(root_cert_store), - Arc::new(provider::default_provider()), + Arc::new(provider.clone()), ) .build() .unwrap(); @@ -196,7 +207,7 @@ fn cas_extension_in_client_hello_if_server_verifier_requests_it() { for (protocol_version, cas_extension_expected) in [(&TLS12, false), (&TLS13, true)] { let client_config = Arc::new( - client_config_builder_with_versions(&[protocol_version]) + client_config_builder_with_versions(&[protocol_version], &provider) .dangerous() .with_custom_certificate_verifier(cas_sending_server_verifier.clone()) .with_no_client_auth(), @@ -227,6 +238,7 @@ fn cas_extension_in_client_hello_if_server_verifier_requests_it() { #[test] fn client_can_request_certain_trusted_cas() { + let provider = provider::default_provider(); // These keys have CAs with different names, which our test needs. // They also share the same sigalgs, so the server won't pick one over the other based on sigalgs. let key_types = [KeyType::Rsa2048, KeyType::Rsa3072, KeyType::Rsa4096]; @@ -238,7 +250,7 @@ fn client_can_request_certain_trusted_cas() { kt.ca_distinguished_name() .to_vec() .into(), - kt.certified_key_with_cert_chain(&provider::default_provider()) + kt.certified_key_with_cert_chain(&provider) .unwrap(), ) }) @@ -246,7 +258,7 @@ fn client_can_request_certain_trusted_cas() { ); let server_config = Arc::new( - server_config_builder() + server_config_builder(&provider) .with_no_client_auth() .with_cert_resolver(Arc::new(cert_resolver.clone())), ); @@ -260,7 +272,7 @@ fn client_can_request_certain_trusted_cas() { .unwrap(); let server_verifier = WebPkiServerVerifier::builder_with_provider( Arc::new(root_store), - Arc::new(provider::default_provider()), + Arc::new(provider.clone()), ) .build() .unwrap(); @@ -274,7 +286,7 @@ fn client_can_request_certain_trusted_cas() { )], }); - let cas_sending_client_config = client_config_builder() + let cas_sending_client_config = client_config_builder(&provider) .dangerous() .with_custom_certificate_verifier(cas_sending_server_verifier) .with_no_client_auth(); @@ -283,7 +295,7 @@ fn client_can_request_certain_trusted_cas() { make_pair_for_arc_configs(&Arc::new(cas_sending_client_config), &server_config); do_handshake(&mut client, &mut server); - let cas_unaware_client_config = client_config_builder() + let cas_unaware_client_config = client_config_builder(&provider) .dangerous() .with_custom_certificate_verifier(server_verifier) .with_no_client_auth(); diff --git a/rustls/tests/unbuffered.rs b/rustls/tests/unbuffered.rs index 706f7e16e9..f38d16a284 100644 --- a/rustls/tests/unbuffered.rs +++ b/rustls/tests/unbuffered.rs @@ -115,8 +115,10 @@ fn handshake_config( version: &'static rustls::SupportedProtocolVersion, editor: impl Fn(&mut ClientConfig, &mut ServerConfig), ) -> Outcome { - let mut server_config = make_server_config_with_versions(KeyType::Rsa2048, &[version]); - let mut client_config = make_client_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = + make_server_config_with_versions(KeyType::Rsa2048, &[version], &provider); + let mut client_config = make_client_config(KeyType::Rsa2048, &provider); editor(&mut client_config, &mut server_config); run( @@ -129,11 +131,13 @@ fn handshake_config( #[test] fn app_data_client_to_server() { + let provider = provider::default_provider(); let expected: &[_] = b"hello"; for version in rustls::ALL_VERSIONS { eprintln!("{version:?}"); - let server_config = make_server_config_with_versions(KeyType::Rsa2048, &[version]); - let client_config = make_client_config(KeyType::Rsa2048); + let server_config = + make_server_config_with_versions(KeyType::Rsa2048, &[version], &provider); + let client_config = make_client_config(KeyType::Rsa2048, &provider); let mut client_actions = Actions { app_data_to_send: Some(expected), @@ -163,11 +167,13 @@ fn app_data_client_to_server() { #[test] fn app_data_server_to_client() { + let provider = provider::default_provider(); let expected: &[_] = b"hello"; for version in rustls::ALL_VERSIONS { eprintln!("{version:?}"); - let server_config = make_server_config_with_versions(KeyType::Rsa2048, &[version]); - let client_config = make_client_config(KeyType::Rsa2048); + let server_config = + make_server_config_with_versions(KeyType::Rsa2048, &[version], &provider); + let client_config = make_client_config(KeyType::Rsa2048, &provider); let mut server_actions = Actions { app_data_to_send: Some(expected), @@ -197,13 +203,15 @@ fn app_data_server_to_client() { #[test] fn early_data() { + let provider = provider::default_provider(); let expected: &[_] = b"hello"; - let mut server_config = make_server_config(KeyType::Rsa2048); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider); server_config.max_early_data_size = 128; let server_config = Arc::new(server_config); - let mut client_config = make_client_config_with_versions(KeyType::Rsa2048, &[&TLS13]); + let mut client_config = + make_client_config_with_versions(KeyType::Rsa2048, &[&TLS13], &provider); client_config.enable_early_data = true; let client_config = Arc::new(client_config); @@ -426,10 +434,12 @@ fn run( #[test] fn close_notify_client_to_server() { + let provider = provider::default_provider(); for version in rustls::ALL_VERSIONS { eprintln!("{version:?}"); - let server_config = make_server_config_with_versions(KeyType::Rsa2048, &[version]); - let client_config = make_client_config(KeyType::Rsa2048); + let server_config = + make_server_config_with_versions(KeyType::Rsa2048, &[version], &provider); + let client_config = make_client_config(KeyType::Rsa2048, &provider); let mut client_actions = Actions { send_close_notify: true, @@ -450,10 +460,12 @@ fn close_notify_client_to_server() { #[test] fn close_notify_server_to_client() { + let provider = provider::default_provider(); for version in rustls::ALL_VERSIONS { eprintln!("{version:?}"); - let server_config = make_server_config_with_versions(KeyType::Rsa2048, &[version]); - let client_config = make_client_config(KeyType::Rsa2048); + let server_config = + make_server_config_with_versions(KeyType::Rsa2048, &[version], &provider); + let client_config = make_client_config(KeyType::Rsa2048, &provider); let mut server_actions = Actions { send_close_notify: true, @@ -725,7 +737,7 @@ fn refresh_traffic_keys_automatically() { .unwrap(), ); - let server_config = make_server_config(KeyType::Rsa2048); + let server_config = make_server_config(KeyType::Rsa2048, &provider::default_provider()); let mut outcome = run( Arc::new(client_config), &mut NO_ACTIONS.clone(), @@ -794,7 +806,7 @@ fn tls12_connection_fails_after_key_reaches_confidentiality_limit() { .unwrap(), ); - let server_config = make_server_config(KeyType::Ed25519); + let server_config = make_server_config(KeyType::Ed25519, &provider::default_provider()); let mut outcome = run( Arc::new(client_config), &mut NO_ACTIONS.clone(), @@ -891,8 +903,11 @@ fn tls13_packed_handshake() { #[test] fn rejects_junk() { - let mut server = - UnbufferedServerConnection::new(Arc::new(make_server_config(KeyType::Rsa2048))).unwrap(); + let mut server = UnbufferedServerConnection::new(Arc::new(make_server_config( + KeyType::Rsa2048, + &provider::default_provider(), + ))) + .unwrap(); let mut buf = [0xff; 5]; let UnbufferedStatus { discard, state } = server.process_tls_records(&mut buf); @@ -1397,8 +1412,9 @@ impl Buffer { fn make_connection_pair( version: &'static rustls::SupportedProtocolVersion, ) -> (UnbufferedClientConnection, UnbufferedServerConnection) { - let server_config = make_server_config(KeyType::Rsa2048); - let client_config = make_client_config_with_versions(KeyType::Rsa2048, &[version]); + let provider = provider::default_provider(); + let server_config = make_server_config(KeyType::Rsa2048, &provider); + let client_config = make_client_config_with_versions(KeyType::Rsa2048, &[version], &provider); let client = UnbufferedClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); @@ -1481,6 +1497,7 @@ fn test_secret_extraction_enabled() { // We support 3 different AEAD algorithms (AES-128-GCM mode, AES-256-GCM, and // Chacha20Poly1305), so that's 2*3 = 6 combinations to test. let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); for suite in [ cipher_suite::TLS13_AES_128_GCM_SHA256, cipher_suite::TLS13_AES_256_GCM_SHA384, @@ -1498,7 +1515,7 @@ fn test_secret_extraction_enabled() { let mut server_config = ServerConfig::builder_with_provider( CryptoProvider { cipher_suites: vec![suite], - ..provider::default_provider() + ..provider.clone() } .into(), ) @@ -1511,7 +1528,7 @@ fn test_secret_extraction_enabled() { server_config.enable_secret_extraction = true; let server_config = Arc::new(server_config); - let mut client_config = make_client_config(kt); + let mut client_config = make_client_config(kt, &provider); client_config.enable_secret_extraction = true; let mut outcome = run( @@ -1570,10 +1587,11 @@ fn test_secret_extraction_enabled() { #[test] fn kernel_err_on_secret_extraction_not_enabled() { - let server_config = make_server_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let server_config = make_server_config(KeyType::Rsa2048, &provider); let server_config = Arc::new(server_config); - let client_config = make_client_config(KeyType::Rsa2048); + let client_config = make_client_config(KeyType::Rsa2048, &provider); let client_config = Arc::new(client_config); let mut server = UnbufferedServerConnection::new(server_config).unwrap(); @@ -1596,11 +1614,12 @@ fn kernel_err_on_secret_extraction_not_enabled() { #[test] fn kernel_err_on_handshake_not_complete() { - let mut server_config = make_server_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider); server_config.enable_secret_extraction = true; let server_config = Arc::new(server_config); - let mut client_config = make_client_config(KeyType::Rsa2048); + let mut client_config = make_client_config(KeyType::Rsa2048, &provider); client_config.enable_secret_extraction = true; let client_config = Arc::new(client_config); @@ -1620,11 +1639,12 @@ fn kernel_err_on_handshake_not_complete() { #[test] fn kernel_initial_traffic_secrets_match() { - let mut server_config = make_server_config(KeyType::Rsa2048); + let provider = provider::default_provider(); + let mut server_config = make_server_config(KeyType::Rsa2048, &provider); server_config.enable_secret_extraction = true; let server_config = Arc::new(server_config); - let mut client_config = make_client_config(KeyType::Rsa2048); + let mut client_config = make_client_config(KeyType::Rsa2048, &provider); client_config.enable_secret_extraction = true; let client_config = Arc::new(client_config); @@ -1647,11 +1667,14 @@ fn kernel_initial_traffic_secrets_match() { #[test] fn kernel_key_updates_tls13() { - let mut server_config = make_server_config_with_versions(KeyType::Rsa2048, &[&TLS13]); + let provider = provider::default_provider(); + let mut server_config = + make_server_config_with_versions(KeyType::Rsa2048, &[&TLS13], &provider); server_config.enable_secret_extraction = true; let server_config = Arc::new(server_config); - let mut client_config = make_client_config_with_versions(KeyType::Rsa2048, &[&TLS13]); + let mut client_config = + make_client_config_with_versions(KeyType::Rsa2048, &[&TLS13], &provider); client_config.enable_secret_extraction = true; let client_config = Arc::new(client_config); @@ -1685,11 +1708,14 @@ fn kernel_key_updates_tls12() { let _ = env_logger::try_init(); - let mut server_config = make_server_config_with_versions(KeyType::Rsa2048, &[&TLS12]); + let provider = provider::default_provider(); + let mut server_config = + make_server_config_with_versions(KeyType::Rsa2048, &[&TLS12], &provider); server_config.enable_secret_extraction = true; let server_config = Arc::new(server_config); - let mut client_config = make_client_config_with_versions(KeyType::Rsa2048, &[&TLS12]); + let mut client_config = + make_client_config_with_versions(KeyType::Rsa2048, &[&TLS12], &provider); client_config.enable_secret_extraction = true; let client_config = Arc::new(client_config); From 69df2a29bbde724e5e0aaaef53c8a05726fe90bd Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 13:50:45 +0100 Subject: [PATCH 09/14] Make provider explicit in AEAD limit tests --- rustls/tests/api.rs | 4 ++-- rustls/tests/common/mod.rs | 21 +++++++++++++++------ rustls/tests/unbuffered.rs | 16 ++++++++++------ 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index 75a4aac4c8..e2da5ec03d 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -8278,7 +8278,7 @@ fn test_automatic_refresh_traffic_keys() { } const KEY_UPDATE_SIZE: usize = encrypted_size(5); - let provider = aes_128_gcm_with_1024_confidentiality_limit(); + let provider = aes_128_gcm_with_1024_confidentiality_limit(provider::default_provider()); let client_config = finish_client_config( KeyType::Ed25519, @@ -8343,7 +8343,7 @@ fn test_automatic_refresh_traffic_keys() { #[cfg(feature = "tls12")] #[test] fn tls12_connection_fails_after_key_reaches_confidentiality_limit() { - let provider = aes_128_gcm_with_1024_confidentiality_limit(); + let provider = aes_128_gcm_with_1024_confidentiality_limit(provider::default_provider()); let client_config = finish_client_config( KeyType::Ed25519, diff --git a/rustls/tests/common/mod.rs b/rustls/tests/common/mod.rs index f377605ee6..797908be7a 100644 --- a/rustls/tests/common/mod.rs +++ b/rustls/tests/common/mod.rs @@ -32,7 +32,7 @@ use rustls::unbuffered::{ ConnectionState, EncodeError, UnbufferedConnectionCommon, UnbufferedStatus, }; use rustls::{ - ClientConfig, ClientConnection, Connection, ConnectionCommon, ContentType, + CipherSuite, ClientConfig, ClientConnection, Connection, ConnectionCommon, ContentType, DigitallySignedStruct, DistinguishedName, Error, InconsistentKeys, NamedGroup, ProtocolVersion, RootCertStore, ServerConfig, ServerConnection, SideData, SignatureScheme, SupportedCipherSuite, }; @@ -1483,7 +1483,9 @@ impl RawTls { } } -pub fn aes_128_gcm_with_1024_confidentiality_limit() -> Arc { +pub fn aes_128_gcm_with_1024_confidentiality_limit( + provider: CryptoProvider, +) -> Arc { const CONFIDENTIALITY_LIMIT: u64 = 1024; // needed to extend lifetime of Tls13CipherSuite to 'static @@ -1491,7 +1493,11 @@ pub fn aes_128_gcm_with_1024_confidentiality_limit() -> Arc { static TLS12_LIMITED_SUITE: OnceLock = OnceLock::new(); let tls13_limited = TLS13_LIMITED_SUITE.get_or_init(|| { - let tls13 = provider::cipher_suite::TLS13_AES_128_GCM_SHA256 + let tls13 = provider + .cipher_suites + .iter() + .find(|cs| cs.suite() == CipherSuite::TLS13_AES_128_GCM_SHA256) + .unwrap() .tls13() .unwrap(); @@ -1505,8 +1511,11 @@ pub fn aes_128_gcm_with_1024_confidentiality_limit() -> Arc { }); let tls12_limited = TLS12_LIMITED_SUITE.get_or_init(|| { - let SupportedCipherSuite::Tls12(tls12) = - provider::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + let SupportedCipherSuite::Tls12(tls12) = *provider + .cipher_suites + .iter() + .find(|cs| cs.suite() == CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .unwrap() else { unreachable!(); }; @@ -1525,7 +1534,7 @@ pub fn aes_128_gcm_with_1024_confidentiality_limit() -> Arc { SupportedCipherSuite::Tls13(tls13_limited), SupportedCipherSuite::Tls12(tls12_limited), ], - ..provider::default_provider() + ..provider } .into() } diff --git a/rustls/tests/unbuffered.rs b/rustls/tests/unbuffered.rs index f38d16a284..53c280e6b1 100644 --- a/rustls/tests/unbuffered.rs +++ b/rustls/tests/unbuffered.rs @@ -732,9 +732,11 @@ fn refresh_traffic_keys_automatically() { let client_config = finish_client_config( KeyType::Rsa2048, - ClientConfig::builder_with_provider(aes_128_gcm_with_1024_confidentiality_limit()) - .with_safe_default_protocol_versions() - .unwrap(), + ClientConfig::builder_with_provider(aes_128_gcm_with_1024_confidentiality_limit( + provider::default_provider(), + )) + .with_safe_default_protocol_versions() + .unwrap(), ); let server_config = make_server_config(KeyType::Rsa2048, &provider::default_provider()); @@ -801,9 +803,11 @@ fn tls12_connection_fails_after_key_reaches_confidentiality_limit() { let client_config = finish_client_config( KeyType::Ed25519, - ClientConfig::builder_with_provider(aes_128_gcm_with_1024_confidentiality_limit()) - .with_protocol_versions(&[&rustls::version::TLS12]) - .unwrap(), + ClientConfig::builder_with_provider(aes_128_gcm_with_1024_confidentiality_limit( + provider::default_provider(), + )) + .with_protocol_versions(&[&rustls::version::TLS12]) + .unwrap(), ); let server_config = make_server_config(KeyType::Ed25519, &provider::default_provider()); From 1085522a76e48bd1f9a512ee71653a0c513bbbf6 Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 13:53:43 +0100 Subject: [PATCH 10/14] Make provider explicit in tests using plaintext suite --- rustls/tests/api.rs | 18 ++++++++++-------- rustls/tests/common/mod.rs | 12 +++++++----- rustls/tests/unbuffered.rs | 18 ++++++++++-------- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index e2da5ec03d..a689bf1a9f 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -8407,14 +8407,16 @@ fn tls13_packed_handshake() { // regression test for https://github.com/rustls/rustls/issues/2040 // (did not affect the buffered api) - let client_config = ClientConfig::builder_with_provider(unsafe_plaintext_crypto_provider()) - .with_safe_default_protocol_versions() - .unwrap() - .dangerous() - .with_custom_certificate_verifier(Arc::new(MockServerVerifier::rejects_certificate( - CertificateError::UnknownIssuer.into(), - ))) - .with_no_client_auth(); + let client_config = ClientConfig::builder_with_provider(unsafe_plaintext_crypto_provider( + provider::default_provider(), + )) + .with_safe_default_protocol_versions() + .unwrap() + .dangerous() + .with_custom_certificate_verifier(Arc::new(MockServerVerifier::rejects_certificate( + CertificateError::UnknownIssuer.into(), + ))) + .with_no_client_auth(); let mut client = ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); diff --git a/rustls/tests/common/mod.rs b/rustls/tests/common/mod.rs index 797908be7a..26f1c7f829 100644 --- a/rustls/tests/common/mod.rs +++ b/rustls/tests/common/mod.rs @@ -39,8 +39,6 @@ use rustls::{ use webpki::anchor_from_trusted_cert; -use super::provider; - // Import `Arc` here for tests - can be overwritten to test with another `Arc` such as `portable_atomic_util::Arc` pub use std::sync::Arc; @@ -1539,11 +1537,15 @@ pub fn aes_128_gcm_with_1024_confidentiality_limit( .into() } -pub fn unsafe_plaintext_crypto_provider() -> Arc { +pub fn unsafe_plaintext_crypto_provider(provider: CryptoProvider) -> Arc { static TLS13_PLAIN_SUITE: OnceLock = OnceLock::new(); let tls13 = TLS13_PLAIN_SUITE.get_or_init(|| { - let tls13 = provider::cipher_suite::TLS13_AES_256_GCM_SHA384 + let tls13 = provider + .cipher_suites + .iter() + .find(|cs| cs.suite() == CipherSuite::TLS13_AES_256_GCM_SHA384) + .unwrap() .tls13() .unwrap(); @@ -1556,7 +1558,7 @@ pub fn unsafe_plaintext_crypto_provider() -> Arc { CryptoProvider { cipher_suites: vec![SupportedCipherSuite::Tls13(tls13)], - ..provider::default_provider() + ..provider } .into() } diff --git a/rustls/tests/unbuffered.rs b/rustls/tests/unbuffered.rs index 53c280e6b1..e32b051371 100644 --- a/rustls/tests/unbuffered.rs +++ b/rustls/tests/unbuffered.rs @@ -878,14 +878,16 @@ fn tls13_packed_handshake() { } // regression test for https://github.com/rustls/rustls/issues/2040 - let client_config = ClientConfig::builder_with_provider(unsafe_plaintext_crypto_provider()) - .with_safe_default_protocol_versions() - .unwrap() - .dangerous() - .with_custom_certificate_verifier(Arc::new(MockServerVerifier::rejects_certificate( - CertificateError::UnknownIssuer.into(), - ))) - .with_no_client_auth(); + let client_config = ClientConfig::builder_with_provider(unsafe_plaintext_crypto_provider( + provider::default_provider(), + )) + .with_safe_default_protocol_versions() + .unwrap() + .dangerous() + .with_custom_certificate_verifier(Arc::new(MockServerVerifier::rejects_certificate( + CertificateError::UnknownIssuer.into(), + ))) + .with_no_client_auth(); let mut client = UnbufferedClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); From a2e76d154b0847972008192a7e6e93235b20654b Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 14:32:25 +0100 Subject: [PATCH 11/14] Eliminate final use of webpki crate in rustls tests --- rustls/tests/common/mod.rs | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/rustls/tests/common/mod.rs b/rustls/tests/common/mod.rs index 26f1c7f829..9186a451e2 100644 --- a/rustls/tests/common/mod.rs +++ b/rustls/tests/common/mod.rs @@ -37,8 +37,6 @@ use rustls::{ RootCertStore, ServerConfig, ServerConnection, SideData, SignatureScheme, SupportedCipherSuite, }; -use webpki::anchor_from_trusted_cert; - // Import `Arc` here for tests - can be overwritten to test with another `Arc` such as `portable_atomic_util::Arc` pub use std::sync::Arc; @@ -550,15 +548,11 @@ pub fn get_client_root_store(kt: KeyType) -> Arc { // The key type's chain file contains the DER encoding of the EE cert, the intermediate cert, // and the root trust anchor. We want only the trust anchor to build the root cert store. let chain = kt.get_chain(); - let trust_anchor = chain.last().unwrap(); - RootCertStore { - roots: vec![ - anchor_from_trusted_cert(trust_anchor) - .unwrap() - .to_owned(), - ], - } - .into() + let mut roots = RootCertStore::empty(); + roots + .add(chain.last().unwrap().clone()) + .unwrap(); + roots.into() } pub fn make_server_config_with_mandatory_client_auth_crls( From c84ca91873c3f7ad0833661a8e1afcf517201b31 Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 11:58:06 +0100 Subject: [PATCH 12/14] Move tests/common/mod.rs to rustls-test crate Everything moves, with the exception of parts that depend on the cargo features of the rustls crate. --- Cargo.lock | 9 + Cargo.toml | 3 + ci-bench/Cargo.toml | 1 + ci-bench/src/benchmark.rs | 2 +- ci-bench/src/main.rs | 12 +- ci-bench/src/util.rs | 31 - rustls-test/Cargo.toml | 8 + rustls-test/src/lib.rs | 1748 ++++++++++++++++++++++++++++++++++ rustls/Cargo.toml | 1 + rustls/benches/benchmarks.rs | 4 +- rustls/tests/common/mod.rs | 1706 +-------------------------------- 11 files changed, 1781 insertions(+), 1744 deletions(-) create mode 100644 rustls-test/Cargo.toml create mode 100644 rustls-test/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 0b6edc4141..4b545ee4bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2589,6 +2589,7 @@ dependencies = [ "rcgen", "ring", "rustls-pki-types", + "rustls-test", "rustls-webpki", "rustversion", "serde", @@ -2639,6 +2640,7 @@ dependencies = [ "itertools 0.14.0", "rayon", "rustls 0.23.27", + "rustls-test", "tikv-jemallocator", ] @@ -2741,6 +2743,13 @@ dependencies = [ "serde_json", ] +[[package]] +name = "rustls-test" +version = "0.1.0" +dependencies = [ + "rustls 0.23.27", +] + [[package]] name = "rustls-webpki" version = "0.103.2" diff --git a/Cargo.toml b/Cargo.toml index d560df2d56..0e3fb6a9a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,8 @@ members = [ "provider-example", # the main library and tests "rustls", + # common code for testing the core crate + "rustls-test", # benchmarking tool "rustls-bench", # experimental post-quantum algorithm support @@ -79,6 +81,7 @@ rcgen = { version = "0.13", features = ["pem", "aws_lc_rs"], default-features = regex = "1" ring = "0.17" rsa = { version = "0.9", features = ["sha2"], default-features = false } +rustls-test = { path = "rustls-test/" } serde = { version = "1", features = ["derive"] } serde_json = "1" sha2 = { version = "0.10", default-features = false } diff --git a/ci-bench/Cargo.toml b/ci-bench/Cargo.toml index 6fde65b814..ff92bfb507 100644 --- a/ci-bench/Cargo.toml +++ b/ci-bench/Cargo.toml @@ -15,6 +15,7 @@ fxhash = { workspace = true } itertools = { workspace = true } rayon = { workspace = true } rustls = { path = "../rustls", features = ["ring", "aws_lc_rs"] } +rustls-test = { workspace = true } [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = { workspace = true } diff --git a/ci-bench/src/benchmark.rs b/ci-bench/src/benchmark.rs index 1e38cc5401..325de04b79 100644 --- a/ci-bench/src/benchmark.rs +++ b/ci-bench/src/benchmark.rs @@ -5,7 +5,7 @@ use itertools::Itertools; use crate::Side; use crate::callgrind::InstructionCounts; -use crate::util::KeyType; +use rustls_test::KeyType; /// Validates a benchmark collection, returning an error if the provided benchmarks are invalid /// diff --git a/ci-bench/src/main.rs b/ci-bench/src/main.rs index 2b768bab87..3d2d359596 100644 --- a/ci-bench/src/main.rs +++ b/ci-bench/src/main.rs @@ -17,20 +17,18 @@ use rayon::iter::Either; use rayon::prelude::*; use rustls::client::Resumption; use rustls::crypto::{CryptoProvider, GetRandomFailed, SecureRandom, aws_lc_rs, ring}; -use rustls::pki_types::CertificateDer; -use rustls::pki_types::pem::PemObject; use rustls::server::{NoServerSessionStorage, ServerSessionMemoryCache, WebPkiClientVerifier}; use rustls::{ CipherSuite, ClientConfig, ClientConnection, HandshakeKind, ProtocolVersion, RootCertStore, ServerConfig, ServerConnection, }; +use rustls_test::KeyType; use crate::benchmark::{ Benchmark, BenchmarkKind, BenchmarkParams, ResumptionKind, get_reported_instr_count, validate_benchmarks, }; use crate::callgrind::{CallgrindRunner, CountInstructions}; -use crate::util::KeyType; use crate::util::async_io::{self, AsyncRead, AsyncWrite}; use crate::util::transport::{ read_handshake_message, read_plaintext_to_end_bounded, send_handshake_message, @@ -504,11 +502,9 @@ impl ClientSideStepper<'_> { fn make_config(params: &BenchmarkParams, resume: ResumptionKind) -> Arc { assert_eq!(params.ciphersuite.version(), params.version); let mut root_store = RootCertStore::empty(); - root_store.add_parsable_certificates( - CertificateDer::pem_file_iter(params.key_type.path_for("ca.cert")) - .unwrap() - .map(|result| result.unwrap()), - ); + root_store + .add(params.key_type.ca_cert()) + .unwrap(); let mut cfg = ClientConfig::builder_with_provider( CryptoProvider { diff --git a/ci-bench/src/util.rs b/ci-bench/src/util.rs index f883ff294b..f924d0d641 100644 --- a/ci-bench/src/util.rs +++ b/ci-bench/src/util.rs @@ -1,34 +1,3 @@ -use rustls::pki_types::pem::PemObject; -use rustls::pki_types::{CertificateDer, PrivateKeyDer}; - -#[derive(PartialEq, Clone, Copy, Debug)] -pub enum KeyType { - Rsa2048, - EcdsaP256, - EcdsaP384, -} - -impl KeyType { - pub(crate) fn path_for(&self, part: &str) -> String { - match self { - Self::Rsa2048 => format!("../test-ca/rsa-2048/{part}"), - Self::EcdsaP256 => format!("../test-ca/ecdsa-p256/{part}"), - Self::EcdsaP384 => format!("../test-ca/ecdsa-p384/{part}"), - } - } - - pub(crate) fn get_chain(&self) -> Vec> { - CertificateDer::pem_file_iter(self.path_for("end.fullchain")) - .unwrap() - .map(|result| result.unwrap()) - .collect() - } - - pub(crate) fn get_key(&self) -> PrivateKeyDer<'static> { - PrivateKeyDer::from_pem_file(self.path_for("end.key")).unwrap() - } -} - pub mod async_io { //! Async IO building blocks required for sharing code between the instruction count and //! wall-time benchmarks diff --git a/rustls-test/Cargo.toml b/rustls-test/Cargo.toml new file mode 100644 index 0000000000..877ec34833 --- /dev/null +++ b/rustls-test/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "rustls-test" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +rustls = { path = "../rustls", default-features = false, features = ["std", "tls12"] } diff --git a/rustls-test/src/lib.rs b/rustls-test/src/lib.rs new file mode 100644 index 0000000000..64e922bad1 --- /dev/null +++ b/rustls-test/src/lib.rs @@ -0,0 +1,1748 @@ +use std::io; +use std::ops::DerefMut; +use std::sync::Arc; +use std::sync::OnceLock; + +use rustls::pki_types::pem::PemObject; +use rustls::pki_types::{ + CertificateDer, CertificateRevocationListDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName, + SubjectPublicKeyInfoDer, UnixTime, +}; + +use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; +use rustls::client::{ + AlwaysResolvesClientRawPublicKeys, ServerCertVerifierBuilder, UnbufferedClientConnection, + WebPkiServerVerifier, +}; +use rustls::crypto::cipher::{InboundOpaqueMessage, MessageDecrypter, MessageEncrypter}; +use rustls::crypto::{ + CryptoProvider, WebPkiSupportedAlgorithms, verify_tls13_signature_with_raw_key, +}; +use rustls::internal::msgs::codec::{Codec, Reader}; +use rustls::internal::msgs::message::{Message, OutboundOpaqueMessage, PlainMessage}; +use rustls::server::danger::{ClientCertVerified, ClientCertVerifier}; +use rustls::server::{ + AlwaysResolvesServerRawPublicKeys, ClientCertVerifierBuilder, UnbufferedServerConnection, + WebPkiClientVerifier, +}; +use rustls::sign::CertifiedKey; +use rustls::unbuffered::{ + ConnectionState, EncodeError, UnbufferedConnectionCommon, UnbufferedStatus, +}; +use rustls::{ + CipherSuite, ClientConfig, ClientConnection, Connection, ConnectionCommon, ContentType, + DigitallySignedStruct, DistinguishedName, Error, InconsistentKeys, NamedGroup, ProtocolVersion, + RootCertStore, ServerConfig, ServerConnection, SideData, SignatureScheme, SupportedCipherSuite, +}; + +macro_rules! embed_files { + ( + $( + ($name:ident, $keytype:expr, $path:expr); + )+ + ) => { + $( + const $name: &'static [u8] = include_bytes!( + concat!("../../test-ca/", $keytype, "/", $path)); + )+ + + pub fn bytes_for(keytype: &str, path: &str) -> &'static [u8] { + match (keytype, path) { + $( + ($keytype, $path) => $name, + )+ + _ => panic!("unknown keytype {} with path {}", keytype, path), + } + } + } +} + +embed_files! { + (ECDSA_P256_END_PEM_SPKI, "ecdsa-p256", "end.spki.pem"); + (ECDSA_P256_CLIENT_PEM_SPKI, "ecdsa-p256", "client.spki.pem"); + (ECDSA_P256_CA_CERT, "ecdsa-p256", "ca.cert"); + (ECDSA_P256_CA_DER, "ecdsa-p256", "ca.der"); + (ECDSA_P256_CA_KEY, "ecdsa-p256", "ca.key"); + (ECDSA_P256_CLIENT_CERT, "ecdsa-p256", "client.cert"); + (ECDSA_P256_CLIENT_CHAIN, "ecdsa-p256", "client.chain"); + (ECDSA_P256_CLIENT_FULLCHAIN, "ecdsa-p256", "client.fullchain"); + (ECDSA_P256_CLIENT_KEY, "ecdsa-p256", "client.key"); + (ECDSA_P256_END_CRL_PEM, "ecdsa-p256", "end.revoked.crl.pem"); + (ECDSA_P256_CLIENT_CRL_PEM, "ecdsa-p256", "client.revoked.crl.pem"); + (ECDSA_P256_INTERMEDIATE_CRL_PEM, "ecdsa-p256", "inter.revoked.crl.pem"); + (ECDSA_P256_EXPIRED_CRL_PEM, "ecdsa-p256", "end.expired.crl.pem"); + (ECDSA_P256_END_CERT, "ecdsa-p256", "end.cert"); + (ECDSA_P256_END_CHAIN, "ecdsa-p256", "end.chain"); + (ECDSA_P256_END_FULLCHAIN, "ecdsa-p256", "end.fullchain"); + (ECDSA_P256_END_KEY, "ecdsa-p256", "end.key"); + (ECDSA_P256_INTER_CERT, "ecdsa-p256", "inter.cert"); + (ECDSA_P256_INTER_KEY, "ecdsa-p256", "inter.key"); + + (ECDSA_P384_END_PEM_SPKI, "ecdsa-p384", "end.spki.pem"); + (ECDSA_P384_CLIENT_PEM_SPKI, "ecdsa-p384", "client.spki.pem"); + (ECDSA_P384_CA_CERT, "ecdsa-p384", "ca.cert"); + (ECDSA_P384_CA_DER, "ecdsa-p384", "ca.der"); + (ECDSA_P384_CA_KEY, "ecdsa-p384", "ca.key"); + (ECDSA_P384_CLIENT_CERT, "ecdsa-p384", "client.cert"); + (ECDSA_P384_CLIENT_CHAIN, "ecdsa-p384", "client.chain"); + (ECDSA_P384_CLIENT_FULLCHAIN, "ecdsa-p384", "client.fullchain"); + (ECDSA_P384_CLIENT_KEY, "ecdsa-p384", "client.key"); + (ECDSA_P384_END_CRL_PEM, "ecdsa-p384", "end.revoked.crl.pem"); + (ECDSA_P384_CLIENT_CRL_PEM, "ecdsa-p384", "client.revoked.crl.pem"); + (ECDSA_P384_INTERMEDIATE_CRL_PEM, "ecdsa-p384", "inter.revoked.crl.pem"); + (ECDSA_P384_EXPIRED_CRL_PEM, "ecdsa-p384", "end.expired.crl.pem"); + (ECDSA_P384_END_CERT, "ecdsa-p384", "end.cert"); + (ECDSA_P384_END_CHAIN, "ecdsa-p384", "end.chain"); + (ECDSA_P384_END_FULLCHAIN, "ecdsa-p384", "end.fullchain"); + (ECDSA_P384_END_KEY, "ecdsa-p384", "end.key"); + (ECDSA_P384_INTER_CERT, "ecdsa-p384", "inter.cert"); + (ECDSA_P384_INTER_KEY, "ecdsa-p384", "inter.key"); + + (ECDSA_P521_END_PEM_SPKI, "ecdsa-p521", "end.spki.pem"); + (ECDSA_P521_CLIENT_PEM_SPKI, "ecdsa-p521", "client.spki.pem"); + (ECDSA_P521_CA_CERT, "ecdsa-p521", "ca.cert"); + (ECDSA_P521_CA_DER, "ecdsa-p521", "ca.der"); + (ECDSA_P521_CA_KEY, "ecdsa-p521", "ca.key"); + (ECDSA_P521_CLIENT_CERT, "ecdsa-p521", "client.cert"); + (ECDSA_P521_CLIENT_CHAIN, "ecdsa-p521", "client.chain"); + (ECDSA_P521_CLIENT_FULLCHAIN, "ecdsa-p521", "client.fullchain"); + (ECDSA_P521_CLIENT_KEY, "ecdsa-p521", "client.key"); + (ECDSA_P521_END_CRL_PEM, "ecdsa-p521", "end.revoked.crl.pem"); + (ECDSA_P521_CLIENT_CRL_PEM, "ecdsa-p521", "client.revoked.crl.pem"); + (ECDSA_P521_INTERMEDIATE_CRL_PEM, "ecdsa-p521", "inter.revoked.crl.pem"); + (ECDSA_P521_EXPIRED_CRL_PEM, "ecdsa-p521", "end.expired.crl.pem"); + (ECDSA_P521_END_CERT, "ecdsa-p521", "end.cert"); + (ECDSA_P521_END_CHAIN, "ecdsa-p521", "end.chain"); + (ECDSA_P521_END_FULLCHAIN, "ecdsa-p521", "end.fullchain"); + (ECDSA_P521_END_KEY, "ecdsa-p521", "end.key"); + (ECDSA_P521_INTER_CERT, "ecdsa-p521", "inter.cert"); + (ECDSA_P521_INTER_KEY, "ecdsa-p521", "inter.key"); + + (EDDSA_END_PEM_SPKI, "eddsa", "end.spki.pem"); + (EDDSA_CLIENT_PEM_SPKI, "eddsa", "client.spki.pem"); + (EDDSA_CA_CERT, "eddsa", "ca.cert"); + (EDDSA_CA_DER, "eddsa", "ca.der"); + (EDDSA_CA_KEY, "eddsa", "ca.key"); + (EDDSA_CLIENT_CERT, "eddsa", "client.cert"); + (EDDSA_CLIENT_CHAIN, "eddsa", "client.chain"); + (EDDSA_CLIENT_FULLCHAIN, "eddsa", "client.fullchain"); + (EDDSA_CLIENT_KEY, "eddsa", "client.key"); + (EDDSA_END_CRL_PEM, "eddsa", "end.revoked.crl.pem"); + (EDDSA_CLIENT_CRL_PEM, "eddsa", "client.revoked.crl.pem"); + (EDDSA_INTERMEDIATE_CRL_PEM, "eddsa", "inter.revoked.crl.pem"); + (EDDSA_EXPIRED_CRL_PEM, "eddsa", "end.expired.crl.pem"); + (EDDSA_END_CERT, "eddsa", "end.cert"); + (EDDSA_END_CHAIN, "eddsa", "end.chain"); + (EDDSA_END_FULLCHAIN, "eddsa", "end.fullchain"); + (EDDSA_END_KEY, "eddsa", "end.key"); + (EDDSA_INTER_CERT, "eddsa", "inter.cert"); + (EDDSA_INTER_KEY, "eddsa", "inter.key"); + + (RSA_2048_END_PEM_SPKI, "rsa-2048", "end.spki.pem"); + (RSA_2048_CLIENT_PEM_SPKI, "rsa-2048", "client.spki.pem"); + (RSA_2048_CA_CERT, "rsa-2048", "ca.cert"); + (RSA_2048_CA_DER, "rsa-2048", "ca.der"); + (RSA_2048_CA_KEY, "rsa-2048", "ca.key"); + (RSA_2048_CLIENT_CERT, "rsa-2048", "client.cert"); + (RSA_2048_CLIENT_CHAIN, "rsa-2048", "client.chain"); + (RSA_2048_CLIENT_FULLCHAIN, "rsa-2048", "client.fullchain"); + (RSA_2048_CLIENT_KEY, "rsa-2048", "client.key"); + (RSA_2048_END_CRL_PEM, "rsa-2048", "end.revoked.crl.pem"); + (RSA_2048_CLIENT_CRL_PEM, "rsa-2048", "client.revoked.crl.pem"); + (RSA_2048_INTERMEDIATE_CRL_PEM, "rsa-2048", "inter.revoked.crl.pem"); + (RSA_2048_EXPIRED_CRL_PEM, "rsa-2048", "end.expired.crl.pem"); + (RSA_2048_END_CERT, "rsa-2048", "end.cert"); + (RSA_2048_END_CHAIN, "rsa-2048", "end.chain"); + (RSA_2048_END_FULLCHAIN, "rsa-2048", "end.fullchain"); + (RSA_2048_END_KEY, "rsa-2048", "end.key"); + (RSA_2048_INTER_CERT, "rsa-2048", "inter.cert"); + (RSA_2048_INTER_KEY, "rsa-2048", "inter.key"); + + (RSA_3072_END_PEM_SPKI, "rsa-3072", "end.spki.pem"); + (RSA_3072_CLIENT_PEM_SPKI, "rsa-3072", "client.spki.pem"); + (RSA_3072_CA_CERT, "rsa-3072", "ca.cert"); + (RSA_3072_CA_DER, "rsa-3072", "ca.der"); + (RSA_3072_CA_KEY, "rsa-3072", "ca.key"); + (RSA_3072_CLIENT_CERT, "rsa-3072", "client.cert"); + (RSA_3072_CLIENT_CHAIN, "rsa-3072", "client.chain"); + (RSA_3072_CLIENT_FULLCHAIN, "rsa-3072", "client.fullchain"); + (RSA_3072_CLIENT_KEY, "rsa-3072", "client.key"); + (RSA_3072_END_CRL_PEM, "rsa-3072", "end.revoked.crl.pem"); + (RSA_3072_CLIENT_CRL_PEM, "rsa-3072", "client.revoked.crl.pem"); + (RSA_3072_INTERMEDIATE_CRL_PEM, "rsa-3072", "inter.revoked.crl.pem"); + (RSA_3072_EXPIRED_CRL_PEM, "rsa-3072", "end.expired.crl.pem"); + (RSA_3072_END_CERT, "rsa-3072", "end.cert"); + (RSA_3072_END_CHAIN, "rsa-3072", "end.chain"); + (RSA_3072_END_FULLCHAIN, "rsa-3072", "end.fullchain"); + (RSA_3072_END_KEY, "rsa-3072", "end.key"); + (RSA_3072_INTER_CERT, "rsa-3072", "inter.cert"); + (RSA_3072_INTER_KEY, "rsa-3072", "inter.key"); + + (RSA_4096_END_PEM_SPKI, "rsa-4096", "end.spki.pem"); + (RSA_4096_CLIENT_PEM_SPKI, "rsa-4096", "client.spki.pem"); + (RSA_4096_CA_CERT, "rsa-4096", "ca.cert"); + (RSA_4096_CA_DER, "rsa-4096", "ca.der"); + (RSA_4096_CA_KEY, "rsa-4096", "ca.key"); + (RSA_4096_CLIENT_CERT, "rsa-4096", "client.cert"); + (RSA_4096_CLIENT_CHAIN, "rsa-4096", "client.chain"); + (RSA_4096_CLIENT_FULLCHAIN, "rsa-4096", "client.fullchain"); + (RSA_4096_CLIENT_KEY, "rsa-4096", "client.key"); + (RSA_4096_END_CRL_PEM, "rsa-4096", "end.revoked.crl.pem"); + (RSA_4096_CLIENT_CRL_PEM, "rsa-4096", "client.revoked.crl.pem"); + (RSA_4096_INTERMEDIATE_CRL_PEM, "rsa-4096", "inter.revoked.crl.pem"); + (RSA_4096_EXPIRED_CRL_PEM, "rsa-4096", "end.expired.crl.pem"); + (RSA_4096_END_CERT, "rsa-4096", "end.cert"); + (RSA_4096_END_CHAIN, "rsa-4096", "end.chain"); + (RSA_4096_END_FULLCHAIN, "rsa-4096", "end.fullchain"); + (RSA_4096_END_KEY, "rsa-4096", "end.key"); + (RSA_4096_INTER_CERT, "rsa-4096", "inter.cert"); + (RSA_4096_INTER_KEY, "rsa-4096", "inter.key"); +} + +pub fn transfer( + left: &mut impl DerefMut>, + right: &mut impl DerefMut>, +) -> usize { + let mut buf = [0u8; 262144]; + let mut total = 0; + + while left.wants_write() { + let sz = { + let into_buf: &mut dyn io::Write = &mut &mut buf[..]; + left.write_tls(into_buf).unwrap() + }; + total += sz; + if sz == 0 { + return total; + } + + let mut offs = 0; + loop { + let from_buf: &mut dyn io::Read = &mut &buf[offs..sz]; + offs += right.read_tls(from_buf).unwrap(); + if sz == offs { + break; + } + } + } + + total +} + +pub fn transfer_eof(conn: &mut impl DerefMut>) { + let empty_buf = [0u8; 0]; + let empty_cursor: &mut dyn io::Read = &mut &empty_buf[..]; + let sz = conn.read_tls(empty_cursor).unwrap(); + assert_eq!(sz, 0); +} + +pub enum Altered { + /// message has been edited in-place (or is unchanged) + InPlace, + /// send these raw bytes instead of the message. + Raw(Vec), +} + +pub fn transfer_altered(left: &mut Connection, filter: F, right: &mut Connection) -> usize +where + F: Fn(&mut Message) -> Altered, +{ + let mut buf = [0u8; 262144]; + let mut total = 0; + + while left.wants_write() { + let sz = { + let into_buf: &mut dyn io::Write = &mut &mut buf[..]; + left.write_tls(into_buf).unwrap() + }; + total += sz; + if sz == 0 { + return total; + } + + let mut reader = Reader::init(&buf[..sz]); + while reader.any_left() { + let message = OutboundOpaqueMessage::read(&mut reader).unwrap(); + + // this is a bit of a falsehood: we don't know whether message + // is encrypted. it is quite unlikely that a genuine encrypted + // message can be decoded by `Message::try_from`. + let plain = message.into_plain_message(); + + let message_enc = match Message::try_from(plain.clone()) { + Ok(mut message) => match filter(&mut message) { + Altered::InPlace => PlainMessage::from(message) + .into_unencrypted_opaque() + .encode(), + Altered::Raw(data) => data, + }, + // pass through encrypted/undecodable messages + Err(_) => plain.into_unencrypted_opaque().encode(), + }; + + let message_enc_reader: &mut dyn io::Read = &mut &message_enc[..]; + let len = right + .read_tls(message_enc_reader) + .unwrap(); + assert_eq!(len, message_enc.len()); + } + } + + total +} + +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum KeyType { + Rsa2048, + Rsa3072, + Rsa4096, + EcdsaP256, + EcdsaP384, + EcdsaP521, + Ed25519, +} + +pub static ALL_KEY_TYPES: &[KeyType] = &[ + KeyType::Rsa2048, + KeyType::Rsa3072, + KeyType::Rsa4096, + KeyType::EcdsaP256, + KeyType::EcdsaP384, + KeyType::EcdsaP521, + KeyType::Ed25519, +]; + +impl KeyType { + fn bytes_for(&self, part: &str) -> &'static [u8] { + match self { + Self::Rsa2048 => bytes_for("rsa-2048", part), + Self::Rsa3072 => bytes_for("rsa-3072", part), + Self::Rsa4096 => bytes_for("rsa-4096", part), + Self::EcdsaP256 => bytes_for("ecdsa-p256", part), + Self::EcdsaP384 => bytes_for("ecdsa-p384", part), + Self::EcdsaP521 => bytes_for("ecdsa-p521", part), + Self::Ed25519 => bytes_for("eddsa", part), + } + } + + pub fn ca_cert(&self) -> CertificateDer<'_> { + self.get_chain() + .into_iter() + .next_back() + .expect("cert chain cannot be empty") + } + + pub fn get_chain(&self) -> Vec> { + CertificateDer::pem_slice_iter(self.bytes_for("end.fullchain")) + .map(|result| result.unwrap()) + .collect() + } + + pub fn get_spki(&self) -> SubjectPublicKeyInfoDer<'static> { + SubjectPublicKeyInfoDer::from_pem_slice(self.bytes_for("end.spki.pem")).unwrap() + } + + pub fn get_key(&self) -> PrivateKeyDer<'static> { + PrivatePkcs8KeyDer::from_pem_slice(self.bytes_for("end.key")) + .unwrap() + .into() + } + + pub fn get_client_chain(&self) -> Vec> { + CertificateDer::pem_slice_iter(self.bytes_for("client.fullchain")) + .map(|result| result.unwrap()) + .collect() + } + + pub fn end_entity_crl(&self) -> CertificateRevocationListDer<'static> { + self.get_crl("end", "revoked") + } + + pub fn client_crl(&self) -> CertificateRevocationListDer<'static> { + self.get_crl("client", "revoked") + } + + pub fn intermediate_crl(&self) -> CertificateRevocationListDer<'static> { + self.get_crl("inter", "revoked") + } + + pub fn end_entity_crl_expired(&self) -> CertificateRevocationListDer<'static> { + self.get_crl("end", "expired") + } + + pub fn get_client_key(&self) -> PrivateKeyDer<'static> { + PrivatePkcs8KeyDer::from_pem_slice(self.bytes_for("client.key")) + .unwrap() + .into() + } + + pub fn get_client_spki(&self) -> SubjectPublicKeyInfoDer<'static> { + SubjectPublicKeyInfoDer::from_pem_slice(self.bytes_for("client.spki.pem")).unwrap() + } + + pub fn get_certified_client_key( + &self, + provider: &CryptoProvider, + ) -> Result, Error> { + let private_key = provider + .key_provider + .load_private_key(self.get_client_key())?; + let public_key = private_key + .public_key() + .ok_or(Error::InconsistentKeys(InconsistentKeys::Unknown))?; + let public_key_as_cert = CertificateDer::from(public_key.to_vec()); + Ok(Arc::new(CertifiedKey::new( + vec![public_key_as_cert], + private_key, + ))) + } + + pub fn certified_key_with_raw_pub_key( + &self, + provider: &CryptoProvider, + ) -> Result, Error> { + let private_key = provider + .key_provider + .load_private_key(self.get_key())?; + let public_key = private_key + .public_key() + .ok_or(Error::InconsistentKeys(InconsistentKeys::Unknown))?; + let public_key_as_cert = CertificateDer::from(public_key.to_vec()); + Ok(Arc::new(CertifiedKey::new( + vec![public_key_as_cert], + private_key, + ))) + } + + pub fn certified_key_with_cert_chain( + &self, + provider: &CryptoProvider, + ) -> Result, Error> { + let private_key = provider + .key_provider + .load_private_key(self.get_key())?; + Ok(Arc::new(CertifiedKey::new(self.get_chain(), private_key))) + } + + fn get_crl(&self, role: &str, r#type: &str) -> CertificateRevocationListDer<'static> { + CertificateRevocationListDer::from_pem_slice( + self.bytes_for(&format!("{role}.{type}.crl.pem")), + ) + .unwrap() + } + + pub fn ca_distinguished_name(&self) -> &'static [u8] { + match self { + KeyType::Rsa2048 => b"0\x1f1\x1d0\x1b\x06\x03U\x04\x03\x0c\x14ponytown RSA 2048 CA", + KeyType::Rsa3072 => b"0\x1f1\x1d0\x1b\x06\x03U\x04\x03\x0c\x14ponytown RSA 3072 CA", + KeyType::Rsa4096 => b"0\x1f1\x1d0\x1b\x06\x03U\x04\x03\x0c\x14ponytown RSA 4096 CA", + KeyType::EcdsaP256 => b"0\x211\x1f0\x1d\x06\x03U\x04\x03\x0c\x16ponytown ECDSA p256 CA", + KeyType::EcdsaP384 => b"0\x211\x1f0\x1d\x06\x03U\x04\x03\x0c\x16ponytown ECDSA p384 CA", + KeyType::EcdsaP521 => b"0\x211\x1f0\x1d\x06\x03U\x04\x03\x0c\x16ponytown ECDSA p521 CA", + KeyType::Ed25519 => b"0\x1c1\x1a0\x18\x06\x03U\x04\x03\x0c\x11ponytown EdDSA CA", + } + } +} + +pub fn server_config_builder( + provider: &CryptoProvider, +) -> rustls::ConfigBuilder { + rustls::ServerConfig::builder_with_provider(provider.clone().into()) + .with_safe_default_protocol_versions() + .unwrap() +} + +pub fn server_config_builder_with_versions( + versions: &[&'static rustls::SupportedProtocolVersion], + provider: &CryptoProvider, +) -> rustls::ConfigBuilder { + rustls::ServerConfig::builder_with_provider(provider.clone().into()) + .with_protocol_versions(versions) + .unwrap() +} + +pub fn client_config_builder( + provider: &CryptoProvider, +) -> rustls::ConfigBuilder { + rustls::ClientConfig::builder_with_provider(provider.clone().into()) + .with_safe_default_protocol_versions() + .unwrap() +} + +pub fn client_config_builder_with_versions( + versions: &[&'static rustls::SupportedProtocolVersion], + provider: &CryptoProvider, +) -> rustls::ConfigBuilder { + rustls::ClientConfig::builder_with_provider(provider.clone().into()) + .with_protocol_versions(versions) + .unwrap() +} + +pub fn finish_server_config( + kt: KeyType, + conf: rustls::ConfigBuilder, +) -> ServerConfig { + conf.with_no_client_auth() + .with_single_cert(kt.get_chain(), kt.get_key()) + .unwrap() +} + +pub fn make_server_config(kt: KeyType, provider: &CryptoProvider) -> ServerConfig { + finish_server_config(kt, server_config_builder(provider)) +} + +pub fn make_server_config_with_versions( + kt: KeyType, + versions: &[&'static rustls::SupportedProtocolVersion], + provider: &CryptoProvider, +) -> ServerConfig { + finish_server_config(kt, server_config_builder_with_versions(versions, provider)) +} + +pub fn make_server_config_with_kx_groups( + kt: KeyType, + kx_groups: Vec<&'static dyn rustls::crypto::SupportedKxGroup>, + provider: &CryptoProvider, +) -> ServerConfig { + finish_server_config( + kt, + ServerConfig::builder_with_provider( + CryptoProvider { + kx_groups, + ..provider.clone() + } + .into(), + ) + .with_safe_default_protocol_versions() + .unwrap(), + ) +} + +pub fn get_client_root_store(kt: KeyType) -> Arc { + // The key type's chain file contains the DER encoding of the EE cert, the intermediate cert, + // and the root trust anchor. We want only the trust anchor to build the root cert store. + let chain = kt.get_chain(); + let mut roots = RootCertStore::empty(); + roots + .add(chain.last().unwrap().clone()) + .unwrap(); + roots.into() +} + +pub fn make_server_config_with_mandatory_client_auth_crls( + kt: KeyType, + crls: Vec>, + provider: &CryptoProvider, +) -> ServerConfig { + make_server_config_with_client_verifier( + kt, + webpki_client_verifier_builder(get_client_root_store(kt), provider).with_crls(crls), + provider, + ) +} + +pub fn make_server_config_with_mandatory_client_auth( + kt: KeyType, + provider: &CryptoProvider, +) -> ServerConfig { + make_server_config_with_client_verifier( + kt, + webpki_client_verifier_builder(get_client_root_store(kt), provider), + provider, + ) +} + +pub fn make_server_config_with_optional_client_auth( + kt: KeyType, + crls: Vec>, + provider: &CryptoProvider, +) -> ServerConfig { + make_server_config_with_client_verifier( + kt, + webpki_client_verifier_builder(get_client_root_store(kt), provider) + .with_crls(crls) + .allow_unknown_revocation_status() + .allow_unauthenticated(), + provider, + ) +} + +pub fn make_server_config_with_client_verifier( + kt: KeyType, + verifier_builder: ClientCertVerifierBuilder, + provider: &CryptoProvider, +) -> ServerConfig { + server_config_builder(provider) + .with_client_cert_verifier(verifier_builder.build().unwrap()) + .with_single_cert(kt.get_chain(), kt.get_key()) + .unwrap() +} + +pub fn make_server_config_with_raw_key_support( + kt: KeyType, + provider: &CryptoProvider, +) -> ServerConfig { + let mut client_verifier = + MockClientVerifier::new(|| Ok(ClientCertVerified::assertion()), kt, provider); + let server_cert_resolver = Arc::new(AlwaysResolvesServerRawPublicKeys::new( + kt.certified_key_with_raw_pub_key(provider) + .unwrap(), + )); + client_verifier.expect_raw_public_keys = true; + // We don't support tls1.2 for Raw Public Keys, hence the version is hard-coded. + server_config_builder_with_versions(&[&rustls::version::TLS13], provider) + .with_client_cert_verifier(Arc::new(client_verifier)) + .with_cert_resolver(server_cert_resolver) +} + +pub fn make_client_config_with_raw_key_support( + kt: KeyType, + provider: &CryptoProvider, +) -> ClientConfig { + let server_verifier = Arc::new(MockServerVerifier::expects_raw_public_keys(provider)); + let client_cert_resolver = Arc::new(AlwaysResolvesClientRawPublicKeys::new( + kt.get_certified_client_key(provider) + .unwrap(), + )); + // We don't support tls1.2 for Raw Public Keys, hence the version is hard-coded. + client_config_builder_with_versions(&[&rustls::version::TLS13], provider) + .dangerous() + .with_custom_certificate_verifier(server_verifier) + .with_client_cert_resolver(client_cert_resolver) +} + +pub fn make_client_config_with_cipher_suite_and_raw_key_support( + kt: KeyType, + cipher_suite: SupportedCipherSuite, + provider: &CryptoProvider, +) -> ClientConfig { + let server_verifier = Arc::new(MockServerVerifier::expects_raw_public_keys(provider)); + let client_cert_resolver = Arc::new(AlwaysResolvesClientRawPublicKeys::new( + kt.get_certified_client_key(provider) + .unwrap(), + )); + ClientConfig::builder_with_provider( + CryptoProvider { + cipher_suites: vec![cipher_suite], + ..provider.clone() + } + .into(), + ) + .with_protocol_versions(&[&rustls::version::TLS13]) + .unwrap() + .dangerous() + .with_custom_certificate_verifier(server_verifier) + .with_client_cert_resolver(client_cert_resolver) +} + +pub fn finish_client_config( + kt: KeyType, + config: rustls::ConfigBuilder, +) -> ClientConfig { + let mut root_store = RootCertStore::empty(); + root_store.add_parsable_certificates( + CertificateDer::pem_slice_iter(kt.bytes_for("ca.cert")).map(|result| result.unwrap()), + ); + + config + .with_root_certificates(root_store) + .with_no_client_auth() +} + +pub fn finish_client_config_with_creds( + kt: KeyType, + config: rustls::ConfigBuilder, +) -> ClientConfig { + let mut root_store = RootCertStore::empty(); + root_store.add_parsable_certificates( + CertificateDer::pem_slice_iter(kt.bytes_for("ca.cert")).map(|result| result.unwrap()), + ); + + config + .with_root_certificates(root_store) + .with_client_auth_cert(kt.get_client_chain(), kt.get_client_key()) + .unwrap() +} + +pub fn make_client_config(kt: KeyType, provider: &CryptoProvider) -> ClientConfig { + finish_client_config(kt, client_config_builder(provider)) +} + +pub fn make_client_config_with_kx_groups( + kt: KeyType, + kx_groups: Vec<&'static dyn rustls::crypto::SupportedKxGroup>, + provider: &CryptoProvider, +) -> ClientConfig { + let builder = ClientConfig::builder_with_provider( + CryptoProvider { + kx_groups, + ..provider.clone() + } + .into(), + ) + .with_safe_default_protocol_versions() + .unwrap(); + finish_client_config(kt, builder) +} + +pub fn make_client_config_with_versions( + kt: KeyType, + versions: &[&'static rustls::SupportedProtocolVersion], + provider: &CryptoProvider, +) -> ClientConfig { + finish_client_config(kt, client_config_builder_with_versions(versions, provider)) +} + +pub fn make_client_config_with_auth(kt: KeyType, provider: &CryptoProvider) -> ClientConfig { + finish_client_config_with_creds(kt, client_config_builder(provider)) +} + +pub fn make_client_config_with_versions_with_auth( + kt: KeyType, + versions: &[&'static rustls::SupportedProtocolVersion], + provider: &CryptoProvider, +) -> ClientConfig { + finish_client_config_with_creds(kt, client_config_builder_with_versions(versions, provider)) +} + +pub fn make_client_config_with_verifier( + versions: &[&'static rustls::SupportedProtocolVersion], + verifier_builder: ServerCertVerifierBuilder, + provider: &CryptoProvider, +) -> ClientConfig { + client_config_builder_with_versions(versions, provider) + .dangerous() + .with_custom_certificate_verifier(verifier_builder.build().unwrap()) + .with_no_client_auth() +} + +pub fn webpki_client_verifier_builder( + roots: Arc, + provider: &CryptoProvider, +) -> ClientCertVerifierBuilder { + WebPkiClientVerifier::builder_with_provider(roots, provider.clone().into()) +} + +pub fn webpki_server_verifier_builder( + roots: Arc, + provider: &CryptoProvider, +) -> ServerCertVerifierBuilder { + WebPkiServerVerifier::builder_with_provider(roots, provider.clone().into()) +} + +pub fn make_pair(kt: KeyType, provider: &CryptoProvider) -> (ClientConnection, ServerConnection) { + make_pair_for_configs( + make_client_config(kt, provider), + make_server_config(kt, provider), + ) +} + +pub fn make_pair_for_configs( + client_config: ClientConfig, + server_config: ServerConfig, +) -> (ClientConnection, ServerConnection) { + make_pair_for_arc_configs(&Arc::new(client_config), &Arc::new(server_config)) +} + +pub fn make_pair_for_arc_configs( + client_config: &Arc, + server_config: &Arc, +) -> (ClientConnection, ServerConnection) { + ( + ClientConnection::new(Arc::clone(client_config), server_name("localhost")).unwrap(), + ServerConnection::new(Arc::clone(server_config)).unwrap(), + ) +} + +pub fn do_handshake( + client: &mut impl DerefMut>, + server: &mut impl DerefMut>, +) -> (usize, usize) { + let (mut to_client, mut to_server) = (0, 0); + while server.is_handshaking() || client.is_handshaking() { + to_server += transfer(client, server); + server.process_new_packets().unwrap(); + to_client += transfer(server, client); + client.process_new_packets().unwrap(); + } + (to_server, to_client) +} + +// Drive a handshake using unbuffered connections. +// +// Note that this drives the connection beyond the handshake until both +// connections are idle and there is no pending data waiting to be processed +// by either. In practice this just means that session tickets are processed +// by the client. +pub fn do_unbuffered_handshake( + client: &mut UnbufferedClientConnection, + server: &mut UnbufferedServerConnection, +) { + fn is_idle(conn: &UnbufferedConnectionCommon, data: &[u8]) -> bool { + !conn.is_handshaking() && !conn.wants_write() && data.is_empty() + } + + let mut client_data = Vec::with_capacity(1024); + let mut server_data = Vec::with_capacity(1024); + + while !is_idle(client, &client_data) || !is_idle(server, &server_data) { + loop { + let UnbufferedStatus { discard, state } = client.process_tls_records(&mut client_data); + let state = state.unwrap(); + + match state { + ConnectionState::BlockedHandshake | ConnectionState::WriteTraffic(_) => { + client_data.drain(..discard); + break; + } + ConnectionState::Closed | ConnectionState::PeerClosed => unreachable!(), + ConnectionState::ReadEarlyData(_) => (), + ConnectionState::EncodeTlsData(mut data) => { + let required = match data.encode(&mut []) { + Err(EncodeError::InsufficientSize(err)) => err.required_size, + _ => unreachable!(), + }; + + let old_len = server_data.len(); + server_data.resize(old_len + required, 0); + data.encode(&mut server_data[old_len..]) + .unwrap(); + } + ConnectionState::TransmitTlsData(data) => data.done(), + st => unreachable!("unexpected connection state: {st:?}"), + } + + client_data.drain(..discard); + } + + loop { + let UnbufferedStatus { discard, state } = server.process_tls_records(&mut server_data); + let state = state.unwrap(); + + match state { + ConnectionState::BlockedHandshake | ConnectionState::WriteTraffic(_) => { + server_data.drain(..discard); + break; + } + ConnectionState::Closed | ConnectionState::PeerClosed => unreachable!(), + ConnectionState::ReadEarlyData(_) => unreachable!(), + ConnectionState::EncodeTlsData(mut data) => { + let required = match data.encode(&mut []) { + Err(EncodeError::InsufficientSize(err)) => err.required_size, + _ => unreachable!(), + }; + + let old_len = client_data.len(); + client_data.resize(old_len + required, 0); + data.encode(&mut client_data[old_len..]) + .unwrap(); + } + ConnectionState::TransmitTlsData(data) => data.done(), + _ => unreachable!(), + } + + server_data.drain(..discard); + } + } + + assert!(server_data.is_empty()); + assert!(client_data.is_empty()); +} + +#[derive(PartialEq, Debug)] +pub enum ErrorFromPeer { + Client(Error), + Server(Error), +} + +pub fn do_handshake_until_error( + client: &mut ClientConnection, + server: &mut ServerConnection, +) -> Result<(), ErrorFromPeer> { + while server.is_handshaking() || client.is_handshaking() { + transfer(client, server); + server + .process_new_packets() + .map_err(ErrorFromPeer::Server)?; + transfer(server, client); + client + .process_new_packets() + .map_err(ErrorFromPeer::Client)?; + } + + Ok(()) +} + +pub fn do_handshake_altered( + client: ClientConnection, + alter_server_message: impl Fn(&mut Message) -> Altered, + alter_client_message: impl Fn(&mut Message) -> Altered, + server: ServerConnection, +) -> Result<(), ErrorFromPeer> { + let mut client: Connection = Connection::Client(client); + let mut server: Connection = Connection::Server(server); + + while server.is_handshaking() || client.is_handshaking() { + transfer_altered(&mut client, &alter_client_message, &mut server); + + server + .process_new_packets() + .map_err(ErrorFromPeer::Server)?; + + transfer_altered(&mut server, &alter_server_message, &mut client); + + client + .process_new_packets() + .map_err(ErrorFromPeer::Client)?; + } + + Ok(()) +} + +pub fn do_handshake_until_both_error( + client: &mut ClientConnection, + server: &mut ServerConnection, +) -> Result<(), Vec> { + match do_handshake_until_error(client, server) { + Err(server_err @ ErrorFromPeer::Server(_)) => { + let mut errors = vec![server_err]; + transfer(server, client); + let client_err = client + .process_new_packets() + .map_err(ErrorFromPeer::Client) + .expect_err("client didn't produce error after server error"); + errors.push(client_err); + Err(errors) + } + + Err(client_err @ ErrorFromPeer::Client(_)) => { + let mut errors = vec![client_err]; + transfer(client, server); + let server_err = server + .process_new_packets() + .map_err(ErrorFromPeer::Server) + .expect_err("server didn't produce error after client error"); + errors.push(server_err); + Err(errors) + } + + Ok(()) => Ok(()), + } +} + +pub fn server_name(name: &'static str) -> ServerName<'static> { + name.try_into().unwrap() +} + +pub struct FailsReads { + errkind: io::ErrorKind, +} + +impl FailsReads { + pub fn new(errkind: io::ErrorKind) -> Self { + Self { errkind } + } +} + +impl io::Read for FailsReads { + fn read(&mut self, _b: &mut [u8]) -> io::Result { + Err(io::Error::from(self.errkind)) + } +} + +pub fn do_suite_and_kx_test( + client_config: ClientConfig, + server_config: ServerConfig, + expect_suite: SupportedCipherSuite, + expect_kx: NamedGroup, + expect_version: ProtocolVersion, +) { + println!( + "do_suite_test {:?} {:?}", + expect_version, + expect_suite.suite() + ); + let (mut client, mut server) = make_pair_for_configs(client_config, server_config); + + assert_eq!(None, client.negotiated_cipher_suite()); + assert_eq!(None, server.negotiated_cipher_suite()); + assert!( + client + .negotiated_key_exchange_group() + .is_none() + ); + assert!( + server + .negotiated_key_exchange_group() + .is_none() + ); + assert_eq!(None, client.protocol_version()); + assert_eq!(None, server.protocol_version()); + assert!(client.is_handshaking()); + assert!(server.is_handshaking()); + + transfer(&mut client, &mut server); + server.process_new_packets().unwrap(); + + assert!(client.is_handshaking()); + assert!(server.is_handshaking()); + assert_eq!(None, client.protocol_version()); + assert_eq!(Some(expect_version), server.protocol_version()); + assert_eq!(None, client.negotiated_cipher_suite()); + assert_eq!(Some(expect_suite), server.negotiated_cipher_suite()); + assert!( + client + .negotiated_key_exchange_group() + .is_none() + ); + if matches!(expect_version, ProtocolVersion::TLSv1_2) { + assert!( + server + .negotiated_key_exchange_group() + .is_none() + ); + } else { + assert_eq!( + expect_kx, + server + .negotiated_key_exchange_group() + .unwrap() + .name() + ); + } + + transfer(&mut server, &mut client); + client.process_new_packets().unwrap(); + + assert_eq!(Some(expect_suite), client.negotiated_cipher_suite()); + assert_eq!(Some(expect_suite), server.negotiated_cipher_suite()); + assert_eq!( + expect_kx, + client + .negotiated_key_exchange_group() + .unwrap() + .name() + ); + if matches!(expect_version, ProtocolVersion::TLSv1_2) { + assert!( + server + .negotiated_key_exchange_group() + .is_none() + ); + } else { + assert_eq!( + expect_kx, + server + .negotiated_key_exchange_group() + .unwrap() + .name() + ); + } + + transfer(&mut client, &mut server); + server.process_new_packets().unwrap(); + transfer(&mut server, &mut client); + client.process_new_packets().unwrap(); + + assert!(!client.is_handshaking()); + assert!(!server.is_handshaking()); + assert_eq!(Some(expect_version), client.protocol_version()); + assert_eq!(Some(expect_version), server.protocol_version()); + assert_eq!(Some(expect_suite), client.negotiated_cipher_suite()); + assert_eq!(Some(expect_suite), server.negotiated_cipher_suite()); + assert_eq!( + expect_kx, + client + .negotiated_key_exchange_group() + .unwrap() + .name() + ); + assert_eq!( + expect_kx, + server + .negotiated_key_exchange_group() + .unwrap() + .name() + ); +} + +#[derive(Debug)] +pub struct MockServerVerifier { + cert_rejection_error: Option, + tls12_signature_error: Option, + tls13_signature_error: Option, + signature_schemes: Vec, + expected_ocsp_response: Option>, + requires_raw_public_keys: bool, + raw_public_key_algorithms: Option, +} + +impl ServerCertVerifier for MockServerVerifier { + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &ServerName<'_>, + ocsp_response: &[u8], + now: UnixTime, + ) -> Result { + println!( + "verify_server_cert({end_entity:?}, {intermediates:?}, {server_name:?}, {ocsp_response:?}, {now:?})" + ); + if let Some(expected_ocsp) = &self.expected_ocsp_response { + assert_eq!(expected_ocsp, ocsp_response); + } + match &self.cert_rejection_error { + Some(error) => Err(error.clone()), + _ => Ok(ServerCertVerified::assertion()), + } + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + println!("verify_tls12_signature({message:?}, {cert:?}, {dss:?})"); + match &self.tls12_signature_error { + Some(error) => Err(error.clone()), + _ => Ok(HandshakeSignatureValid::assertion()), + } + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + println!("verify_tls13_signature({message:?}, {cert:?}, {dss:?})"); + match &self.tls13_signature_error { + Some(error) => Err(error.clone()), + _ if self.requires_raw_public_keys => verify_tls13_signature_with_raw_key( + message, + &SubjectPublicKeyInfoDer::from(cert.as_ref()), + dss, + self.raw_public_key_algorithms + .as_ref() + .unwrap(), + ), + _ => Ok(HandshakeSignatureValid::assertion()), + } + } + + fn supported_verify_schemes(&self) -> Vec { + self.signature_schemes.clone() + } + + fn requires_raw_public_keys(&self) -> bool { + self.requires_raw_public_keys + } +} + +impl MockServerVerifier { + pub fn accepts_anything() -> Self { + MockServerVerifier { + cert_rejection_error: None, + ..Default::default() + } + } + + pub fn expects_ocsp_response(response: &[u8]) -> Self { + MockServerVerifier { + expected_ocsp_response: Some(response.to_vec()), + ..Default::default() + } + } + + pub fn rejects_certificate(err: Error) -> Self { + MockServerVerifier { + cert_rejection_error: Some(err), + ..Default::default() + } + } + + pub fn rejects_tls12_signatures(err: Error) -> Self { + MockServerVerifier { + tls12_signature_error: Some(err), + ..Default::default() + } + } + + pub fn rejects_tls13_signatures(err: Error) -> Self { + MockServerVerifier { + tls13_signature_error: Some(err), + ..Default::default() + } + } + + pub fn offers_no_signature_schemes() -> Self { + MockServerVerifier { + signature_schemes: vec![], + ..Default::default() + } + } + + pub fn expects_raw_public_keys(provider: &CryptoProvider) -> Self { + MockServerVerifier { + requires_raw_public_keys: true, + raw_public_key_algorithms: Some(provider.signature_verification_algorithms), + ..Default::default() + } + } +} + +impl Default for MockServerVerifier { + fn default() -> Self { + MockServerVerifier { + cert_rejection_error: None, + tls12_signature_error: None, + tls13_signature_error: None, + signature_schemes: vec![ + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::RSA_PKCS1_SHA256, + SignatureScheme::ED25519, + SignatureScheme::ECDSA_NISTP256_SHA256, + SignatureScheme::ECDSA_NISTP384_SHA384, + SignatureScheme::ECDSA_NISTP521_SHA512, + ], + expected_ocsp_response: None, + requires_raw_public_keys: false, + raw_public_key_algorithms: None, + } + } +} + +#[derive(Debug)] +pub struct MockClientVerifier { + pub verified: fn() -> Result, + pub subjects: Vec, + pub mandatory: bool, + pub offered_schemes: Option>, + expect_raw_public_keys: bool, + raw_public_key_algorithms: Option, + parent: Arc, +} + +impl MockClientVerifier { + pub fn new( + verified: fn() -> Result, + kt: KeyType, + provider: &CryptoProvider, + ) -> Self { + Self { + parent: webpki_client_verifier_builder(get_client_root_store(kt), provider) + .build() + .unwrap(), + verified, + subjects: get_client_root_store(kt).subjects(), + mandatory: true, + offered_schemes: None, + expect_raw_public_keys: false, + raw_public_key_algorithms: Some(provider.signature_verification_algorithms), + } + } +} + +impl ClientCertVerifier for MockClientVerifier { + fn client_auth_mandatory(&self) -> bool { + self.mandatory + } + + fn root_hint_subjects(&self) -> &[DistinguishedName] { + &self.subjects + } + + fn verify_client_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _now: UnixTime, + ) -> Result { + (self.verified)() + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + if self.expect_raw_public_keys { + Ok(HandshakeSignatureValid::assertion()) + } else { + self.parent + .verify_tls12_signature(message, cert, dss) + } + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + if self.expect_raw_public_keys { + verify_tls13_signature_with_raw_key( + message, + &SubjectPublicKeyInfoDer::from(cert.as_ref()), + dss, + self.raw_public_key_algorithms + .as_ref() + .unwrap(), + ) + } else { + self.parent + .verify_tls13_signature(message, cert, dss) + } + } + + fn supported_verify_schemes(&self) -> Vec { + if let Some(schemes) = &self.offered_schemes { + schemes.clone() + } else { + self.parent.supported_verify_schemes() + } + } + + fn requires_raw_public_keys(&self) -> bool { + self.expect_raw_public_keys + } +} + +/// This allows injection/receipt of raw messages into a post-handshake connection. +/// +/// It consumes one of the peers, extracts its secrets, and then reconstitutes the +/// message encrypter/decrypter. It does not do fragmentation/joining. +pub struct RawTls { + encrypter: Box, + enc_seq: u64, + decrypter: Box, + dec_seq: u64, +} + +impl RawTls { + /// conn must be post-handshake, and must have been created with `enable_secret_extraction` + pub fn new_client(conn: ClientConnection) -> Self { + let suite = conn.negotiated_cipher_suite().unwrap(); + Self::new( + suite, + conn.dangerous_extract_secrets() + .unwrap(), + ) + } + + /// conn must be post-handshake, and must have been created with `enable_secret_extraction` + pub fn new_server(conn: ServerConnection) -> Self { + let suite = conn.negotiated_cipher_suite().unwrap(); + Self::new( + suite, + conn.dangerous_extract_secrets() + .unwrap(), + ) + } + + fn new(suite: SupportedCipherSuite, secrets: rustls::ExtractedSecrets) -> Self { + let rustls::ExtractedSecrets { + tx: (tx_seq, tx_keys), + rx: (rx_seq, rx_keys), + } = secrets; + + let encrypter = match (tx_keys, suite) { + ( + rustls::ConnectionTrafficSecrets::Aes256Gcm { key, iv }, + SupportedCipherSuite::Tls13(tls13), + ) => tls13.aead_alg.encrypter(key, iv), + + ( + rustls::ConnectionTrafficSecrets::Aes256Gcm { key, iv }, + SupportedCipherSuite::Tls12(tls12), + ) => tls12 + .aead_alg + .encrypter(key, &iv.as_ref()[..4], &iv.as_ref()[4..]), + + _ => todo!(), + }; + + let decrypter = match (rx_keys, suite) { + ( + rustls::ConnectionTrafficSecrets::Aes256Gcm { key, iv }, + SupportedCipherSuite::Tls13(tls13), + ) => tls13.aead_alg.decrypter(key, iv), + + ( + rustls::ConnectionTrafficSecrets::Aes256Gcm { key, iv }, + SupportedCipherSuite::Tls12(tls12), + ) => tls12 + .aead_alg + .decrypter(key, &iv.as_ref()[..4]), + + _ => todo!(), + }; + + Self { + encrypter, + enc_seq: tx_seq, + decrypter, + dec_seq: rx_seq, + } + } + + pub fn encrypt_and_send( + &mut self, + msg: &PlainMessage, + peer: &mut impl DerefMut>, + ) { + let data = self + .encrypter + .encrypt(msg.borrow_outbound(), self.enc_seq) + .unwrap() + .encode(); + self.enc_seq += 1; + peer.read_tls(&mut io::Cursor::new(data)) + .unwrap(); + } + + pub fn receive_and_decrypt( + &mut self, + peer: &mut impl DerefMut>, + f: impl Fn(Message), + ) { + let mut data = vec![]; + peer.write_tls(&mut io::Cursor::new(&mut data)) + .unwrap(); + + let mut reader = Reader::init(&data); + let content_type = ContentType::read(&mut reader).unwrap(); + let version = ProtocolVersion::read(&mut reader).unwrap(); + let len = u16::read(&mut reader).unwrap(); + let left = &mut data[5..]; + assert_eq!(len as usize, left.len()); + + let inbound = InboundOpaqueMessage::new(content_type, version, left); + let plain = self + .decrypter + .decrypt(inbound, self.dec_seq) + .unwrap(); + self.dec_seq += 1; + + let msg = Message::try_from(plain).unwrap(); + println!("receive_and_decrypt: {msg:?}"); + + f(msg); + } +} + +pub fn aes_128_gcm_with_1024_confidentiality_limit( + provider: CryptoProvider, +) -> Arc { + const CONFIDENTIALITY_LIMIT: u64 = 1024; + + // needed to extend lifetime of Tls13CipherSuite to 'static + static TLS13_LIMITED_SUITE: OnceLock = OnceLock::new(); + static TLS12_LIMITED_SUITE: OnceLock = OnceLock::new(); + + let tls13_limited = TLS13_LIMITED_SUITE.get_or_init(|| { + let tls13 = provider + .cipher_suites + .iter() + .find(|cs| cs.suite() == CipherSuite::TLS13_AES_128_GCM_SHA256) + .unwrap() + .tls13() + .unwrap(); + + rustls::Tls13CipherSuite { + common: rustls::crypto::CipherSuiteCommon { + confidentiality_limit: CONFIDENTIALITY_LIMIT, + ..tls13.common + }, + ..*tls13 + } + }); + + let tls12_limited = TLS12_LIMITED_SUITE.get_or_init(|| { + let SupportedCipherSuite::Tls12(tls12) = *provider + .cipher_suites + .iter() + .find(|cs| cs.suite() == CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .unwrap() + else { + unreachable!(); + }; + + rustls::Tls12CipherSuite { + common: rustls::crypto::CipherSuiteCommon { + confidentiality_limit: CONFIDENTIALITY_LIMIT, + ..tls12.common + }, + ..*tls12 + } + }); + + CryptoProvider { + cipher_suites: vec![ + SupportedCipherSuite::Tls13(tls13_limited), + SupportedCipherSuite::Tls12(tls12_limited), + ], + ..provider + } + .into() +} + +pub fn unsafe_plaintext_crypto_provider(provider: CryptoProvider) -> Arc { + static TLS13_PLAIN_SUITE: OnceLock = OnceLock::new(); + + let tls13 = TLS13_PLAIN_SUITE.get_or_init(|| { + let tls13 = provider + .cipher_suites + .iter() + .find(|cs| cs.suite() == CipherSuite::TLS13_AES_256_GCM_SHA384) + .unwrap() + .tls13() + .unwrap(); + + rustls::Tls13CipherSuite { + aead_alg: &plaintext::Aead, + common: rustls::crypto::CipherSuiteCommon { ..tls13.common }, + ..*tls13 + } + }); + + CryptoProvider { + cipher_suites: vec![SupportedCipherSuite::Tls13(tls13)], + ..provider + } + .into() +} + +mod plaintext { + use rustls::ConnectionTrafficSecrets; + use rustls::crypto::cipher::{ + AeadKey, InboundOpaqueMessage, InboundPlainMessage, Iv, MessageDecrypter, MessageEncrypter, + OutboundPlainMessage, PrefixedPayload, Tls13AeadAlgorithm, UnsupportedOperationError, + }; + + use super::*; + + pub(super) struct Aead; + + impl Tls13AeadAlgorithm for Aead { + fn encrypter(&self, _key: AeadKey, _iv: Iv) -> Box { + Box::new(Encrypter) + } + + fn decrypter(&self, _key: AeadKey, _iv: Iv) -> Box { + Box::new(Decrypter) + } + + fn key_len(&self) -> usize { + 32 + } + + fn extract_keys( + &self, + _key: AeadKey, + _iv: Iv, + ) -> Result { + Err(UnsupportedOperationError) + } + } + + struct Encrypter; + + impl MessageEncrypter for Encrypter { + fn encrypt( + &mut self, + msg: OutboundPlainMessage<'_>, + _seq: u64, + ) -> Result { + let mut payload = PrefixedPayload::with_capacity(msg.payload.len()); + payload.extend_from_chunks(&msg.payload); + + Ok(OutboundOpaqueMessage::new( + ContentType::ApplicationData, + ProtocolVersion::TLSv1_2, + payload, + )) + } + + fn encrypted_payload_len(&self, payload_len: usize) -> usize { + payload_len + } + } + + struct Decrypter; + + impl MessageDecrypter for Decrypter { + fn decrypt<'a>( + &mut self, + msg: InboundOpaqueMessage<'a>, + _seq: u64, + ) -> Result, Error> { + Ok(msg.into_plain_message()) + } + } +} + +/// Deeply inefficient, test-only TLS encoding helpers +pub mod encoding { + use rustls::internal::msgs::codec::Codec; + use rustls::internal::msgs::enums::ExtensionType; + use rustls::{ + CipherSuite, ContentType, HandshakeType, NamedGroup, ProtocolVersion, SignatureScheme, + }; + + /// Return a client hello with mandatory extensions added to `extensions` + /// + /// The returned bytes are handshake-framed, but not message-framed. + pub fn basic_client_hello(mut extensions: Vec) -> Vec { + extensions.push(Extension::new_kx_groups()); + extensions.push(Extension::new_sig_algs()); + extensions.push(Extension::new_versions()); + extensions.push(Extension::new_dummy_key_share()); + client_hello_with_extensions(extensions) + } + + /// Return a client hello with exactly `extensions` + /// + /// The returned bytes are handshake-framed, but not message-framed. + pub fn client_hello_with_extensions(extensions: Vec) -> Vec { + client_hello( + ProtocolVersion::TLSv1_2, + &[0u8; 32], + &[0], + vec![ + CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + CipherSuite::TLS13_AES_128_GCM_SHA256, + ], + extensions, + ) + } + + pub fn client_hello( + legacy_version: ProtocolVersion, + random: &[u8; 32], + session_id: &[u8], + cipher_suites: Vec, + extensions: Vec, + ) -> Vec { + let mut out = vec![]; + + legacy_version.encode(&mut out); + out.extend_from_slice(random); + out.extend_from_slice(session_id); + cipher_suites.to_vec().encode(&mut out); + out.extend_from_slice(&[0x01, 0x00]); // only null compression + + let mut exts = vec![]; + for e in extensions { + e.typ.encode(&mut exts); + exts.extend_from_slice(&(e.body.len() as u16).to_be_bytes()); + exts.extend_from_slice(&e.body); + } + + out.extend(len_u16(exts)); + handshake_framing(HandshakeType::ClientHello, out) + } + + /// Apply handshake framing to `body`. + /// + /// This does not do fragmentation. + pub fn handshake_framing(ty: HandshakeType, body: Vec) -> Vec { + let mut body = len_u24(body); + body.splice(0..0, ty.to_array()); + body + } + + /// Apply message framing to `body`. + pub fn message_framing(ty: ContentType, vers: ProtocolVersion, body: Vec) -> Vec { + let mut body = len_u16(body); + body.splice(0..0, vers.to_array()); + body.splice(0..0, ty.to_array()); + body + } + + #[derive(Clone)] + pub struct Extension { + pub typ: ExtensionType, + pub body: Vec, + } + + impl Extension { + pub fn new_sig_algs() -> Extension { + Extension { + typ: ExtensionType::SignatureAlgorithms, + body: len_u16( + SignatureScheme::RSA_PKCS1_SHA256 + .to_array() + .to_vec(), + ), + } + } + + pub fn new_kx_groups() -> Extension { + Extension { + typ: ExtensionType::EllipticCurves, + body: len_u16(vector_of([NamedGroup::secp256r1].into_iter())), + } + } + + pub fn new_versions() -> Extension { + Extension { + typ: ExtensionType::SupportedVersions, + body: len_u8(vector_of( + [ProtocolVersion::TLSv1_3, ProtocolVersion::TLSv1_2].into_iter(), + )), + } + } + + pub fn new_dummy_key_share() -> Extension { + const SOME_POINT_ON_P256: &[u8] = &[ + 4, 41, 39, 177, 5, 18, 186, 227, 237, 220, 254, 70, 120, 40, 18, 139, 173, 41, 3, + 38, 153, 25, 247, 8, 96, 105, 200, 196, 223, 108, 115, 40, 56, 199, 120, 121, 100, + 234, 172, 0, 229, 146, 31, 177, 73, 138, 96, 244, 96, 103, 102, 179, 217, 104, 80, + 1, 85, 141, 26, 151, 78, 115, 65, 81, 62, + ]; + + let mut share = len_u16(SOME_POINT_ON_P256.to_vec()); + share.splice(0..0, NamedGroup::secp256r1.to_array()); + + Extension { + typ: ExtensionType::KeyShare, + body: len_u16(share), + } + } + } + + /// Prefix with u8 length + pub fn len_u8(mut body: Vec) -> Vec { + body.splice(0..0, [body.len() as u8]); + body + } + + /// Prefix with u16 length + pub fn len_u16(mut body: Vec) -> Vec { + body.splice(0..0, (body.len() as u16).to_be_bytes()); + body + } + + /// Prefix with u24 length + pub fn len_u24(mut body: Vec) -> Vec { + let len = (body.len() as u32).to_be_bytes(); + body.insert(0, len[1]); + body.insert(1, len[2]); + body.insert(2, len[3]); + body + } + + /// Encode each of `items` + pub fn vector_of<'a, T: Codec<'a>>(items: impl Iterator) -> Vec { + let mut body = Vec::new(); + + for i in items { + i.encode(&mut body); + } + body + } +} diff --git a/rustls/Cargo.toml b/rustls/Cargo.toml index 7d93128a7e..b0ab257849 100644 --- a/rustls/Cargo.toml +++ b/rustls/Cargo.toml @@ -57,6 +57,7 @@ log = { workspace = true } macro_rules_attribute = { workspace = true } num-bigint = { workspace = true } rcgen = { workspace = true } +rustls-test = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } time = { workspace = true } diff --git a/rustls/benches/benchmarks.rs b/rustls/benches/benchmarks.rs index 85c2013ff9..070e6cf8e5 100644 --- a/rustls/benches/benchmarks.rs +++ b/rustls/benches/benchmarks.rs @@ -4,13 +4,11 @@ use bencher::{Bencher, benchmark_group, benchmark_main}; use rustls::crypto::ring as provider; -#[path = "../tests/common/mod.rs"] -mod test_utils; use std::io; use std::sync::Arc; use rustls::ServerConnection; -use test_utils::*; +use rustls_test::{FailsReads, KeyType, make_server_config}; fn bench_ewouldblock(c: &mut Bencher) { let server_config = make_server_config(KeyType::Rsa2048, &provider::default_provider()); diff --git a/rustls/tests/common/mod.rs b/rustls/tests/common/mod.rs index 9186a451e2..02541e9182 100644 --- a/rustls/tests/common/mod.rs +++ b/rustls/tests/common/mod.rs @@ -1,454 +1,13 @@ #![allow(dead_code)] #![allow(clippy::disallowed_types, clippy::duplicate_mod)] -use std::io; -use std::ops::DerefMut; -use std::sync::OnceLock; - -use pki_types::pem::PemObject; -use pki_types::{ - CertificateDer, CertificateRevocationListDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName, - SubjectPublicKeyInfoDer, UnixTime, -}; - -use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; -use rustls::client::{ - AlwaysResolvesClientRawPublicKeys, ServerCertVerifierBuilder, UnbufferedClientConnection, - WebPkiServerVerifier, -}; -use rustls::crypto::cipher::{InboundOpaqueMessage, MessageDecrypter, MessageEncrypter}; -use rustls::crypto::{ - CryptoProvider, WebPkiSupportedAlgorithms, verify_tls13_signature_with_raw_key, -}; -use rustls::internal::msgs::codec::{Codec, Reader}; -use rustls::internal::msgs::message::{Message, OutboundOpaqueMessage, PlainMessage}; -use rustls::server::danger::{ClientCertVerified, ClientCertVerifier}; -use rustls::server::{ - AlwaysResolvesServerRawPublicKeys, ClientCertVerifierBuilder, UnbufferedServerConnection, - WebPkiClientVerifier, -}; -use rustls::sign::CertifiedKey; -use rustls::unbuffered::{ - ConnectionState, EncodeError, UnbufferedConnectionCommon, UnbufferedStatus, -}; -use rustls::{ - CipherSuite, ClientConfig, ClientConnection, Connection, ConnectionCommon, ContentType, - DigitallySignedStruct, DistinguishedName, Error, InconsistentKeys, NamedGroup, ProtocolVersion, - RootCertStore, ServerConfig, ServerConnection, SideData, SignatureScheme, SupportedCipherSuite, -}; - -// Import `Arc` here for tests - can be overwritten to test with another `Arc` such as `portable_atomic_util::Arc` pub use std::sync::Arc; -macro_rules! embed_files { - ( - $( - ($name:ident, $keytype:expr, $path:expr); - )+ - ) => { - $( - const $name: &'static [u8] = include_bytes!( - concat!("../../../test-ca/", $keytype, "/", $path)); - )+ - - pub fn bytes_for(keytype: &str, path: &str) -> &'static [u8] { - match (keytype, path) { - $( - ($keytype, $path) => $name, - )+ - _ => panic!("unknown keytype {} with path {}", keytype, path), - } - } - } -} - -embed_files! { - (ECDSA_P256_END_PEM_SPKI, "ecdsa-p256", "end.spki.pem"); - (ECDSA_P256_CLIENT_PEM_SPKI, "ecdsa-p256", "client.spki.pem"); - (ECDSA_P256_CA_CERT, "ecdsa-p256", "ca.cert"); - (ECDSA_P256_CA_DER, "ecdsa-p256", "ca.der"); - (ECDSA_P256_CA_KEY, "ecdsa-p256", "ca.key"); - (ECDSA_P256_CLIENT_CERT, "ecdsa-p256", "client.cert"); - (ECDSA_P256_CLIENT_CHAIN, "ecdsa-p256", "client.chain"); - (ECDSA_P256_CLIENT_FULLCHAIN, "ecdsa-p256", "client.fullchain"); - (ECDSA_P256_CLIENT_KEY, "ecdsa-p256", "client.key"); - (ECDSA_P256_END_CRL_PEM, "ecdsa-p256", "end.revoked.crl.pem"); - (ECDSA_P256_CLIENT_CRL_PEM, "ecdsa-p256", "client.revoked.crl.pem"); - (ECDSA_P256_INTERMEDIATE_CRL_PEM, "ecdsa-p256", "inter.revoked.crl.pem"); - (ECDSA_P256_EXPIRED_CRL_PEM, "ecdsa-p256", "end.expired.crl.pem"); - (ECDSA_P256_END_CERT, "ecdsa-p256", "end.cert"); - (ECDSA_P256_END_CHAIN, "ecdsa-p256", "end.chain"); - (ECDSA_P256_END_FULLCHAIN, "ecdsa-p256", "end.fullchain"); - (ECDSA_P256_END_KEY, "ecdsa-p256", "end.key"); - (ECDSA_P256_INTER_CERT, "ecdsa-p256", "inter.cert"); - (ECDSA_P256_INTER_KEY, "ecdsa-p256", "inter.key"); - - (ECDSA_P384_END_PEM_SPKI, "ecdsa-p384", "end.spki.pem"); - (ECDSA_P384_CLIENT_PEM_SPKI, "ecdsa-p384", "client.spki.pem"); - (ECDSA_P384_CA_CERT, "ecdsa-p384", "ca.cert"); - (ECDSA_P384_CA_DER, "ecdsa-p384", "ca.der"); - (ECDSA_P384_CA_KEY, "ecdsa-p384", "ca.key"); - (ECDSA_P384_CLIENT_CERT, "ecdsa-p384", "client.cert"); - (ECDSA_P384_CLIENT_CHAIN, "ecdsa-p384", "client.chain"); - (ECDSA_P384_CLIENT_FULLCHAIN, "ecdsa-p384", "client.fullchain"); - (ECDSA_P384_CLIENT_KEY, "ecdsa-p384", "client.key"); - (ECDSA_P384_END_CRL_PEM, "ecdsa-p384", "end.revoked.crl.pem"); - (ECDSA_P384_CLIENT_CRL_PEM, "ecdsa-p384", "client.revoked.crl.pem"); - (ECDSA_P384_INTERMEDIATE_CRL_PEM, "ecdsa-p384", "inter.revoked.crl.pem"); - (ECDSA_P384_EXPIRED_CRL_PEM, "ecdsa-p384", "end.expired.crl.pem"); - (ECDSA_P384_END_CERT, "ecdsa-p384", "end.cert"); - (ECDSA_P384_END_CHAIN, "ecdsa-p384", "end.chain"); - (ECDSA_P384_END_FULLCHAIN, "ecdsa-p384", "end.fullchain"); - (ECDSA_P384_END_KEY, "ecdsa-p384", "end.key"); - (ECDSA_P384_INTER_CERT, "ecdsa-p384", "inter.cert"); - (ECDSA_P384_INTER_KEY, "ecdsa-p384", "inter.key"); - - (ECDSA_P521_END_PEM_SPKI, "ecdsa-p521", "end.spki.pem"); - (ECDSA_P521_CLIENT_PEM_SPKI, "ecdsa-p521", "client.spki.pem"); - (ECDSA_P521_CA_CERT, "ecdsa-p521", "ca.cert"); - (ECDSA_P521_CA_DER, "ecdsa-p521", "ca.der"); - (ECDSA_P521_CA_KEY, "ecdsa-p521", "ca.key"); - (ECDSA_P521_CLIENT_CERT, "ecdsa-p521", "client.cert"); - (ECDSA_P521_CLIENT_CHAIN, "ecdsa-p521", "client.chain"); - (ECDSA_P521_CLIENT_FULLCHAIN, "ecdsa-p521", "client.fullchain"); - (ECDSA_P521_CLIENT_KEY, "ecdsa-p521", "client.key"); - (ECDSA_P521_END_CRL_PEM, "ecdsa-p521", "end.revoked.crl.pem"); - (ECDSA_P521_CLIENT_CRL_PEM, "ecdsa-p521", "client.revoked.crl.pem"); - (ECDSA_P521_INTERMEDIATE_CRL_PEM, "ecdsa-p521", "inter.revoked.crl.pem"); - (ECDSA_P521_EXPIRED_CRL_PEM, "ecdsa-p521", "end.expired.crl.pem"); - (ECDSA_P521_END_CERT, "ecdsa-p521", "end.cert"); - (ECDSA_P521_END_CHAIN, "ecdsa-p521", "end.chain"); - (ECDSA_P521_END_FULLCHAIN, "ecdsa-p521", "end.fullchain"); - (ECDSA_P521_END_KEY, "ecdsa-p521", "end.key"); - (ECDSA_P521_INTER_CERT, "ecdsa-p521", "inter.cert"); - (ECDSA_P521_INTER_KEY, "ecdsa-p521", "inter.key"); - - (EDDSA_END_PEM_SPKI, "eddsa", "end.spki.pem"); - (EDDSA_CLIENT_PEM_SPKI, "eddsa", "client.spki.pem"); - (EDDSA_CA_CERT, "eddsa", "ca.cert"); - (EDDSA_CA_DER, "eddsa", "ca.der"); - (EDDSA_CA_KEY, "eddsa", "ca.key"); - (EDDSA_CLIENT_CERT, "eddsa", "client.cert"); - (EDDSA_CLIENT_CHAIN, "eddsa", "client.chain"); - (EDDSA_CLIENT_FULLCHAIN, "eddsa", "client.fullchain"); - (EDDSA_CLIENT_KEY, "eddsa", "client.key"); - (EDDSA_END_CRL_PEM, "eddsa", "end.revoked.crl.pem"); - (EDDSA_CLIENT_CRL_PEM, "eddsa", "client.revoked.crl.pem"); - (EDDSA_INTERMEDIATE_CRL_PEM, "eddsa", "inter.revoked.crl.pem"); - (EDDSA_EXPIRED_CRL_PEM, "eddsa", "end.expired.crl.pem"); - (EDDSA_END_CERT, "eddsa", "end.cert"); - (EDDSA_END_CHAIN, "eddsa", "end.chain"); - (EDDSA_END_FULLCHAIN, "eddsa", "end.fullchain"); - (EDDSA_END_KEY, "eddsa", "end.key"); - (EDDSA_INTER_CERT, "eddsa", "inter.cert"); - (EDDSA_INTER_KEY, "eddsa", "inter.key"); - - (RSA_2048_END_PEM_SPKI, "rsa-2048", "end.spki.pem"); - (RSA_2048_CLIENT_PEM_SPKI, "rsa-2048", "client.spki.pem"); - (RSA_2048_CA_CERT, "rsa-2048", "ca.cert"); - (RSA_2048_CA_DER, "rsa-2048", "ca.der"); - (RSA_2048_CA_KEY, "rsa-2048", "ca.key"); - (RSA_2048_CLIENT_CERT, "rsa-2048", "client.cert"); - (RSA_2048_CLIENT_CHAIN, "rsa-2048", "client.chain"); - (RSA_2048_CLIENT_FULLCHAIN, "rsa-2048", "client.fullchain"); - (RSA_2048_CLIENT_KEY, "rsa-2048", "client.key"); - (RSA_2048_END_CRL_PEM, "rsa-2048", "end.revoked.crl.pem"); - (RSA_2048_CLIENT_CRL_PEM, "rsa-2048", "client.revoked.crl.pem"); - (RSA_2048_INTERMEDIATE_CRL_PEM, "rsa-2048", "inter.revoked.crl.pem"); - (RSA_2048_EXPIRED_CRL_PEM, "rsa-2048", "end.expired.crl.pem"); - (RSA_2048_END_CERT, "rsa-2048", "end.cert"); - (RSA_2048_END_CHAIN, "rsa-2048", "end.chain"); - (RSA_2048_END_FULLCHAIN, "rsa-2048", "end.fullchain"); - (RSA_2048_END_KEY, "rsa-2048", "end.key"); - (RSA_2048_INTER_CERT, "rsa-2048", "inter.cert"); - (RSA_2048_INTER_KEY, "rsa-2048", "inter.key"); - - (RSA_3072_END_PEM_SPKI, "rsa-3072", "end.spki.pem"); - (RSA_3072_CLIENT_PEM_SPKI, "rsa-3072", "client.spki.pem"); - (RSA_3072_CA_CERT, "rsa-3072", "ca.cert"); - (RSA_3072_CA_DER, "rsa-3072", "ca.der"); - (RSA_3072_CA_KEY, "rsa-3072", "ca.key"); - (RSA_3072_CLIENT_CERT, "rsa-3072", "client.cert"); - (RSA_3072_CLIENT_CHAIN, "rsa-3072", "client.chain"); - (RSA_3072_CLIENT_FULLCHAIN, "rsa-3072", "client.fullchain"); - (RSA_3072_CLIENT_KEY, "rsa-3072", "client.key"); - (RSA_3072_END_CRL_PEM, "rsa-3072", "end.revoked.crl.pem"); - (RSA_3072_CLIENT_CRL_PEM, "rsa-3072", "client.revoked.crl.pem"); - (RSA_3072_INTERMEDIATE_CRL_PEM, "rsa-3072", "inter.revoked.crl.pem"); - (RSA_3072_EXPIRED_CRL_PEM, "rsa-3072", "end.expired.crl.pem"); - (RSA_3072_END_CERT, "rsa-3072", "end.cert"); - (RSA_3072_END_CHAIN, "rsa-3072", "end.chain"); - (RSA_3072_END_FULLCHAIN, "rsa-3072", "end.fullchain"); - (RSA_3072_END_KEY, "rsa-3072", "end.key"); - (RSA_3072_INTER_CERT, "rsa-3072", "inter.cert"); - (RSA_3072_INTER_KEY, "rsa-3072", "inter.key"); - - (RSA_4096_END_PEM_SPKI, "rsa-4096", "end.spki.pem"); - (RSA_4096_CLIENT_PEM_SPKI, "rsa-4096", "client.spki.pem"); - (RSA_4096_CA_CERT, "rsa-4096", "ca.cert"); - (RSA_4096_CA_DER, "rsa-4096", "ca.der"); - (RSA_4096_CA_KEY, "rsa-4096", "ca.key"); - (RSA_4096_CLIENT_CERT, "rsa-4096", "client.cert"); - (RSA_4096_CLIENT_CHAIN, "rsa-4096", "client.chain"); - (RSA_4096_CLIENT_FULLCHAIN, "rsa-4096", "client.fullchain"); - (RSA_4096_CLIENT_KEY, "rsa-4096", "client.key"); - (RSA_4096_END_CRL_PEM, "rsa-4096", "end.revoked.crl.pem"); - (RSA_4096_CLIENT_CRL_PEM, "rsa-4096", "client.revoked.crl.pem"); - (RSA_4096_INTERMEDIATE_CRL_PEM, "rsa-4096", "inter.revoked.crl.pem"); - (RSA_4096_EXPIRED_CRL_PEM, "rsa-4096", "end.expired.crl.pem"); - (RSA_4096_END_CERT, "rsa-4096", "end.cert"); - (RSA_4096_END_CHAIN, "rsa-4096", "end.chain"); - (RSA_4096_END_FULLCHAIN, "rsa-4096", "end.fullchain"); - (RSA_4096_END_KEY, "rsa-4096", "end.key"); - (RSA_4096_INTER_CERT, "rsa-4096", "inter.cert"); - (RSA_4096_INTER_KEY, "rsa-4096", "inter.key"); -} - -pub fn transfer( - left: &mut impl DerefMut>, - right: &mut impl DerefMut>, -) -> usize { - let mut buf = [0u8; 262144]; - let mut total = 0; - - while left.wants_write() { - let sz = { - let into_buf: &mut dyn io::Write = &mut &mut buf[..]; - left.write_tls(into_buf).unwrap() - }; - total += sz; - if sz == 0 { - return total; - } - - let mut offs = 0; - loop { - let from_buf: &mut dyn io::Read = &mut &buf[offs..sz]; - offs += right.read_tls(from_buf).unwrap(); - if sz == offs { - break; - } - } - } - - total -} - -pub fn transfer_eof(conn: &mut impl DerefMut>) { - let empty_buf = [0u8; 0]; - let empty_cursor: &mut dyn io::Read = &mut &empty_buf[..]; - let sz = conn.read_tls(empty_cursor).unwrap(); - assert_eq!(sz, 0); -} - -pub enum Altered { - /// message has been edited in-place (or is unchanged) - InPlace, - /// send these raw bytes instead of the message. - Raw(Vec), -} - -pub fn transfer_altered(left: &mut Connection, filter: F, right: &mut Connection) -> usize -where - F: Fn(&mut Message) -> Altered, -{ - let mut buf = [0u8; 262144]; - let mut total = 0; - - while left.wants_write() { - let sz = { - let into_buf: &mut dyn io::Write = &mut &mut buf[..]; - left.write_tls(into_buf).unwrap() - }; - total += sz; - if sz == 0 { - return total; - } - - let mut reader = Reader::init(&buf[..sz]); - while reader.any_left() { - let message = OutboundOpaqueMessage::read(&mut reader).unwrap(); - - // this is a bit of a falsehood: we don't know whether message - // is encrypted. it is quite unlikely that a genuine encrypted - // message can be decoded by `Message::try_from`. - let plain = message.into_plain_message(); - - let message_enc = match Message::try_from(plain.clone()) { - Ok(mut message) => match filter(&mut message) { - Altered::InPlace => PlainMessage::from(message) - .into_unencrypted_opaque() - .encode(), - Altered::Raw(data) => data, - }, - // pass through encrypted/undecodable messages - Err(_) => plain.into_unencrypted_opaque().encode(), - }; - - let message_enc_reader: &mut dyn io::Read = &mut &message_enc[..]; - let len = right - .read_tls(message_enc_reader) - .unwrap(); - assert_eq!(len, message_enc.len()); - } - } - - total -} - -#[derive(Clone, Copy, Debug, PartialEq)] -pub enum KeyType { - Rsa2048, - Rsa3072, - Rsa4096, - EcdsaP256, - EcdsaP384, - EcdsaP521, - Ed25519, -} - -pub static ALL_KEY_TYPES: &[KeyType] = &[ - KeyType::Rsa2048, - KeyType::Rsa3072, - KeyType::Rsa4096, - KeyType::EcdsaP256, - KeyType::EcdsaP384, - #[cfg(all(not(feature = "ring"), feature = "aws_lc_rs"))] - KeyType::EcdsaP521, - KeyType::Ed25519, -]; - -impl KeyType { - fn bytes_for(&self, part: &str) -> &'static [u8] { - match self { - Self::Rsa2048 => bytes_for("rsa-2048", part), - Self::Rsa3072 => bytes_for("rsa-3072", part), - Self::Rsa4096 => bytes_for("rsa-4096", part), - Self::EcdsaP256 => bytes_for("ecdsa-p256", part), - Self::EcdsaP384 => bytes_for("ecdsa-p384", part), - Self::EcdsaP521 => bytes_for("ecdsa-p521", part), - Self::Ed25519 => bytes_for("eddsa", part), - } - } - - pub fn ca_cert(&self) -> CertificateDer<'_> { - self.get_chain() - .into_iter() - .next_back() - .expect("cert chain cannot be empty") - } - - pub fn get_chain(&self) -> Vec> { - CertificateDer::pem_slice_iter(self.bytes_for("end.fullchain")) - .map(|result| result.unwrap()) - .collect() - } - - pub fn get_spki(&self) -> SubjectPublicKeyInfoDer<'static> { - SubjectPublicKeyInfoDer::from_pem_slice(self.bytes_for("end.spki.pem")).unwrap() - } - - pub fn get_key(&self) -> PrivateKeyDer<'static> { - PrivatePkcs8KeyDer::from_pem_slice(self.bytes_for("end.key")) - .unwrap() - .into() - } - - pub fn get_client_chain(&self) -> Vec> { - CertificateDer::pem_slice_iter(self.bytes_for("client.fullchain")) - .map(|result| result.unwrap()) - .collect() - } - - pub fn end_entity_crl(&self) -> CertificateRevocationListDer<'static> { - self.get_crl("end", "revoked") - } - - pub fn client_crl(&self) -> CertificateRevocationListDer<'static> { - self.get_crl("client", "revoked") - } - - pub fn intermediate_crl(&self) -> CertificateRevocationListDer<'static> { - self.get_crl("inter", "revoked") - } - - pub fn end_entity_crl_expired(&self) -> CertificateRevocationListDer<'static> { - self.get_crl("end", "expired") - } - - pub fn get_client_key(&self) -> PrivateKeyDer<'static> { - PrivatePkcs8KeyDer::from_pem_slice(self.bytes_for("client.key")) - .unwrap() - .into() - } - - pub fn get_client_spki(&self) -> SubjectPublicKeyInfoDer<'static> { - SubjectPublicKeyInfoDer::from_pem_slice(self.bytes_for("client.spki.pem")).unwrap() - } - - pub fn get_certified_client_key( - &self, - provider: &CryptoProvider, - ) -> Result, Error> { - let private_key = provider - .key_provider - .load_private_key(self.get_client_key())?; - let public_key = private_key - .public_key() - .ok_or(Error::InconsistentKeys(InconsistentKeys::Unknown))?; - let public_key_as_cert = CertificateDer::from(public_key.to_vec()); - Ok(Arc::new(CertifiedKey::new( - vec![public_key_as_cert], - private_key, - ))) - } - - pub fn certified_key_with_raw_pub_key( - &self, - provider: &CryptoProvider, - ) -> Result, Error> { - let private_key = provider - .key_provider - .load_private_key(self.get_key())?; - let public_key = private_key - .public_key() - .ok_or(Error::InconsistentKeys(InconsistentKeys::Unknown))?; - let public_key_as_cert = CertificateDer::from(public_key.to_vec()); - Ok(Arc::new(CertifiedKey::new( - vec![public_key_as_cert], - private_key, - ))) - } - - pub fn certified_key_with_cert_chain( - &self, - provider: &CryptoProvider, - ) -> Result, Error> { - let private_key = provider - .key_provider - .load_private_key(self.get_key())?; - Ok(Arc::new(CertifiedKey::new(self.get_chain(), private_key))) - } - - fn get_crl(&self, role: &str, r#type: &str) -> CertificateRevocationListDer<'static> { - CertificateRevocationListDer::from_pem_slice( - self.bytes_for(&format!("{role}.{type}.crl.pem")), - ) - .unwrap() - } - - pub fn ca_distinguished_name(&self) -> &'static [u8] { - match self { - KeyType::Rsa2048 => b"0\x1f1\x1d0\x1b\x06\x03U\x04\x03\x0c\x14ponytown RSA 2048 CA", - KeyType::Rsa3072 => b"0\x1f1\x1d0\x1b\x06\x03U\x04\x03\x0c\x14ponytown RSA 3072 CA", - KeyType::Rsa4096 => b"0\x1f1\x1d0\x1b\x06\x03U\x04\x03\x0c\x14ponytown RSA 4096 CA", - KeyType::EcdsaP256 => b"0\x211\x1f0\x1d\x06\x03U\x04\x03\x0c\x16ponytown ECDSA p256 CA", - KeyType::EcdsaP384 => b"0\x211\x1f0\x1d\x06\x03U\x04\x03\x0c\x16ponytown ECDSA p384 CA", - KeyType::EcdsaP521 => b"0\x211\x1f0\x1d\x06\x03U\x04\x03\x0c\x16ponytown ECDSA p521 CA", - KeyType::Ed25519 => b"0\x1c1\x1a0\x18\x06\x03U\x04\x03\x0c\x11ponytown EdDSA CA", - } - } -} +use rustls::RootCertStore; +use rustls::client::{ClientConfig, ServerCertVerifierBuilder, WebPkiServerVerifier}; +use rustls::crypto::CryptoProvider; +use rustls::server::{ClientCertVerifierBuilder, ServerConfig, WebPkiClientVerifier}; +pub use rustls_test::*; pub fn server_config_builder( provider: &CryptoProvider, @@ -504,244 +63,6 @@ pub fn client_config_builder_with_versions( } } -pub fn finish_server_config( - kt: KeyType, - conf: rustls::ConfigBuilder, -) -> ServerConfig { - conf.with_no_client_auth() - .with_single_cert(kt.get_chain(), kt.get_key()) - .unwrap() -} - -pub fn make_server_config(kt: KeyType, provider: &CryptoProvider) -> ServerConfig { - finish_server_config(kt, server_config_builder(provider)) -} - -pub fn make_server_config_with_versions( - kt: KeyType, - versions: &[&'static rustls::SupportedProtocolVersion], - provider: &CryptoProvider, -) -> ServerConfig { - finish_server_config(kt, server_config_builder_with_versions(versions, provider)) -} - -pub fn make_server_config_with_kx_groups( - kt: KeyType, - kx_groups: Vec<&'static dyn rustls::crypto::SupportedKxGroup>, - provider: &CryptoProvider, -) -> ServerConfig { - finish_server_config( - kt, - ServerConfig::builder_with_provider( - CryptoProvider { - kx_groups, - ..provider.clone() - } - .into(), - ) - .with_safe_default_protocol_versions() - .unwrap(), - ) -} - -pub fn get_client_root_store(kt: KeyType) -> Arc { - // The key type's chain file contains the DER encoding of the EE cert, the intermediate cert, - // and the root trust anchor. We want only the trust anchor to build the root cert store. - let chain = kt.get_chain(); - let mut roots = RootCertStore::empty(); - roots - .add(chain.last().unwrap().clone()) - .unwrap(); - roots.into() -} - -pub fn make_server_config_with_mandatory_client_auth_crls( - kt: KeyType, - crls: Vec>, - provider: &CryptoProvider, -) -> ServerConfig { - make_server_config_with_client_verifier( - kt, - webpki_client_verifier_builder(get_client_root_store(kt), provider).with_crls(crls), - provider, - ) -} - -pub fn make_server_config_with_mandatory_client_auth( - kt: KeyType, - provider: &CryptoProvider, -) -> ServerConfig { - make_server_config_with_client_verifier( - kt, - webpki_client_verifier_builder(get_client_root_store(kt), provider), - provider, - ) -} - -pub fn make_server_config_with_optional_client_auth( - kt: KeyType, - crls: Vec>, - provider: &CryptoProvider, -) -> ServerConfig { - make_server_config_with_client_verifier( - kt, - webpki_client_verifier_builder(get_client_root_store(kt), provider) - .with_crls(crls) - .allow_unknown_revocation_status() - .allow_unauthenticated(), - provider, - ) -} - -pub fn make_server_config_with_client_verifier( - kt: KeyType, - verifier_builder: ClientCertVerifierBuilder, - provider: &CryptoProvider, -) -> ServerConfig { - server_config_builder(provider) - .with_client_cert_verifier(verifier_builder.build().unwrap()) - .with_single_cert(kt.get_chain(), kt.get_key()) - .unwrap() -} - -pub fn make_server_config_with_raw_key_support( - kt: KeyType, - provider: &CryptoProvider, -) -> ServerConfig { - let mut client_verifier = - MockClientVerifier::new(|| Ok(ClientCertVerified::assertion()), kt, provider); - let server_cert_resolver = Arc::new(AlwaysResolvesServerRawPublicKeys::new( - kt.certified_key_with_raw_pub_key(provider) - .unwrap(), - )); - client_verifier.expect_raw_public_keys = true; - // We don't support tls1.2 for Raw Public Keys, hence the version is hard-coded. - server_config_builder_with_versions(&[&rustls::version::TLS13], provider) - .with_client_cert_verifier(Arc::new(client_verifier)) - .with_cert_resolver(server_cert_resolver) -} - -pub fn make_client_config_with_raw_key_support( - kt: KeyType, - provider: &CryptoProvider, -) -> ClientConfig { - let server_verifier = Arc::new(MockServerVerifier::expects_raw_public_keys(provider)); - let client_cert_resolver = Arc::new(AlwaysResolvesClientRawPublicKeys::new( - kt.get_certified_client_key(provider) - .unwrap(), - )); - // We don't support tls1.2 for Raw Public Keys, hence the version is hard-coded. - client_config_builder_with_versions(&[&rustls::version::TLS13], provider) - .dangerous() - .with_custom_certificate_verifier(server_verifier) - .with_client_cert_resolver(client_cert_resolver) -} - -pub fn make_client_config_with_cipher_suite_and_raw_key_support( - kt: KeyType, - cipher_suite: SupportedCipherSuite, - provider: &CryptoProvider, -) -> ClientConfig { - let server_verifier = Arc::new(MockServerVerifier::expects_raw_public_keys(provider)); - let client_cert_resolver = Arc::new(AlwaysResolvesClientRawPublicKeys::new( - kt.get_certified_client_key(provider) - .unwrap(), - )); - ClientConfig::builder_with_provider( - CryptoProvider { - cipher_suites: vec![cipher_suite], - ..provider.clone() - } - .into(), - ) - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap() - .dangerous() - .with_custom_certificate_verifier(server_verifier) - .with_client_cert_resolver(client_cert_resolver) -} - -pub fn finish_client_config( - kt: KeyType, - config: rustls::ConfigBuilder, -) -> ClientConfig { - let mut root_store = RootCertStore::empty(); - root_store.add_parsable_certificates( - CertificateDer::pem_slice_iter(kt.bytes_for("ca.cert")).map(|result| result.unwrap()), - ); - - config - .with_root_certificates(root_store) - .with_no_client_auth() -} - -pub fn finish_client_config_with_creds( - kt: KeyType, - config: rustls::ConfigBuilder, -) -> ClientConfig { - let mut root_store = RootCertStore::empty(); - root_store.add_parsable_certificates( - CertificateDer::pem_slice_iter(kt.bytes_for("ca.cert")).map(|result| result.unwrap()), - ); - - config - .with_root_certificates(root_store) - .with_client_auth_cert(kt.get_client_chain(), kt.get_client_key()) - .unwrap() -} - -pub fn make_client_config(kt: KeyType, provider: &CryptoProvider) -> ClientConfig { - finish_client_config(kt, client_config_builder(provider)) -} - -pub fn make_client_config_with_kx_groups( - kt: KeyType, - kx_groups: Vec<&'static dyn rustls::crypto::SupportedKxGroup>, - provider: &CryptoProvider, -) -> ClientConfig { - let builder = ClientConfig::builder_with_provider( - CryptoProvider { - kx_groups, - ..provider.clone() - } - .into(), - ) - .with_safe_default_protocol_versions() - .unwrap(); - finish_client_config(kt, builder) -} - -pub fn make_client_config_with_versions( - kt: KeyType, - versions: &[&'static rustls::SupportedProtocolVersion], - provider: &CryptoProvider, -) -> ClientConfig { - finish_client_config(kt, client_config_builder_with_versions(versions, provider)) -} - -pub fn make_client_config_with_auth(kt: KeyType, provider: &CryptoProvider) -> ClientConfig { - finish_client_config_with_creds(kt, client_config_builder(provider)) -} - -pub fn make_client_config_with_versions_with_auth( - kt: KeyType, - versions: &[&'static rustls::SupportedProtocolVersion], - provider: &CryptoProvider, -) -> ClientConfig { - finish_client_config_with_creds(kt, client_config_builder_with_versions(versions, provider)) -} - -pub fn make_client_config_with_verifier( - versions: &[&'static rustls::SupportedProtocolVersion], - verifier_builder: ServerCertVerifierBuilder, - provider: &CryptoProvider, -) -> ClientConfig { - client_config_builder_with_versions(versions, provider) - .dangerous() - .with_custom_certificate_verifier(verifier_builder.build().unwrap()) - .with_no_client_auth() -} - pub fn webpki_client_verifier_builder( roots: Arc, provider: &CryptoProvider, @@ -764,1026 +85,9 @@ pub fn webpki_server_verifier_builder( } } -pub fn make_pair(kt: KeyType, provider: &CryptoProvider) -> (ClientConnection, ServerConnection) { - make_pair_for_configs( - make_client_config(kt, provider), - make_server_config(kt, provider), - ) -} - -pub fn make_pair_for_configs( - client_config: ClientConfig, - server_config: ServerConfig, -) -> (ClientConnection, ServerConnection) { - make_pair_for_arc_configs(&Arc::new(client_config), &Arc::new(server_config)) -} - -pub fn make_pair_for_arc_configs( - client_config: &Arc, - server_config: &Arc, -) -> (ClientConnection, ServerConnection) { - ( - ClientConnection::new(Arc::clone(client_config), server_name("localhost")).unwrap(), - ServerConnection::new(Arc::clone(server_config)).unwrap(), - ) -} - -pub fn do_handshake( - client: &mut impl DerefMut>, - server: &mut impl DerefMut>, -) -> (usize, usize) { - let (mut to_client, mut to_server) = (0, 0); - while server.is_handshaking() || client.is_handshaking() { - to_server += transfer(client, server); - server.process_new_packets().unwrap(); - to_client += transfer(server, client); - client.process_new_packets().unwrap(); - } - (to_server, to_client) -} - -// Drive a handshake using unbuffered connections. -// -// Note that this drives the connection beyond the handshake until both -// connections are idle and there is no pending data waiting to be processed -// by either. In practice this just means that session tickets are processed -// by the client. -pub fn do_unbuffered_handshake( - client: &mut UnbufferedClientConnection, - server: &mut UnbufferedServerConnection, -) { - fn is_idle(conn: &UnbufferedConnectionCommon, data: &[u8]) -> bool { - !conn.is_handshaking() && !conn.wants_write() && data.is_empty() - } - - let mut client_data = Vec::with_capacity(1024); - let mut server_data = Vec::with_capacity(1024); - - while !is_idle(client, &client_data) || !is_idle(server, &server_data) { - loop { - let UnbufferedStatus { discard, state } = client.process_tls_records(&mut client_data); - let state = state.unwrap(); - - match state { - ConnectionState::BlockedHandshake | ConnectionState::WriteTraffic(_) => { - client_data.drain(..discard); - break; - } - ConnectionState::Closed | ConnectionState::PeerClosed => unreachable!(), - ConnectionState::ReadEarlyData(_) => (), - ConnectionState::EncodeTlsData(mut data) => { - let required = match data.encode(&mut []) { - Err(EncodeError::InsufficientSize(err)) => err.required_size, - _ => unreachable!(), - }; - - let old_len = server_data.len(); - server_data.resize(old_len + required, 0); - data.encode(&mut server_data[old_len..]) - .unwrap(); - } - ConnectionState::TransmitTlsData(data) => data.done(), - st => unreachable!("unexpected connection state: {st:?}"), - } - - client_data.drain(..discard); - } - - loop { - let UnbufferedStatus { discard, state } = server.process_tls_records(&mut server_data); - let state = state.unwrap(); - - match state { - ConnectionState::BlockedHandshake | ConnectionState::WriteTraffic(_) => { - server_data.drain(..discard); - break; - } - ConnectionState::Closed | ConnectionState::PeerClosed => unreachable!(), - ConnectionState::ReadEarlyData(_) => unreachable!(), - ConnectionState::EncodeTlsData(mut data) => { - let required = match data.encode(&mut []) { - Err(EncodeError::InsufficientSize(err)) => err.required_size, - _ => unreachable!(), - }; - - let old_len = client_data.len(); - client_data.resize(old_len + required, 0); - data.encode(&mut client_data[old_len..]) - .unwrap(); - } - ConnectionState::TransmitTlsData(data) => data.done(), - _ => unreachable!(), - } - - server_data.drain(..discard); - } - } - - assert!(server_data.is_empty()); - assert!(client_data.is_empty()); -} - -#[derive(PartialEq, Debug)] -pub enum ErrorFromPeer { - Client(Error), - Server(Error), -} - -pub fn do_handshake_until_error( - client: &mut ClientConnection, - server: &mut ServerConnection, -) -> Result<(), ErrorFromPeer> { - while server.is_handshaking() || client.is_handshaking() { - transfer(client, server); - server - .process_new_packets() - .map_err(ErrorFromPeer::Server)?; - transfer(server, client); - client - .process_new_packets() - .map_err(ErrorFromPeer::Client)?; - } - - Ok(()) -} - -pub fn do_handshake_altered( - client: ClientConnection, - alter_server_message: impl Fn(&mut Message) -> Altered, - alter_client_message: impl Fn(&mut Message) -> Altered, - server: ServerConnection, -) -> Result<(), ErrorFromPeer> { - let mut client: Connection = Connection::Client(client); - let mut server: Connection = Connection::Server(server); - - while server.is_handshaking() || client.is_handshaking() { - transfer_altered(&mut client, &alter_client_message, &mut server); - - server - .process_new_packets() - .map_err(ErrorFromPeer::Server)?; - - transfer_altered(&mut server, &alter_server_message, &mut client); - - client - .process_new_packets() - .map_err(ErrorFromPeer::Client)?; - } - - Ok(()) -} - -pub fn do_handshake_until_both_error( - client: &mut ClientConnection, - server: &mut ServerConnection, -) -> Result<(), Vec> { - match do_handshake_until_error(client, server) { - Err(server_err @ ErrorFromPeer::Server(_)) => { - let mut errors = vec![server_err]; - transfer(server, client); - let client_err = client - .process_new_packets() - .map_err(ErrorFromPeer::Client) - .expect_err("client didn't produce error after server error"); - errors.push(client_err); - Err(errors) - } - - Err(client_err @ ErrorFromPeer::Client(_)) => { - let mut errors = vec![client_err]; - transfer(client, server); - let server_err = server - .process_new_packets() - .map_err(ErrorFromPeer::Server) - .expect_err("server didn't produce error after client error"); - errors.push(server_err); - Err(errors) - } - - Ok(()) => Ok(()), - } -} - -pub fn server_name(name: &'static str) -> ServerName<'static> { - name.try_into().unwrap() -} - -pub struct FailsReads { - errkind: io::ErrorKind, -} - -impl FailsReads { - pub fn new(errkind: io::ErrorKind) -> Self { - Self { errkind } - } -} - -impl io::Read for FailsReads { - fn read(&mut self, _b: &mut [u8]) -> io::Result { - Err(io::Error::from(self.errkind)) - } -} - -pub fn do_suite_and_kx_test( - client_config: ClientConfig, - server_config: ServerConfig, - expect_suite: SupportedCipherSuite, - expect_kx: NamedGroup, - expect_version: ProtocolVersion, -) { - println!( - "do_suite_test {:?} {:?}", - expect_version, - expect_suite.suite() - ); - let (mut client, mut server) = make_pair_for_configs(client_config, server_config); - - assert_eq!(None, client.negotiated_cipher_suite()); - assert_eq!(None, server.negotiated_cipher_suite()); - assert!( - client - .negotiated_key_exchange_group() - .is_none() - ); - assert!( - server - .negotiated_key_exchange_group() - .is_none() - ); - assert_eq!(None, client.protocol_version()); - assert_eq!(None, server.protocol_version()); - assert!(client.is_handshaking()); - assert!(server.is_handshaking()); - - transfer(&mut client, &mut server); - server.process_new_packets().unwrap(); - - assert!(client.is_handshaking()); - assert!(server.is_handshaking()); - assert_eq!(None, client.protocol_version()); - assert_eq!(Some(expect_version), server.protocol_version()); - assert_eq!(None, client.negotiated_cipher_suite()); - assert_eq!(Some(expect_suite), server.negotiated_cipher_suite()); - assert!( - client - .negotiated_key_exchange_group() - .is_none() - ); - if matches!(expect_version, ProtocolVersion::TLSv1_2) { - assert!( - server - .negotiated_key_exchange_group() - .is_none() - ); - } else { - assert_eq!( - expect_kx, - server - .negotiated_key_exchange_group() - .unwrap() - .name() - ); - } - - transfer(&mut server, &mut client); - client.process_new_packets().unwrap(); - - assert_eq!(Some(expect_suite), client.negotiated_cipher_suite()); - assert_eq!(Some(expect_suite), server.negotiated_cipher_suite()); - assert_eq!( - expect_kx, - client - .negotiated_key_exchange_group() - .unwrap() - .name() - ); - if matches!(expect_version, ProtocolVersion::TLSv1_2) { - assert!( - server - .negotiated_key_exchange_group() - .is_none() - ); - } else { - assert_eq!( - expect_kx, - server - .negotiated_key_exchange_group() - .unwrap() - .name() - ); - } - - transfer(&mut client, &mut server); - server.process_new_packets().unwrap(); - transfer(&mut server, &mut client); - client.process_new_packets().unwrap(); - - assert!(!client.is_handshaking()); - assert!(!server.is_handshaking()); - assert_eq!(Some(expect_version), client.protocol_version()); - assert_eq!(Some(expect_version), server.protocol_version()); - assert_eq!(Some(expect_suite), client.negotiated_cipher_suite()); - assert_eq!(Some(expect_suite), server.negotiated_cipher_suite()); - assert_eq!( - expect_kx, - client - .negotiated_key_exchange_group() - .unwrap() - .name() - ); - assert_eq!( - expect_kx, - server - .negotiated_key_exchange_group() - .unwrap() - .name() - ); -} - fn exactly_one_provider() -> bool { cfg!(any( all(feature = "ring", not(feature = "aws_lc_rs")), all(feature = "aws_lc_rs", not(feature = "ring")) )) } - -#[derive(Debug)] -pub struct MockServerVerifier { - cert_rejection_error: Option, - tls12_signature_error: Option, - tls13_signature_error: Option, - signature_schemes: Vec, - expected_ocsp_response: Option>, - requires_raw_public_keys: bool, - raw_public_key_algorithms: Option, -} - -impl ServerCertVerifier for MockServerVerifier { - fn verify_server_cert( - &self, - end_entity: &CertificateDer<'_>, - intermediates: &[CertificateDer<'_>], - server_name: &ServerName<'_>, - ocsp_response: &[u8], - now: UnixTime, - ) -> Result { - println!( - "verify_server_cert({end_entity:?}, {intermediates:?}, {server_name:?}, {ocsp_response:?}, {now:?})" - ); - if let Some(expected_ocsp) = &self.expected_ocsp_response { - assert_eq!(expected_ocsp, ocsp_response); - } - match &self.cert_rejection_error { - Some(error) => Err(error.clone()), - _ => Ok(ServerCertVerified::assertion()), - } - } - - fn verify_tls12_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &DigitallySignedStruct, - ) -> Result { - println!("verify_tls12_signature({message:?}, {cert:?}, {dss:?})"); - match &self.tls12_signature_error { - Some(error) => Err(error.clone()), - _ => Ok(HandshakeSignatureValid::assertion()), - } - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &DigitallySignedStruct, - ) -> Result { - println!("verify_tls13_signature({message:?}, {cert:?}, {dss:?})"); - match &self.tls13_signature_error { - Some(error) => Err(error.clone()), - _ if self.requires_raw_public_keys => verify_tls13_signature_with_raw_key( - message, - &SubjectPublicKeyInfoDer::from(cert.as_ref()), - dss, - self.raw_public_key_algorithms - .as_ref() - .unwrap(), - ), - _ => Ok(HandshakeSignatureValid::assertion()), - } - } - - fn supported_verify_schemes(&self) -> Vec { - self.signature_schemes.clone() - } - - fn requires_raw_public_keys(&self) -> bool { - self.requires_raw_public_keys - } -} - -impl MockServerVerifier { - pub fn accepts_anything() -> Self { - MockServerVerifier { - cert_rejection_error: None, - ..Default::default() - } - } - - pub fn expects_ocsp_response(response: &[u8]) -> Self { - MockServerVerifier { - expected_ocsp_response: Some(response.to_vec()), - ..Default::default() - } - } - - pub fn rejects_certificate(err: Error) -> Self { - MockServerVerifier { - cert_rejection_error: Some(err), - ..Default::default() - } - } - - pub fn rejects_tls12_signatures(err: Error) -> Self { - MockServerVerifier { - tls12_signature_error: Some(err), - ..Default::default() - } - } - - pub fn rejects_tls13_signatures(err: Error) -> Self { - MockServerVerifier { - tls13_signature_error: Some(err), - ..Default::default() - } - } - - pub fn offers_no_signature_schemes() -> Self { - MockServerVerifier { - signature_schemes: vec![], - ..Default::default() - } - } - - pub fn expects_raw_public_keys(provider: &CryptoProvider) -> Self { - MockServerVerifier { - requires_raw_public_keys: true, - raw_public_key_algorithms: Some(provider.signature_verification_algorithms), - ..Default::default() - } - } -} - -impl Default for MockServerVerifier { - fn default() -> Self { - MockServerVerifier { - cert_rejection_error: None, - tls12_signature_error: None, - tls13_signature_error: None, - signature_schemes: vec![ - SignatureScheme::RSA_PSS_SHA256, - SignatureScheme::RSA_PKCS1_SHA256, - SignatureScheme::ED25519, - SignatureScheme::ECDSA_NISTP256_SHA256, - SignatureScheme::ECDSA_NISTP384_SHA384, - SignatureScheme::ECDSA_NISTP521_SHA512, - ], - expected_ocsp_response: None, - requires_raw_public_keys: false, - raw_public_key_algorithms: None, - } - } -} - -#[derive(Debug)] -pub struct MockClientVerifier { - pub verified: fn() -> Result, - pub subjects: Vec, - pub mandatory: bool, - pub offered_schemes: Option>, - expect_raw_public_keys: bool, - raw_public_key_algorithms: Option, - parent: Arc, -} - -impl MockClientVerifier { - pub fn new( - verified: fn() -> Result, - kt: KeyType, - provider: &CryptoProvider, - ) -> Self { - Self { - parent: webpki_client_verifier_builder(get_client_root_store(kt), provider) - .build() - .unwrap(), - verified, - subjects: get_client_root_store(kt).subjects(), - mandatory: true, - offered_schemes: None, - expect_raw_public_keys: false, - raw_public_key_algorithms: Some(provider.signature_verification_algorithms), - } - } -} - -impl ClientCertVerifier for MockClientVerifier { - fn client_auth_mandatory(&self) -> bool { - self.mandatory - } - - fn root_hint_subjects(&self) -> &[DistinguishedName] { - &self.subjects - } - - fn verify_client_cert( - &self, - _end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer<'_>], - _now: UnixTime, - ) -> Result { - (self.verified)() - } - - fn verify_tls12_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &DigitallySignedStruct, - ) -> Result { - if self.expect_raw_public_keys { - Ok(HandshakeSignatureValid::assertion()) - } else { - self.parent - .verify_tls12_signature(message, cert, dss) - } - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &DigitallySignedStruct, - ) -> Result { - if self.expect_raw_public_keys { - verify_tls13_signature_with_raw_key( - message, - &SubjectPublicKeyInfoDer::from(cert.as_ref()), - dss, - self.raw_public_key_algorithms - .as_ref() - .unwrap(), - ) - } else { - self.parent - .verify_tls13_signature(message, cert, dss) - } - } - - fn supported_verify_schemes(&self) -> Vec { - if let Some(schemes) = &self.offered_schemes { - schemes.clone() - } else { - self.parent.supported_verify_schemes() - } - } - - fn requires_raw_public_keys(&self) -> bool { - self.expect_raw_public_keys - } -} - -/// This allows injection/receipt of raw messages into a post-handshake connection. -/// -/// It consumes one of the peers, extracts its secrets, and then reconstitutes the -/// message encrypter/decrypter. It does not do fragmentation/joining. -pub struct RawTls { - encrypter: Box, - enc_seq: u64, - decrypter: Box, - dec_seq: u64, -} - -impl RawTls { - /// conn must be post-handshake, and must have been created with `enable_secret_extraction` - pub fn new_client(conn: ClientConnection) -> Self { - let suite = conn.negotiated_cipher_suite().unwrap(); - Self::new( - suite, - conn.dangerous_extract_secrets() - .unwrap(), - ) - } - - /// conn must be post-handshake, and must have been created with `enable_secret_extraction` - pub fn new_server(conn: ServerConnection) -> Self { - let suite = conn.negotiated_cipher_suite().unwrap(); - Self::new( - suite, - conn.dangerous_extract_secrets() - .unwrap(), - ) - } - - fn new(suite: SupportedCipherSuite, secrets: rustls::ExtractedSecrets) -> Self { - let rustls::ExtractedSecrets { - tx: (tx_seq, tx_keys), - rx: (rx_seq, rx_keys), - } = secrets; - - let encrypter = match (tx_keys, suite) { - ( - rustls::ConnectionTrafficSecrets::Aes256Gcm { key, iv }, - SupportedCipherSuite::Tls13(tls13), - ) => tls13.aead_alg.encrypter(key, iv), - - ( - rustls::ConnectionTrafficSecrets::Aes256Gcm { key, iv }, - SupportedCipherSuite::Tls12(tls12), - ) => tls12 - .aead_alg - .encrypter(key, &iv.as_ref()[..4], &iv.as_ref()[4..]), - - _ => todo!(), - }; - - let decrypter = match (rx_keys, suite) { - ( - rustls::ConnectionTrafficSecrets::Aes256Gcm { key, iv }, - SupportedCipherSuite::Tls13(tls13), - ) => tls13.aead_alg.decrypter(key, iv), - - ( - rustls::ConnectionTrafficSecrets::Aes256Gcm { key, iv }, - SupportedCipherSuite::Tls12(tls12), - ) => tls12 - .aead_alg - .decrypter(key, &iv.as_ref()[..4]), - - _ => todo!(), - }; - - Self { - encrypter, - enc_seq: tx_seq, - decrypter, - dec_seq: rx_seq, - } - } - - pub fn encrypt_and_send( - &mut self, - msg: &PlainMessage, - peer: &mut impl DerefMut>, - ) { - let data = self - .encrypter - .encrypt(msg.borrow_outbound(), self.enc_seq) - .unwrap() - .encode(); - self.enc_seq += 1; - peer.read_tls(&mut io::Cursor::new(data)) - .unwrap(); - } - - pub fn receive_and_decrypt( - &mut self, - peer: &mut impl DerefMut>, - f: impl Fn(Message), - ) { - let mut data = vec![]; - peer.write_tls(&mut io::Cursor::new(&mut data)) - .unwrap(); - - let mut reader = Reader::init(&data); - let content_type = ContentType::read(&mut reader).unwrap(); - let version = ProtocolVersion::read(&mut reader).unwrap(); - let len = u16::read(&mut reader).unwrap(); - let left = &mut data[5..]; - assert_eq!(len as usize, left.len()); - - let inbound = InboundOpaqueMessage::new(content_type, version, left); - let plain = self - .decrypter - .decrypt(inbound, self.dec_seq) - .unwrap(); - self.dec_seq += 1; - - let msg = Message::try_from(plain).unwrap(); - println!("receive_and_decrypt: {msg:?}"); - - f(msg); - } -} - -pub fn aes_128_gcm_with_1024_confidentiality_limit( - provider: CryptoProvider, -) -> Arc { - const CONFIDENTIALITY_LIMIT: u64 = 1024; - - // needed to extend lifetime of Tls13CipherSuite to 'static - static TLS13_LIMITED_SUITE: OnceLock = OnceLock::new(); - static TLS12_LIMITED_SUITE: OnceLock = OnceLock::new(); - - let tls13_limited = TLS13_LIMITED_SUITE.get_or_init(|| { - let tls13 = provider - .cipher_suites - .iter() - .find(|cs| cs.suite() == CipherSuite::TLS13_AES_128_GCM_SHA256) - .unwrap() - .tls13() - .unwrap(); - - rustls::Tls13CipherSuite { - common: rustls::crypto::CipherSuiteCommon { - confidentiality_limit: CONFIDENTIALITY_LIMIT, - ..tls13.common - }, - ..*tls13 - } - }); - - let tls12_limited = TLS12_LIMITED_SUITE.get_or_init(|| { - let SupportedCipherSuite::Tls12(tls12) = *provider - .cipher_suites - .iter() - .find(|cs| cs.suite() == CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) - .unwrap() - else { - unreachable!(); - }; - - rustls::Tls12CipherSuite { - common: rustls::crypto::CipherSuiteCommon { - confidentiality_limit: CONFIDENTIALITY_LIMIT, - ..tls12.common - }, - ..*tls12 - } - }); - - CryptoProvider { - cipher_suites: vec![ - SupportedCipherSuite::Tls13(tls13_limited), - SupportedCipherSuite::Tls12(tls12_limited), - ], - ..provider - } - .into() -} - -pub fn unsafe_plaintext_crypto_provider(provider: CryptoProvider) -> Arc { - static TLS13_PLAIN_SUITE: OnceLock = OnceLock::new(); - - let tls13 = TLS13_PLAIN_SUITE.get_or_init(|| { - let tls13 = provider - .cipher_suites - .iter() - .find(|cs| cs.suite() == CipherSuite::TLS13_AES_256_GCM_SHA384) - .unwrap() - .tls13() - .unwrap(); - - rustls::Tls13CipherSuite { - aead_alg: &plaintext::Aead, - common: rustls::crypto::CipherSuiteCommon { ..tls13.common }, - ..*tls13 - } - }); - - CryptoProvider { - cipher_suites: vec![SupportedCipherSuite::Tls13(tls13)], - ..provider - } - .into() -} - -mod plaintext { - use rustls::ConnectionTrafficSecrets; - use rustls::crypto::cipher::{ - AeadKey, InboundOpaqueMessage, InboundPlainMessage, Iv, MessageDecrypter, MessageEncrypter, - OutboundPlainMessage, PrefixedPayload, Tls13AeadAlgorithm, UnsupportedOperationError, - }; - - use super::*; - - pub(super) struct Aead; - - impl Tls13AeadAlgorithm for Aead { - fn encrypter(&self, _key: AeadKey, _iv: Iv) -> Box { - Box::new(Encrypter) - } - - fn decrypter(&self, _key: AeadKey, _iv: Iv) -> Box { - Box::new(Decrypter) - } - - fn key_len(&self) -> usize { - 32 - } - - fn extract_keys( - &self, - _key: AeadKey, - _iv: Iv, - ) -> Result { - Err(UnsupportedOperationError) - } - } - - struct Encrypter; - - impl MessageEncrypter for Encrypter { - fn encrypt( - &mut self, - msg: OutboundPlainMessage<'_>, - _seq: u64, - ) -> Result { - let mut payload = PrefixedPayload::with_capacity(msg.payload.len()); - payload.extend_from_chunks(&msg.payload); - - Ok(OutboundOpaqueMessage::new( - ContentType::ApplicationData, - ProtocolVersion::TLSv1_2, - payload, - )) - } - - fn encrypted_payload_len(&self, payload_len: usize) -> usize { - payload_len - } - } - - struct Decrypter; - - impl MessageDecrypter for Decrypter { - fn decrypt<'a>( - &mut self, - msg: InboundOpaqueMessage<'a>, - _seq: u64, - ) -> Result, Error> { - Ok(msg.into_plain_message()) - } - } -} - -/// Deeply inefficient, test-only TLS encoding helpers -pub mod encoding { - use rustls::internal::msgs::codec::Codec; - use rustls::internal::msgs::enums::ExtensionType; - use rustls::{ - CipherSuite, ContentType, HandshakeType, NamedGroup, ProtocolVersion, SignatureScheme, - }; - - /// Return a client hello with mandatory extensions added to `extensions` - /// - /// The returned bytes are handshake-framed, but not message-framed. - pub fn basic_client_hello(mut extensions: Vec) -> Vec { - extensions.push(Extension::new_kx_groups()); - extensions.push(Extension::new_sig_algs()); - extensions.push(Extension::new_versions()); - extensions.push(Extension::new_dummy_key_share()); - client_hello_with_extensions(extensions) - } - - /// Return a client hello with exactly `extensions` - /// - /// The returned bytes are handshake-framed, but not message-framed. - pub fn client_hello_with_extensions(extensions: Vec) -> Vec { - client_hello( - ProtocolVersion::TLSv1_2, - &[0u8; 32], - &[0], - vec![ - CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - CipherSuite::TLS13_AES_128_GCM_SHA256, - ], - extensions, - ) - } - - pub fn client_hello( - legacy_version: ProtocolVersion, - random: &[u8; 32], - session_id: &[u8], - cipher_suites: Vec, - extensions: Vec, - ) -> Vec { - let mut out = vec![]; - - legacy_version.encode(&mut out); - out.extend_from_slice(random); - out.extend_from_slice(session_id); - cipher_suites.to_vec().encode(&mut out); - out.extend_from_slice(&[0x01, 0x00]); // only null compression - - let mut exts = vec![]; - for e in extensions { - e.typ.encode(&mut exts); - exts.extend_from_slice(&(e.body.len() as u16).to_be_bytes()); - exts.extend_from_slice(&e.body); - } - - out.extend(len_u16(exts)); - handshake_framing(HandshakeType::ClientHello, out) - } - - /// Apply handshake framing to `body`. - /// - /// This does not do fragmentation. - pub fn handshake_framing(ty: HandshakeType, body: Vec) -> Vec { - let mut body = len_u24(body); - body.splice(0..0, ty.to_array()); - body - } - - /// Apply message framing to `body`. - pub fn message_framing(ty: ContentType, vers: ProtocolVersion, body: Vec) -> Vec { - let mut body = len_u16(body); - body.splice(0..0, vers.to_array()); - body.splice(0..0, ty.to_array()); - body - } - - #[derive(Clone)] - pub struct Extension { - pub typ: ExtensionType, - pub body: Vec, - } - - impl Extension { - pub fn new_sig_algs() -> Extension { - Extension { - typ: ExtensionType::SignatureAlgorithms, - body: len_u16( - SignatureScheme::RSA_PKCS1_SHA256 - .to_array() - .to_vec(), - ), - } - } - - pub fn new_kx_groups() -> Extension { - Extension { - typ: ExtensionType::EllipticCurves, - body: len_u16(vector_of([NamedGroup::secp256r1].into_iter())), - } - } - - pub fn new_versions() -> Extension { - Extension { - typ: ExtensionType::SupportedVersions, - body: len_u8(vector_of( - [ProtocolVersion::TLSv1_3, ProtocolVersion::TLSv1_2].into_iter(), - )), - } - } - - pub fn new_dummy_key_share() -> Extension { - const SOME_POINT_ON_P256: &[u8] = &[ - 4, 41, 39, 177, 5, 18, 186, 227, 237, 220, 254, 70, 120, 40, 18, 139, 173, 41, 3, - 38, 153, 25, 247, 8, 96, 105, 200, 196, 223, 108, 115, 40, 56, 199, 120, 121, 100, - 234, 172, 0, 229, 146, 31, 177, 73, 138, 96, 244, 96, 103, 102, 179, 217, 104, 80, - 1, 85, 141, 26, 151, 78, 115, 65, 81, 62, - ]; - - let mut share = len_u16(SOME_POINT_ON_P256.to_vec()); - share.splice(0..0, NamedGroup::secp256r1.to_array()); - - Extension { - typ: ExtensionType::KeyShare, - body: len_u16(share), - } - } - } - - /// Prefix with u8 length - pub fn len_u8(mut body: Vec) -> Vec { - body.splice(0..0, [body.len() as u8]); - body - } - - /// Prefix with u16 length - pub fn len_u16(mut body: Vec) -> Vec { - body.splice(0..0, (body.len() as u16).to_be_bytes()); - body - } - - /// Prefix with u24 length - pub fn len_u24(mut body: Vec) -> Vec { - let len = (body.len() as u32).to_be_bytes(); - body.insert(0, len[1]); - body.insert(1, len[2]); - body.insert(2, len[3]); - body - } - - /// Encode each of `items` - pub fn vector_of<'a, T: Codec<'a>>(items: impl Iterator) -> Vec { - let mut body = Vec::new(); - - for i in items { - i.encode(&mut body); - } - body - } -} From 48790d512b5a5bce0f9d62f45a819691ba38b454 Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 14:55:37 +0100 Subject: [PATCH 13/14] rustls-bench: use rustls-test keys/certificates --- Cargo.lock | 1 + rustls-bench/Cargo.toml | 1 + rustls-bench/src/main.rs | 60 ++++++++++------------------------------ 3 files changed, 17 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4b545ee4bc..ee91ef1f79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2624,6 +2624,7 @@ dependencies = [ "clap", "rustls 0.23.27", "rustls-post-quantum", + "rustls-test", "tikv-jemallocator", ] diff --git a/rustls-bench/Cargo.toml b/rustls-bench/Cargo.toml index 36f70198b1..0b7ff90f7e 100644 --- a/rustls-bench/Cargo.toml +++ b/rustls-bench/Cargo.toml @@ -8,6 +8,7 @@ publish = false clap = { workspace = true } rustls = { path = "../rustls" } rustls-post-quantum = { path = "../rustls-post-quantum", optional = true } +rustls-test = { workspace = true } [features] default = [] diff --git a/rustls-bench/src/main.rs b/rustls-bench/src/main.rs index 73385f74f7..707c0319d7 100644 --- a/rustls-bench/src/main.rs +++ b/rustls-bench/src/main.rs @@ -14,8 +14,6 @@ use std::{mem, thread}; use clap::{Parser, ValueEnum}; use rustls::client::{Resumption, UnbufferedClientConnection}; use rustls::crypto::CryptoProvider; -use rustls::pki_types::pem::PemObject; -use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; use rustls::server::{ NoServerSessionStorage, ProducesTickets, ServerSessionMemoryCache, UnbufferedServerConnection, WebPkiClientVerifier, @@ -25,6 +23,7 @@ use rustls::{ CipherSuite, ClientConfig, ClientConnection, ConnectionCommon, Error, HandshakeKind, RootCertStore, ServerConfig, ServerConnection, SideData, }; +use rustls_test::KeyType; pub fn main() { let args = Args::parse(); @@ -111,7 +110,7 @@ struct Args { long, help = "Which key type to use for server and client authentication. The default is to run tests once for each key type." )] - key_type: Option, + key_type: Option, #[arg(long, help = "Which provider to test")] provider: Option, @@ -669,13 +668,13 @@ fn bench_memory( fn lookup_matching_benches( ciphersuite_name: &str, - key_type: Option, + key_type: Option, ) -> Vec { let r: Vec = ALL_BENCHMARKS .iter() .filter(|params| { format!("{:?}", params.ciphersuite).to_lowercase() == ciphersuite_name.to_lowercase() - && (key_type.is_none() || Some(params.key_type) == key_type) + && (key_type.is_none() || Some(params.key_type) == key_type.map(KeyType::from)) }) .cloned() .collect(); @@ -808,11 +807,9 @@ impl Parameters { fn client_config(&self) -> Arc { let mut root_store = RootCertStore::empty(); - root_store.add_parsable_certificates( - CertificateDer::pem_file_iter(self.proto.key_type.path_for("ca.cert")) - .unwrap() - .map(|result| result.unwrap()), - ); + root_store + .add(self.proto.key_type.ca_cert()) + .unwrap(); let cfg = ClientConfig::builder_with_provider( CryptoProvider { @@ -1035,50 +1032,23 @@ impl BenchmarkParam { } } -// copied from tests/api.rs #[derive(PartialEq, Clone, Copy, Debug, ValueEnum)] -enum KeyType { +enum RequestedKeyType { Rsa2048, EcdsaP256, EcdsaP384, Ed25519, } -impl KeyType { - fn path_for(&self, part: &str) -> String { - match self { - Self::Rsa2048 => format!("test-ca/rsa-2048/{part}"), - Self::EcdsaP256 => format!("test-ca/ecdsa-p256/{part}"), - Self::EcdsaP384 => format!("test-ca/ecdsa-p384/{part}"), - Self::Ed25519 => format!("test-ca/eddsa/{part}"), +impl From for KeyType { + fn from(val: RequestedKeyType) -> Self { + match val { + RequestedKeyType::Rsa2048 => Self::Rsa2048, + RequestedKeyType::EcdsaP256 => Self::EcdsaP256, + RequestedKeyType::EcdsaP384 => Self::EcdsaP384, + RequestedKeyType::Ed25519 => Self::Ed25519, } } - - fn get_chain(&self) -> Vec> { - CertificateDer::pem_file_iter(self.path_for("end.fullchain")) - .unwrap() - .map(|result| result.unwrap()) - .collect() - } - - fn get_key(&self) -> PrivateKeyDer<'static> { - PrivatePkcs8KeyDer::from_pem_file(self.path_for("end.key")) - .unwrap() - .into() - } - - fn get_client_chain(&self) -> Vec> { - CertificateDer::pem_file_iter(self.path_for("client.fullchain")) - .unwrap() - .map(|result| result.unwrap()) - .collect() - } - - fn get_client_key(&self) -> PrivateKeyDer<'static> { - PrivatePkcs8KeyDer::from_pem_file(self.path_for("client.key")) - .unwrap() - .into() - } } struct Unbuffered { From 9dc5d32706256d588dd8824155b1d31e04b608b3 Mon Sep 17 00:00:00 2001 From: Joe Birr-Pixton Date: Mon, 19 May 2025 16:41:57 +0100 Subject: [PATCH 14/14] rustls-test: determine supported keytypes at runtime --- rustls-test/src/lib.rs | 22 +++++- rustls/tests/api.rs | 105 ++++++++++++++------------- rustls/tests/client_cert_verifier.rs | 11 +-- rustls/tests/server_cert_verifier.rs | 12 +-- 4 files changed, 86 insertions(+), 64 deletions(-) diff --git a/rustls-test/src/lib.rs b/rustls-test/src/lib.rs index 64e922bad1..1556c8ae7b 100644 --- a/rustls-test/src/lib.rs +++ b/rustls-test/src/lib.rs @@ -302,7 +302,7 @@ pub enum KeyType { Ed25519, } -pub static ALL_KEY_TYPES: &[KeyType] = &[ +static ALL_KEY_TYPES: &[KeyType] = &[ KeyType::Rsa2048, KeyType::Rsa3072, KeyType::Rsa4096, @@ -312,7 +312,27 @@ pub static ALL_KEY_TYPES: &[KeyType] = &[ KeyType::Ed25519, ]; +static ALL_KEY_TYPES_EXCEPT_P521: &[KeyType] = &[ + KeyType::Rsa2048, + KeyType::Rsa3072, + KeyType::Rsa4096, + KeyType::EcdsaP256, + KeyType::EcdsaP384, + KeyType::Ed25519, +]; + impl KeyType { + pub fn all_for_provider(provider: &CryptoProvider) -> &'static [KeyType] { + match provider + .key_provider + .load_private_key(Self::EcdsaP521.get_key()) + .is_ok() + { + true => ALL_KEY_TYPES, + false => ALL_KEY_TYPES_EXCEPT_P521, + } + } + fn bytes_for(&self, part: &str) -> &'static [u8] { match self { Self::Rsa2048 => bytes_for("rsa-2048", part), diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index a689bf1a9f..b999ab5b20 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -61,7 +61,7 @@ mod test_raw_keys { #[test] fn successful_raw_key_connection_and_correct_peer_certificates() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let client_config = make_client_config_with_raw_key_support(*kt, &provider); let server_config = make_server_config_with_raw_key_support(*kt, &provider); @@ -97,7 +97,7 @@ mod test_raw_keys { #[test] fn correct_certificate_type_extensions_from_client_hello() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let client_config = make_client_config_with_raw_key_support(*kt, &provider); let mut server_config = make_server_config_with_raw_key_support(*kt, &provider); @@ -116,7 +116,7 @@ mod test_raw_keys { #[test] fn only_client_supports_raw_keys() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let client_config_rpk = make_client_config_with_raw_key_support(*kt, &provider); let server_config = make_server_config(*kt, &provider); @@ -143,7 +143,7 @@ mod test_raw_keys { #[test] fn only_server_supports_raw_keys() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let client_config = make_client_config_with_versions(*kt, &[&rustls::version::TLS13], &provider); let server_config_rpk = make_server_config_with_raw_key_support(*kt, &provider); @@ -194,7 +194,7 @@ mod test_raw_keys { expected_result: Result<(), ErrorFromPeer>, ) { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let client_config = Arc::new(make_client_config(*kt, &provider)); let server_config_rpk = match server_requires_raw_keys { true => Arc::new(make_server_config_with_raw_key_support(*kt, &provider)), @@ -218,7 +218,7 @@ mod test_raw_keys { #[test] fn incorrectly_alter_server_hello() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let supported_suite = cipher_suite::TLS13_AES_256_GCM_SHA384; // Alter Server Hello server certificate extension and expect IncorrectCertificateTypeExtension error @@ -885,7 +885,7 @@ fn buffered_both_data_sent() { #[test] fn client_can_get_server_cert() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { for version in rustls::ALL_VERSIONS { let client_config = make_client_config_with_versions(*kt, &[version], &provider); let (mut client, mut server) = @@ -901,7 +901,7 @@ fn client_can_get_server_cert() { #[test] fn client_can_get_server_cert_after_resumption() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let server_config = make_server_config(*kt, &provider); for version in rustls::ALL_VERSIONS { let client_config = make_client_config_with_versions(*kt, &[version], &provider); @@ -995,7 +995,7 @@ fn client_only_attempts_resumption_with_compatible_security() { #[test] fn server_can_get_client_cert() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let server_config = Arc::new(make_server_config_with_mandatory_client_auth( *kt, &provider, )); @@ -1016,7 +1016,7 @@ fn server_can_get_client_cert() { #[test] fn server_can_get_client_cert_after_resumption() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let server_config = Arc::new(make_server_config_with_mandatory_client_auth( *kt, &provider, )); @@ -1042,7 +1042,7 @@ fn server_can_get_client_cert_after_resumption() { #[test] fn resumption_combinations() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let server_config = make_server_config(*kt, &provider); for version in rustls::ALL_VERSIONS { let client_config = make_client_config_with_versions(*kt, &[version], &provider); @@ -1527,7 +1527,7 @@ impl ResolvesServerCert for ServerCheckCertResolve { #[test] fn server_cert_resolve_with_sni() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let client_config = make_client_config(*kt, &provider); let mut server_config = make_server_config(*kt, &provider); @@ -1549,7 +1549,7 @@ fn server_cert_resolve_with_sni() { #[test] fn server_cert_resolve_with_alpn() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let mut client_config = make_client_config(*kt, &provider); client_config.alpn_protocols = vec!["foo".into(), "bar".into()]; @@ -1571,7 +1571,7 @@ fn server_cert_resolve_with_alpn() { #[test] fn client_trims_terminating_dot() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let client_config = make_client_config(*kt, &provider); let mut server_config = make_server_config(*kt, &provider); @@ -1678,7 +1678,7 @@ impl ResolvesServerCert for ServerCheckNoSni { #[test] fn client_with_sni_disabled_does_not_send_sni() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let mut server_config = make_server_config(*kt, &provider); server_config.cert_resolver = Arc::new(ServerCheckNoSni {}); let server_config = Arc::new(server_config); @@ -1701,7 +1701,7 @@ fn client_with_sni_disabled_does_not_send_sni() { #[test] fn client_checks_server_certificate_with_given_name() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let server_config = Arc::new(make_server_config(*kt, &provider)); for version in rustls::ALL_VERSIONS { @@ -1737,7 +1737,7 @@ fn client_checks_server_certificate_with_given_ip_address() { } let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let server_config = Arc::new(make_server_config(*kt, &provider)); for version in rustls::ALL_VERSIONS { @@ -1778,7 +1778,7 @@ fn client_checks_server_certificate_with_given_ip_address() { #[test] fn client_check_server_certificate_ee_revoked() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let server_config = Arc::new(make_server_config(*kt, &provider)); // Setup a server verifier that will check the EE certificate's revocation status. @@ -1809,7 +1809,7 @@ fn client_check_server_certificate_ee_revoked() { #[test] fn client_check_server_certificate_ee_unknown_revocation() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let server_config = Arc::new(make_server_config(*kt, &provider)); // Setup a server verifier builder that will check the EE certificate's revocation status, but not @@ -1866,7 +1866,7 @@ fn client_check_server_certificate_ee_unknown_revocation() { #[test] fn client_check_server_certificate_intermediate_revoked() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let server_config = Arc::new(make_server_config(*kt, &provider)); // Setup a server verifier builder that will check the full chain revocation status against a CRL @@ -1925,7 +1925,7 @@ fn client_check_server_certificate_intermediate_revoked() { #[test] fn client_check_server_certificate_ee_crl_expired() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let server_config = Arc::new(make_server_config(*kt, &provider)); // Setup a server verifier that will check the EE certificate's revocation status, with CRL expiration enforced. @@ -1983,7 +1983,7 @@ fn client_check_server_certificate_ee_crl_expired() { /// so isn't used by the other existing verifier tests. #[test] fn client_check_server_certificate_helper_api() { - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider::default_provider()) { let chain = kt.get_chain(); let correct_roots = get_client_root_store(*kt); let incorrect_roots = get_client_root_store(match kt { @@ -2165,7 +2165,7 @@ fn client_cert_resolve_default() { // Test that in the default configuration that a client cert resolver gets the expected // CA subject hints, and supported signature algorithms. let provider = provider::default_provider(); - for key_type in ALL_KEY_TYPES { + for key_type in KeyType::all_for_provider(&provider) { let server_config = Arc::new(make_server_config_with_mandatory_client_auth( *key_type, &provider, )); @@ -2187,7 +2187,7 @@ fn client_cert_resolve_server_no_hints() { // Test that a server can provide no hints and the client cert resolver gets the expected // arguments. let provider = provider::default_provider(); - for key_type in ALL_KEY_TYPES { + for key_type in KeyType::all_for_provider(&provider) { // Build a verifier with no hint subjects. let verifier = webpki_client_verifier_builder(get_client_root_store(*key_type), &provider) .clear_root_hint_subjects(); @@ -2203,7 +2203,7 @@ fn client_cert_resolve_server_added_hint() { // and the client cert resolver gets the expected arguments. let provider = provider::default_provider(); let extra_name = b"0\x1a1\x180\x16\x06\x03U\x04\x03\x0c\x0fponyland IDK CA".to_vec(); - for key_type in ALL_KEY_TYPES { + for key_type in KeyType::all_for_provider(&provider) { let expected_hint_subjects = vec![ key_type .ca_distinguished_name() @@ -2222,7 +2222,7 @@ fn client_cert_resolve_server_added_hint() { #[test] fn client_auth_works() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let server_config = Arc::new(make_server_config_with_mandatory_client_auth( *kt, &provider, )); @@ -2239,10 +2239,10 @@ fn client_auth_works() { #[test] fn client_mandatory_auth_client_revocation_works() { - for kt in ALL_KEY_TYPES { + let provider = provider::default_provider(); + for kt in KeyType::all_for_provider(&provider) { // Create a server configuration that includes a CRL that specifies the client certificate // is revoked. - let provider = provider::default_provider(); let relevant_crls = vec![kt.client_crl()]; // Only check the EE certificate status. See client_mandatory_auth_intermediate_revocation_works // for testing revocation status of the whole chain. @@ -2321,7 +2321,7 @@ fn client_mandatory_auth_client_revocation_works() { #[test] fn client_mandatory_auth_intermediate_revocation_works() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { // Create a server configuration that includes a CRL that specifies the intermediate certificate // is revoked. We check the full chain for revocation status (default), and allow unknown // revocation status so the EE's unknown revocation status isn't an error. @@ -2377,7 +2377,7 @@ fn client_mandatory_auth_intermediate_revocation_works() { #[test] fn client_optional_auth_client_revocation_works() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { // Create a server configuration that includes a CRL that specifies the client certificate // is revoked. let crls = vec![kt.client_crl()]; @@ -2887,7 +2887,7 @@ fn client_complete_io_for_handshake_eof() { #[test] fn client_complete_io_for_write() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let (mut client, mut server) = make_pair(*kt, &provider); do_handshake(&mut client, &mut server); @@ -2917,7 +2917,7 @@ fn client_complete_io_for_write() { #[test] fn buffered_client_complete_io_for_write() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let (mut client, mut server) = make_pair(*kt, &provider); do_handshake(&mut client, &mut server); @@ -2947,7 +2947,7 @@ fn buffered_client_complete_io_for_write() { #[test] fn client_complete_io_for_read() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let (mut client, mut server) = make_pair(*kt, &provider); do_handshake(&mut client, &mut server); @@ -2969,7 +2969,7 @@ fn client_complete_io_for_read() { #[test] fn server_complete_io_for_handshake() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let (mut client, mut server) = make_pair(*kt, &provider); assert!(server.is_handshaking()); @@ -2997,7 +2997,7 @@ fn server_complete_io_for_handshake_eof() { #[test] fn server_complete_io_for_write() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let (mut client, mut server) = make_pair(*kt, &provider); do_handshake(&mut client, &mut server); @@ -3026,7 +3026,7 @@ fn server_complete_io_for_write() { #[test] fn server_complete_io_for_write_eof() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let (mut client, mut server) = make_pair(*kt, &provider); do_handshake(&mut client, &mut server); @@ -3083,7 +3083,7 @@ impl std::io::Read for EofWriter { #[test] fn server_complete_io_for_read() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let (mut client, mut server) = make_pair(*kt, &provider); do_handshake(&mut client, &mut server); @@ -3122,7 +3122,7 @@ enum StreamKind { fn test_client_stream_write(stream_kind: StreamKind) { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let (mut client, mut server) = make_pair(*kt, &provider); let data = b"hello"; { @@ -3139,7 +3139,7 @@ fn test_client_stream_write(stream_kind: StreamKind) { fn test_server_stream_write(stream_kind: StreamKind) { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let (mut client, mut server) = make_pair(*kt, &provider); let data = b"hello"; { @@ -3208,7 +3208,7 @@ fn test_stream_read(read_kind: ReadKind, mut stream: impl BufRead, data: &[u8]) fn test_client_stream_read(stream_kind: StreamKind, read_kind: ReadKind) { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let (mut client, mut server) = make_pair(*kt, &provider); let data = b"world"; server.writer().write_all(data).unwrap(); @@ -3229,7 +3229,7 @@ fn test_client_stream_read(stream_kind: StreamKind, read_kind: ReadKind) { fn test_server_stream_read(stream_kind: StreamKind, read_kind: ReadKind) { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let (mut client, mut server) = make_pair(*kt, &provider); let data = b"world"; client.writer().write_all(data).unwrap(); @@ -3869,7 +3869,7 @@ fn do_exporter_test(client_config: ClientConfig, server_config: ServerConfig) { #[test] fn test_tls12_exporter() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let client_config = make_client_config_with_versions(*kt, &[&rustls::version::TLS12], &provider); let server_config = make_server_config(*kt, &provider); @@ -3881,7 +3881,7 @@ fn test_tls12_exporter() { #[test] fn test_tls13_exporter() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let client_config = make_client_config_with_versions(*kt, &[&rustls::version::TLS13], &provider); let server_config = make_server_config(*kt, &provider); @@ -4022,7 +4022,7 @@ fn test_ciphersuites() -> Vec<( #[test] fn negotiated_ciphersuite_default() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { do_suite_and_kx_test( make_client_config(*kt, &provider), make_server_config(*kt, &provider), @@ -5390,7 +5390,7 @@ mod test_quic { let server_params = &b"server params"[..]; let provider = provider::default_provider(); - for &kt in ALL_KEY_TYPES { + for &kt in KeyType::all_for_provider(&provider) { let client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13], &provider); let client_config = Arc::new(client_config); @@ -5754,7 +5754,7 @@ mod test_quic { #[test] fn test_quic_exporter() { let provider = provider::default_provider(); - for &kt in ALL_KEY_TYPES { + for &kt in KeyType::all_for_provider(&provider) { let client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13], &provider); let server_config = @@ -5810,7 +5810,7 @@ fn test_client_does_not_offer_sha1() { use rustls::internal::msgs::message::{MessagePayload, OutboundOpaqueMessage}; let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { for version in rustls::ALL_VERSIONS { let client_config = make_client_config_with_versions(*kt, &[version], &provider); let (mut client, _) = @@ -6313,7 +6313,7 @@ fn test_client_mtu_reduction() { } let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { let mut client_config = make_client_config(*kt, &provider); client_config.max_fragment_size = Some(64); let mut client = @@ -6408,7 +6408,7 @@ fn bad_client_max_fragment_sizes() { fn handshakes_complete_and_data_flows_with_gratuitious_max_fragment_sizes() { // general exercising of msgs::fragmenter and msgs::deframer let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { for version in rustls::ALL_VERSIONS { // no hidden significance to these numbers for frag_size in [37, 61, 101, 257] { @@ -6887,7 +6887,7 @@ fn test_no_warning_logging_during_successful_sessions() { CountingLogger::reset(); let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES { + for kt in KeyType::all_for_provider(&provider) { for version in rustls::ALL_VERSIONS { let client_config = make_client_config_with_versions(*kt, &[version], &provider); let (mut client, mut server) = @@ -8387,8 +8387,9 @@ fn tls12_connection_fails_after_key_reaches_confidentiality_limit() { #[test] fn test_keys_match_for_all_signing_key_types() { - for kt in ALL_KEY_TYPES { - let key = provider::default_provider() + let provider = provider::default_provider(); + for kt in KeyType::all_for_provider(&provider) { + let key = provider .key_provider .load_private_key(kt.get_client_key()) .unwrap(); diff --git a/rustls/tests/client_cert_verifier.rs b/rustls/tests/client_cert_verifier.rs index 381c16d430..d25dcb3bf0 100644 --- a/rustls/tests/client_cert_verifier.rs +++ b/rustls/tests/client_cert_verifier.rs @@ -7,7 +7,7 @@ use super::*; mod common; use common::{ - ALL_KEY_TYPES, Arc, ErrorFromPeer, KeyType, MockClientVerifier, do_handshake_until_both_error, + Arc, ErrorFromPeer, KeyType, MockClientVerifier, do_handshake_until_both_error, do_handshake_until_error, make_client_config_with_versions, make_client_config_with_versions_with_auth, make_pair_for_arc_configs, server_config_builder, server_name, @@ -47,7 +47,7 @@ fn server_config_with_verifier( // Happy path, we resolve to a root, it is verified OK, should be able to connect fn client_verifier_works() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES.iter() { + for kt in KeyType::all_for_provider(&provider).iter() { let client_verifier = MockClientVerifier::new(ver_ok, *kt, &provider); let server_config = server_config_with_verifier(*kt, client_verifier); let server_config = Arc::new(server_config); @@ -67,7 +67,7 @@ fn client_verifier_works() { #[test] fn client_verifier_no_schemes() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES.iter() { + for kt in KeyType::all_for_provider(&provider).iter() { let mut client_verifier = MockClientVerifier::new(ver_ok, *kt, &provider); client_verifier.offered_schemes = Some(vec![]); let server_config = server_config_with_verifier(*kt, client_verifier); @@ -93,8 +93,9 @@ fn client_verifier_no_schemes() { #[test] fn client_verifier_no_auth_yes_root() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES.iter() { + for kt in KeyType::all_for_provider(&provider).iter() { let client_verifier = MockClientVerifier::new(ver_unreachable, *kt, &provider); + let server_config = server_config_with_verifier(*kt, client_verifier); let server_config = Arc::new(server_config); @@ -121,7 +122,7 @@ fn client_verifier_no_auth_yes_root() { // Triple checks we propagate the rustls::Error through fn client_verifier_fails_properly() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES.iter() { + for kt in KeyType::all_for_provider(&provider).iter() { let client_verifier = MockClientVerifier::new(ver_err, *kt, &provider); let server_config = server_config_with_verifier(*kt, client_verifier); let server_config = Arc::new(server_config); diff --git a/rustls/tests/server_cert_verifier.rs b/rustls/tests/server_cert_verifier.rs index b0d2508b65..402ffe551a 100644 --- a/rustls/tests/server_cert_verifier.rs +++ b/rustls/tests/server_cert_verifier.rs @@ -7,7 +7,7 @@ use super::*; mod common; use common::{ - ALL_KEY_TYPES, Altered, Arc, ErrorFromPeer, KeyType, MockServerVerifier, client_config_builder, + Altered, Arc, ErrorFromPeer, KeyType, MockServerVerifier, client_config_builder, client_config_builder_with_versions, do_handshake, do_handshake_until_both_error, do_handshake_until_error, make_client_config_with_versions, make_pair_for_arc_configs, make_server_config, server_config_builder, transfer_altered, @@ -33,7 +33,7 @@ use x509_parser::x509::X509Name; #[test] fn client_can_override_certificate_verification() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES.iter() { + for kt in KeyType::all_for_provider(&provider).iter() { let verifier = Arc::new(MockServerVerifier::accepts_anything()); let server_config = Arc::new(make_server_config(*kt, &provider)); @@ -54,7 +54,7 @@ fn client_can_override_certificate_verification() { #[test] fn client_can_override_certificate_verification_and_reject_certificate() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES.iter() { + for kt in KeyType::all_for_provider(&provider).iter() { let verifier = Arc::new(MockServerVerifier::rejects_certificate( Error::InvalidMessage(InvalidMessage::HandshakePayloadTooLarge), )); @@ -87,7 +87,7 @@ fn client_can_override_certificate_verification_and_reject_certificate() { #[test] fn client_can_override_certificate_verification_and_reject_tls12_signatures() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES.iter() { + for kt in KeyType::all_for_provider(&provider).iter() { let mut client_config = make_client_config_with_versions(*kt, &[&rustls::version::TLS12], &provider); let verifier = Arc::new(MockServerVerifier::rejects_tls12_signatures( @@ -118,7 +118,7 @@ fn client_can_override_certificate_verification_and_reject_tls12_signatures() { #[test] fn client_can_override_certificate_verification_and_reject_tls13_signatures() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES.iter() { + for kt in KeyType::all_for_provider(&provider).iter() { let mut client_config = make_client_config_with_versions( *kt, &[&rustls::version::TLS13], @@ -152,7 +152,7 @@ fn client_can_override_certificate_verification_and_reject_tls13_signatures() { #[test] fn client_can_override_certificate_verification_and_offer_no_signature_schemes() { let provider = provider::default_provider(); - for kt in ALL_KEY_TYPES.iter() { + for kt in KeyType::all_for_provider(&provider).iter() { let verifier = Arc::new(MockServerVerifier::offers_no_signature_schemes()); let server_config = Arc::new(make_server_config(*kt, &provider)); pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy