Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 81 additions & 24 deletions stdlib/src/ssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ mod _ssl {
SSL_ERROR_WANT_CONNECT,
SSL_ERROR_WANT_READ,
SSL_ERROR_WANT_WRITE,
// X509_V_FLAG_CRL_CHECK as VERIFY_CRL_CHECK_LEAF,
// sys::X509_V_FLAG_CRL_CHECK|sys::X509_V_FLAG_CRL_CHECK_ALL as VERIFY_CRL_CHECK_CHAIN
// X509_V_FLAG_X509_STRICT as VERIFY_X509_STRICT,
SSL_ERROR_ZERO_RETURN,
SSL_OP_CIPHER_SERVER_PREFERENCE as OP_CIPHER_SERVER_PREFERENCE,
SSL_OP_ENABLE_MIDDLEBOX_COMPAT as OP_ENABLE_MIDDLEBOX_COMPAT,
Expand All @@ -114,6 +111,11 @@ mod _ssl {
X509_V_FLAG_X509_STRICT as VERIFY_X509_STRICT,
};

// CRL verification constants
#[pyattr]
const VERIFY_CRL_CHECK_CHAIN: libc::c_ulong =
sys::X509_V_FLAG_CRL_CHECK | sys::X509_V_FLAG_CRL_CHECK_ALL;

// taken from CPython, should probably be kept up to date with their version if it ever changes
#[pyattr]
const _DEFAULT_CIPHERS: &str =
Expand Down Expand Up @@ -631,6 +633,12 @@ mod _ssl {
Ok(())
}

#[cfg(ossl110)]
#[pygetset]
fn security_level(&self) -> i32 {
unsafe { SSL_CTX_get_security_level(self.ctx().as_ptr()) }
}

#[pymethod]
fn set_ciphers(&self, cipherlist: PyStrRef, vm: &VirtualMachine) -> PyResult<()> {
let ciphers = cipherlist.as_str();
Expand Down Expand Up @@ -677,19 +685,29 @@ mod _ssl {
}

#[pymethod]
fn set_ecdh_curve(&self, name: PyStrRef, vm: &VirtualMachine) -> PyResult<()> {
fn set_ecdh_curve(
&self,
name: Either<PyStrRef, ArgBytesLike>,
vm: &VirtualMachine,
) -> PyResult<()> {
use openssl::ec::{EcGroup, EcKey};

let curve_name = name.as_str();
if curve_name.contains('\0') {
return Err(exceptions::cstring_error(vm));
}
// Convert name to CString, supporting both str and bytes
let name_cstr = match name {
Either::A(s) => {
if s.as_str().contains('\0') {
return Err(exceptions::cstring_error(vm));
}
s.to_cstring(vm)?
}
Either::B(b) => std::ffi::CString::new(b.borrow_buf().to_vec())
.map_err(|_| exceptions::cstring_error(vm))?,
};

// Find the NID for the curve name using OBJ_sn2nid
let name_cstr = name.to_cstring(vm)?;
let nid_raw = unsafe { sys::OBJ_sn2nid(name_cstr.as_ptr()) };
if nid_raw == 0 {
return Err(vm.new_value_error(format!("unknown curve name: {}", curve_name)));
return Err(vm.new_value_error("unknown curve name"));
}
let nid = Nid::from_raw(nid_raw);

Expand Down Expand Up @@ -794,6 +812,47 @@ mod _ssl {
self.check_hostname.store(ch);
}

// PY_PROTO_MINIMUM_SUPPORTED = -2, PY_PROTO_MAXIMUM_SUPPORTED = -1
#[pygetset]
fn minimum_version(&self) -> i32 {
let ctx = self.ctx();
let version = unsafe { sys::SSL_CTX_get_min_proto_version(ctx.as_ptr()) };
if version == 0 {
-2 // PY_PROTO_MINIMUM_SUPPORTED
} else {
version
}
}
#[pygetset(setter)]
fn set_minimum_version(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> {
let ctx = self.builder();
let result = unsafe { sys::SSL_CTX_set_min_proto_version(ctx.as_ptr(), value) };
if result == 0 {
return Err(vm.new_value_error("invalid protocol version"));
}
Ok(())
}

#[pygetset]
fn maximum_version(&self) -> i32 {
let ctx = self.ctx();
let version = unsafe { sys::SSL_CTX_get_max_proto_version(ctx.as_ptr()) };
if version == 0 {
-1 // PY_PROTO_MAXIMUM_SUPPORTED
} else {
version
}
}
#[pygetset(setter)]
fn set_maximum_version(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> {
let ctx = self.builder();
let result = unsafe { sys::SSL_CTX_set_max_proto_version(ctx.as_ptr(), value) };
if result == 0 {
return Err(vm.new_value_error("invalid protocol version"));
}
Ok(())
}

#[pymethod]
fn set_default_verify_paths(&self, vm: &VirtualMachine) -> PyResult<()> {
cfg_if::cfg_if! {
Expand Down Expand Up @@ -852,12 +911,6 @@ mod _ssl {
if let (None, None, None) = (&args.cafile, &args.capath, &args.cadata) {
return Err(vm.new_type_error("cafile, capath and cadata cannot be all omitted"));
}
if let Some(cafile) = &args.cafile {
cafile.ensure_no_nul(vm)?
}
if let Some(capath) = &args.capath {
capath.ensure_no_nul(vm)?
}

#[cold]
fn invalid_cadata(vm: &VirtualMachine) -> PyBaseExceptionRef {
Expand Down Expand Up @@ -887,11 +940,10 @@ mod _ssl {
}

if args.cafile.is_some() || args.capath.is_some() {
ctx.load_verify_locations(
args.cafile.as_ref().map(|s| s.as_str().as_ref()),
args.capath.as_ref().map(|s| s.as_str().as_ref()),
)
.map_err(|e| convert_openssl_error(vm, e))?;
let cafile_path = args.cafile.map(|p| p.to_path_buf(vm)).transpose()?;
let capath_path = args.capath.map(|p| p.to_path_buf(vm)).transpose()?;
ctx.load_verify_locations(cafile_path.as_deref(), capath_path.as_deref())
.map_err(|e| convert_openssl_error(vm, e))?;
}

Ok(())
Expand Down Expand Up @@ -1064,9 +1116,9 @@ mod _ssl {
#[derive(FromArgs)]
struct LoadVerifyLocationsArgs {
#[pyarg(any, default)]
cafile: Option<PyStrRef>,
cafile: Option<FsPath>,
#[pyarg(any, default)]
capath: Option<PyStrRef>,
capath: Option<FsPath>,
#[pyarg(any, default)]
cadata: Option<Either<PyStrRef, ArgBytesLike>>,
}
Expand Down Expand Up @@ -1794,6 +1846,11 @@ mod _ssl {
fn SSL_verify_client_post_handshake(ssl: *const sys::SSL) -> libc::c_int;
}

#[cfg(ossl110)]
unsafe extern "C" {
fn SSL_CTX_get_security_level(ctx: *const sys::SSL_CTX) -> libc::c_int;
}

// OpenSSL BIO helper functions
// These are typically macros in OpenSSL, implemented via BIO_ctrl
const BIO_CTRL_PENDING: libc::c_int = 10;
Expand Down Expand Up @@ -2082,7 +2139,7 @@ mod _ssl {
let lib = sys::ERR_GET_LIB(err_code);
if lib == ERR_LIB_SSL && reason == SSL_R_UNEXPECTED_EOF_WHILE_READING {
return vm.new_exception(
vm.class("_ssl", "SSLEOFError"),
PySslEOFError::class(&vm.ctx).to_owned(),
vec![
vm.ctx.new_int(SSL_ERROR_EOF).into(),
vm.ctx
Expand Down
16 changes: 7 additions & 9 deletions stdlib/src/ssl/cert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,15 @@ pub(crate) mod ssl_cert {
// IPv4
format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3])
} else if ip.len() == 16 {
// IPv6 - format like: "X:X:X:X:X:X:X:X" (not compressed)
// IPv6 - format with all zeros visible (not compressed)
let ip_addr = std::net::Ipv6Addr::from([
ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], ip[8],
ip[9], ip[10], ip[11], ip[12], ip[13], ip[14], ip[15],
]);
let s = ip_addr.segments();
format!(
"{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}",
(ip[0] as u16) << 8 | ip[1] as u16,
(ip[2] as u16) << 8 | ip[3] as u16,
(ip[4] as u16) << 8 | ip[5] as u16,
(ip[6] as u16) << 8 | ip[7] as u16,
(ip[8] as u16) << 8 | ip[9] as u16,
(ip[10] as u16) << 8 | ip[11] as u16,
(ip[12] as u16) << 8 | ip[13] as u16,
(ip[14] as u16) << 8 | ip[15] as u16
s[0], s[1], s[2], s[3], s[4], s[5], s[6], s[7]
)
} else {
// Fallback for unexpected length
Expand Down
Loading