Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 117 additions & 126 deletions crates/stdlib/src/ssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2598,24 +2598,26 @@ mod _ssl {
fn complete_handshake(&self, vm: &VirtualMachine) -> PyResult<()> {
*self.handshake_done.lock() = true;

// Check if session was resumed before creating session object
let conn_guard = self.connection.lock();
if let Some(ref conn) = *conn_guard {
let was_resumed = conn.is_session_resumed();
*self.session_was_reused.lock() = was_resumed;
// Check if session was resumed - get value and release lock immediately
let was_resumed = self
.connection
.lock()
.as_ref()
.map(|conn| conn.is_session_resumed())
.unwrap_or(false);

// Update context session statistics if server-side
if self.server_side {
let context = self.context.read();
// Increment accept count for every successful server handshake
context.accept_count.fetch_add(1, Ordering::SeqCst);
// Increment hits count if session was resumed
if was_resumed {
context.session_hits.fetch_add(1, Ordering::SeqCst);
}
*self.session_was_reused.lock() = was_resumed;

// Update context session statistics if server-side
if self.server_side {
let context = self.context.read();
// Increment accept count for every successful server handshake
context.accept_count.fetch_add(1, Ordering::SeqCst);
// Increment hits count if session was resumed
if was_resumed {
context.session_hits.fetch_add(1, Ordering::SeqCst);
}
}
drop(conn_guard);

// Track CA certificate used during handshake (client-side only)
// This simulates lazy loading behavior for capath certificates
Expand Down Expand Up @@ -3209,62 +3211,46 @@ mod _ssl {
}

// Perform the actual handshake by exchanging data with the socket/BIO
match conn_guard.as_mut() {
Some(TlsConnection::Client(_conn)) => {
// CLIENT is simple - no SNI callback handling needed
ssl_do_handshake(conn_guard.as_mut().unwrap(), self, vm)
.map_err(|e| e.into_py_err(vm))?;

drop(conn_guard);
self.complete_handshake(vm)?;
Ok(())
}
Some(TlsConnection::Server(_conn)) => {
// Use OpenSSL-compatible handshake for server
// Handle SNI callback restart
match ssl_do_handshake(conn_guard.as_mut().unwrap(), self, vm) {
Ok(()) => {
// Handshake completed successfully
drop(conn_guard);
self.complete_handshake(vm)?;
Ok(())
}
Err(SslError::SniCallbackRestart) => {
// SNI detected - need to call callback and recreate connection

// CRITICAL: Drop connection lock BEFORE calling Python callback to avoid deadlock
//
// Deadlock scenario if we keep the lock:
// 1. This thread holds self.connection.lock()
// 2. Python callback invokes other SSL methods (e.g., getpeercert(), cipher())
// 3. Those methods try to acquire self.connection.lock() again
// 4. PyMutex (parking_lot::Mutex) is not reentrant -> DEADLOCK
//
// Trade-off: By dropping the lock, we lose the ability to send TLS alerts
// because Rustls doesn't provide a send_fatal_alert() API. See detailed
// explanation in invoke_sni_callback() where we set _reason attribute.
drop(conn_guard);

// Get the SNI name that was extracted (may be None if client didn't send SNI)
let sni_name = self.get_extracted_sni_name();

// Now safe to call Python callback (no locks held)
self.invoke_sni_callback(sni_name.as_deref(), vm)?;

// Clear connection to trigger recreation
*self.connection.lock() = None;

// Recursively call do_handshake to recreate with new context
self.do_handshake(vm)
}
Err(e) => {
// Other errors - convert to Python exception
drop(conn_guard);
Err(e.into_py_err(vm))
}
let conn = conn_guard.as_mut().expect("unreachable");
let is_client = matches!(conn, TlsConnection::Client(_));
let handshake_result = ssl_do_handshake(conn, self, vm);
drop(conn_guard);

if is_client {
// CLIENT is simple - no SNI callback handling needed
handshake_result.map_err(|e| e.into_py_err(vm))?;
self.complete_handshake(vm)?;
Ok(())
} else {
// Use OpenSSL-compatible handshake for server
// Handle SNI callback restart
match handshake_result {
Ok(()) => {
// Handshake completed successfully
self.complete_handshake(vm)?;
Ok(())
}
Err(SslError::SniCallbackRestart) => {
// SNI detected - need to call callback and recreate connection

// Get the SNI name that was extracted (may be None if client didn't send SNI)
let sni_name = self.get_extracted_sni_name();

// Now safe to call Python callback (no locks held)
self.invoke_sni_callback(sni_name.as_deref(), vm)?;

// Clear connection to trigger recreation
*self.connection.lock() = None;

// Recursively call do_handshake to recreate with new context
self.do_handshake(vm)
}
Err(e) => {
// Other errors - convert to Python exception
Err(e.into_py_err(vm))
}
}
None => unreachable!(),
}
}

Expand Down Expand Up @@ -3323,9 +3309,6 @@ mod _ssl {
));
}

// Check for deferred certificate verification errors (TLS 1.3)
self.check_deferred_cert_error(vm)?;

// Helper function to handle return value based on buffer presence
let return_data = |data: Vec<u8>,
buffer_arg: &OptionalArg<ArgMemoryBuffer>,
Expand All @@ -3350,17 +3333,21 @@ mod _ssl {
}
};

let mut conn_guard = self.connection.lock();
let conn = conn_guard
.as_mut()
.ok_or_else(|| vm.new_value_error("Connection not established"))?;

// Use compat layer for unified read logic with proper EOF handling
// This matches CPython's SSL_read_ex() approach
let mut buf = vec![0u8; len];

match crate::ssl::compat::ssl_read(conn, &mut buf, self, vm) {
let read_result = {
let mut conn_guard = self.connection.lock();
let conn = conn_guard
.as_mut()
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
crate::ssl::compat::ssl_read(conn, &mut buf, self, vm)
};
match read_result {
Ok(n) => {
// Check for deferred certificate verification errors (TLS 1.3)
// Must be checked AFTER ssl_read, as the error is set during I/O
self.check_deferred_cert_error(vm)?;
buf.truncate(n);
return_data(buf, &buffer, vm)
}
Expand Down Expand Up @@ -3445,62 +3432,62 @@ mod _ssl {
));
}

// Check for deferred certificate verification errors (TLS 1.3)
self.check_deferred_cert_error(vm)?;

let mut conn_guard = self.connection.lock();
let conn = conn_guard
.as_mut()
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
{
let mut conn_guard = self.connection.lock();
let conn = conn_guard
.as_mut()
.ok_or_else(|| vm.new_value_error("Connection not established"))?;

let is_bio = self.is_bio_mode();
let data: &[u8] = data_bytes.as_ref();
let is_bio = self.is_bio_mode();
let data: &[u8] = data_bytes.as_ref();

// Write data in chunks to avoid filling the internal TLS buffer
// rustls has a limited internal buffer, so we need to flush periodically
const CHUNK_SIZE: usize = 16384; // 16KB chunks (typical TLS record size)
let mut written = 0;
// Write data in chunks to avoid filling the internal TLS buffer
// rustls has a limited internal buffer, so we need to flush periodically
const CHUNK_SIZE: usize = 16384; // 16KB chunks (typical TLS record size)
let mut written = 0;

while written < data.len() {
let chunk_end = std::cmp::min(written + CHUNK_SIZE, data.len());
let chunk = &data[written..chunk_end];
while written < data.len() {
let chunk_end = std::cmp::min(written + CHUNK_SIZE, data.len());
let chunk = &data[written..chunk_end];

// Write chunk to TLS layer
{
let mut writer = conn.writer();
use std::io::Write;
writer
.write_all(chunk)
.map_err(|e| vm.new_os_error(format!("Write failed: {e}")))?;
}
// Write chunk to TLS layer
{
let mut writer = conn.writer();
use std::io::Write;
writer
.write_all(chunk)
.map_err(|e| vm.new_os_error(format!("Write failed: {e}")))?;
}

written = chunk_end;
written = chunk_end;

// Flush TLS data to socket after each chunk
if conn.wants_write() {
if is_bio {
self.write_pending_tls(conn, vm)?;
} else {
// Socket mode: flush all pending TLS data
while conn.wants_write() {
let mut buf = Vec::new();
conn.write_tls(&mut buf)
.map_err(|e| vm.new_os_error(format!("TLS write failed: {e}")))?;

if !buf.is_empty() {
let timed_out =
self.sock_wait_for_io_impl(SelectKind::Write, vm)?;
if timed_out {
return Err(vm.new_os_error("Write operation timed out"));
}
// Flush TLS data to socket after each chunk
if conn.wants_write() {
if is_bio {
self.write_pending_tls(conn, vm)?;
} else {
// Socket mode: flush all pending TLS data
while conn.wants_write() {
let mut buf = Vec::new();
conn.write_tls(&mut buf).map_err(|e| {
vm.new_os_error(format!("TLS write failed: {e}"))
})?;

if !buf.is_empty() {
let timed_out =
self.sock_wait_for_io_impl(SelectKind::Write, vm)?;
if timed_out {
return Err(vm.new_os_error("Write operation timed out"));
}

match self.sock_send(buf, vm) {
Ok(_) => {}
Err(e) => {
if is_blocking_io_error(&e, vm) {
return Err(create_ssl_want_write_error(vm));
match self.sock_send(buf, vm) {
Ok(_) => {}
Err(e) => {
if is_blocking_io_error(&e, vm) {
return Err(create_ssl_want_write_error(vm));
}
return Err(e);
}
return Err(e);
}
}
}
Expand All @@ -3509,6 +3496,10 @@ mod _ssl {
}
}

// Check for deferred certificate verification errors (TLS 1.3)
// Must be checked AFTER write completes, as the error may be set during I/O
self.check_deferred_cert_error(vm)?;

Ok(data_len)
}

Expand Down
13 changes: 7 additions & 6 deletions crates/stdlib/src/ssl/cert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1067,15 +1067,16 @@ impl ClientCertVerifier for DeferredClientCertVerifier {
.inner
.verify_client_cert(end_entity, intermediates, now);

// If verification failed, store the error for later
if result.is_err() {
let error_msg = "TLS handshake failed: received fatal alert: UnknownCA".to_string();
// If verification failed, store the error for the server's Python code
// AND return the error so rustls sends the appropriate TLS alert
if let Err(ref e) = result {
let error_msg = format!("certificate verify failed: {e}");
*self.deferred_error.write() = Some(error_msg);
// Return the error to rustls so it sends the alert to the client
return result;
}

// Always return success to allow handshake to complete
// The error will be raised during the first I/O operation
Ok(ClientCertVerified::assertion())
result
}

fn verify_tls12_signature(
Expand Down
Loading
Loading