diff --git a/derive-impl/src/pyclass.rs b/derive-impl/src/pyclass.rs index 6c589896e6..99cddc8462 100644 --- a/derive-impl/src/pyclass.rs +++ b/derive-impl/src/pyclass.rs @@ -551,7 +551,13 @@ pub(crate) fn impl_pyexception(attr: PunctuatedNestedMeta, item: Item) -> Result } } } else { - quote! {} + quote! { + impl ::rustpython_vm::PyPayload for #ident { + fn class(_ctx: &::rustpython_vm::vm::Context) -> &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType> { + ::static_type() + } + } + } }; let impl_pyclass = if class_meta.has_impl()? { quote! { @@ -613,7 +619,7 @@ pub(crate) fn impl_pyexception_impl(attr: PunctuatedNestedMeta, item: Item) -> R } else { quote! { #[pyslot] - pub(crate) fn slot_new( + pub fn slot_new( cls: ::rustpython_vm::builtins::PyTypeRef, args: ::rustpython_vm::function::FuncArgs, vm: &::rustpython_vm::VirtualMachine, @@ -634,7 +640,7 @@ pub(crate) fn impl_pyexception_impl(attr: PunctuatedNestedMeta, item: Item) -> R quote! { #[pyslot] #[pymethod(name="__init__")] - pub(crate) fn slot_init( + pub fn slot_init( zelf: ::rustpython_vm::PyObjectRef, args: ::rustpython_vm::function::FuncArgs, vm: &::rustpython_vm::VirtualMachine, diff --git a/derive-impl/src/pymodule.rs b/derive-impl/src/pymodule.rs index 9e052cd1ff..a8588a762e 100644 --- a/derive-impl/src/pymodule.rs +++ b/derive-impl/src/pymodule.rs @@ -16,6 +16,7 @@ enum AttrName { Function, Attr, Class, + Exception, } impl std::fmt::Display for AttrName { @@ -24,6 +25,7 @@ impl std::fmt::Display for AttrName { Self::Function => "pyfunction", Self::Attr => "pyattr", Self::Class => "pyclass", + Self::Exception => "pyexception", }; s.fmt(f) } @@ -37,6 +39,7 @@ impl FromStr for AttrName { "pyfunction" => Self::Function, "pyattr" => Self::Attr, "pyclass" => Self::Class, + "pyexception" => Self::Exception, s => { return Err(s.to_owned()); } @@ -232,7 +235,8 @@ fn module_item_new( inner: ContentItemInner { index, attr_name }, py_attrs, }), - AttrName::Class => Box::new(ClassItem { + // pyexception is treated like pyclass - both define types + AttrName::Class | AttrName::Exception => Box::new(ClassItem { inner: ContentItemInner { index, attr_name }, py_attrs, }), @@ -302,13 +306,13 @@ where result.push(item_new(i, attr_name, Vec::new())); } else { match attr_name { - AttrName::Class | AttrName::Function => { + AttrName::Class | AttrName::Function | AttrName::Exception => { result.push(item_new(i, attr_name, py_attrs.clone())); } _ => { bail_span!( attr, - "#[pyclass] or #[pyfunction] only can follow #[pyattr]", + "#[pyclass], #[pyfunction], or #[pyexception] can follow #[pyattr]", ) } } diff --git a/derive-impl/src/util.rs b/derive-impl/src/util.rs index 47a0606f05..f7f0d28fb9 100644 --- a/derive-impl/src/util.rs +++ b/derive-impl/src/util.rs @@ -446,7 +446,8 @@ impl ClassItemMeta { pub(crate) struct ExceptionItemMeta(ClassItemMeta); impl ItemMeta for ExceptionItemMeta { - const ALLOWED_NAMES: &'static [&'static str] = &["name", "base", "unhashable", "ctx", "impl"]; + const ALLOWED_NAMES: &'static [&'static str] = + &["module", "name", "base", "unhashable", "ctx", "impl"]; fn from_inner(inner: ItemMetaInner) -> Self { Self(ClassItemMeta(inner)) diff --git a/stdlib/src/ssl.rs b/stdlib/src/ssl.rs index bf77d6b690..c642d6e087 100644 --- a/stdlib/src/ssl.rs +++ b/stdlib/src/ssl.rs @@ -38,8 +38,10 @@ mod _ssl { }, socket::{self, PySocket}, vm::{ - Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, - builtins::{PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyTypeRef, PyWeak}, + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{ + PyBaseExceptionRef, PyBytesRef, PyListRef, PyOSError, PyStrRef, PyTypeRef, PyWeak, + }, class_or_notimplemented, convert::{ToPyException, ToPyObject}, exceptions, @@ -198,63 +200,84 @@ mod _ssl { parse_version_info(openssl_api_version) } + // SSL Exception Types + /// An error occurred in the SSL implementation. - #[pyattr(name = "SSLError", once)] - fn ssl_error(vm: &VirtualMachine) -> PyTypeRef { - vm.ctx.new_exception_type( - "ssl", - "SSLError", - Some(vec![vm.ctx.exceptions.os_error.to_owned()]), - ) + #[pyattr] + #[pyexception(name = "SSLError", base = "PyOSError")] + #[derive(Debug)] + pub struct PySslError {} + + #[pyexception] + impl PySslError { + // Returns strerror attribute if available, otherwise str(args) + #[pymethod] + fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { + // Try to get strerror attribute first (OSError compatibility) + if let Ok(strerror) = exc.as_object().get_attr("strerror", vm) + && !vm.is_none(&strerror) + { + return strerror.str(vm); + } + + // Otherwise return str(args) + exc.args().as_object().str(vm) + } } /// A certificate could not be verified. - #[pyattr(name = "SSLCertVerificationError", once)] - fn ssl_cert_verification_error(vm: &VirtualMachine) -> PyTypeRef { - vm.ctx.new_exception_type( - "ssl", - "SSLCertVerificationError", - Some(vec![ - ssl_error(vm), - vm.ctx.exceptions.value_error.to_owned(), - ]), - ) - } + #[pyattr] + #[pyexception(name = "SSLCertVerificationError", base = "PySslError")] + #[derive(Debug)] + pub struct PySslCertVerificationError {} + + #[pyexception] + impl PySslCertVerificationError {} /// SSL/TLS session closed cleanly. - #[pyattr(name = "SSLZeroReturnError", once)] - fn ssl_zero_return_error(vm: &VirtualMachine) -> PyTypeRef { - vm.ctx - .new_exception_type("ssl", "SSLZeroReturnError", Some(vec![ssl_error(vm)])) - } + #[pyattr] + #[pyexception(name = "SSLZeroReturnError", base = "PySslError")] + #[derive(Debug)] + pub struct PySslZeroReturnError {} - /// Non-blocking SSL socket needs to read more data before the requested operation can be completed. - #[pyattr(name = "SSLWantReadError", once)] - fn ssl_want_read_error(vm: &VirtualMachine) -> PyTypeRef { - vm.ctx - .new_exception_type("ssl", "SSLWantReadError", Some(vec![ssl_error(vm)])) - } + #[pyexception] + impl PySslZeroReturnError {} - /// Non-blocking SSL socket needs to write more data before the requested operation can be completed. - #[pyattr(name = "SSLWantWriteError", once)] - fn ssl_want_write_error(vm: &VirtualMachine) -> PyTypeRef { - vm.ctx - .new_exception_type("ssl", "SSLWantWriteError", Some(vec![ssl_error(vm)])) - } + /// Non-blocking SSL socket needs to read more data. + #[pyattr] + #[pyexception(name = "SSLWantReadError", base = "PySslError")] + #[derive(Debug)] + pub struct PySslWantReadError {} + + #[pyexception] + impl PySslWantReadError {} + + /// Non-blocking SSL socket needs to write more data. + #[pyattr] + #[pyexception(name = "SSLWantWriteError", base = "PySslError")] + #[derive(Debug)] + pub struct PySslWantWriteError {} + + #[pyexception] + impl PySslWantWriteError {} /// System error when attempting SSL operation. - #[pyattr(name = "SSLSyscallError", once)] - fn ssl_syscall_error(vm: &VirtualMachine) -> PyTypeRef { - vm.ctx - .new_exception_type("ssl", "SSLSyscallError", Some(vec![ssl_error(vm)])) - } + #[pyattr] + #[pyexception(name = "SSLSyscallError", base = "PySslError")] + #[derive(Debug)] + pub struct PySslSyscallError {} + + #[pyexception] + impl PySslSyscallError {} /// SSL/TLS connection terminated abruptly. - #[pyattr(name = "SSLEOFError", once)] - fn ssl_eof_error(vm: &VirtualMachine) -> PyTypeRef { - vm.ctx - .new_exception_type("ssl", "SSLEOFError", Some(vec![ssl_error(vm)])) - } + #[pyattr] + #[pyexception(name = "SSLEOFError", base = "PySslError")] + #[derive(Debug)] + pub struct PySslEOFError {} + + #[pyexception] + impl PySslEOFError {} type OpensslVersionInfo = (u8, u8, u8, u8, u8); const fn parse_version_info(mut n: i64) -> OpensslVersionInfo { @@ -617,7 +640,10 @@ mod _ssl { return Err(exceptions::cstring_error(vm)); } self.builder().set_cipher_list(ciphers).map_err(|_| { - vm.new_exception_msg(ssl_error(vm), "No cipher can be selected.".to_owned()) + vm.new_exception_msg( + PySslError::class(&vm.ctx).to_owned(), + "No cipher can be selected.".to_owned(), + ) }) } @@ -744,13 +770,13 @@ mod _ssl { if clear != 0 && sys::X509_VERIFY_PARAM_clear_flags(param, clear) == 0 { return Err(vm.new_exception_msg( - ssl_error(vm), + PySslError::class(&vm.ctx).to_owned(), "Failed to clear verify flags".to_owned(), )); } if set != 0 && sys::X509_VERIFY_PARAM_set_flags(param, set) == 0 { return Err(vm.new_exception_msg( - ssl_error(vm), + PySslError::class(&vm.ctx).to_owned(), "Failed to set verify flags".to_owned(), )); } @@ -934,13 +960,13 @@ mod _ssl { // validate socket type and context protocol if !args.server_side && zelf.protocol == SslVersion::TlsServer { return Err(vm.new_exception_msg( - ssl_error(vm), + PySslError::class(&vm.ctx).to_owned(), "Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(), )); } if args.server_side && zelf.protocol == SslVersion::TlsClient { return Err(vm.new_exception_msg( - ssl_error(vm), + PySslError::class(&vm.ctx).to_owned(), "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(), )); } @@ -1124,7 +1150,7 @@ mod _ssl { fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef { vm.new_exception_msg( - ssl_error(vm), + PySslError::class(&vm.ctx).to_owned(), "Underlying socket has been closed.".to_owned(), ) } @@ -1390,7 +1416,7 @@ mod _ssl { let result = unsafe { SSL_verify_client_post_handshake(stream.ssl().as_ptr()) }; if result == 0 { Err(vm.new_exception_msg( - ssl_error(vm), + PySslError::class(&vm.ctx).to_owned(), "Post-handshake authentication failed".to_owned(), )) } else { @@ -1422,7 +1448,7 @@ mod _ssl { // Return the underlying socket } else { return Err(vm.new_exception_msg( - ssl_error(vm), + PySslError::class(&vm.ctx).to_owned(), format!("SSL shutdown failed: error code {}", err), )); } @@ -1854,7 +1880,7 @@ mod _ssl { fn write(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult { if self.eof_written.load() { return Err(vm.new_exception_msg( - ssl_error(vm), + PySslError::class(&vm.ctx).to_owned(), "cannot write() after write_eof()".to_owned(), )); } @@ -1953,7 +1979,7 @@ mod _ssl { #[track_caller] fn convert_openssl_error(vm: &VirtualMachine, err: ErrorStack) -> PyBaseExceptionRef { - let cls = ssl_error(vm); + let cls = PySslError::class(&vm.ctx).to_owned(); match err.errors().last() { Some(e) => { let caller = std::panic::Location::caller(); @@ -2012,25 +2038,31 @@ mod _ssl { let e = e.borrow(); let (cls, msg) = match e.code() { ssl::ErrorCode::WANT_READ => ( - vm.class("_ssl", "SSLWantReadError"), + PySslWantReadError::class(&vm.ctx).to_owned(), "The operation did not complete (read)", ), ssl::ErrorCode::WANT_WRITE => ( - vm.class("_ssl", "SSLWantWriteError"), + PySslWantWriteError::class(&vm.ctx).to_owned(), "The operation did not complete (write)", ), ssl::ErrorCode::SYSCALL => match e.io_error() { Some(io_err) => return io_err.to_pyexception(vm), None => ( - vm.class("_ssl", "SSLSyscallError"), + PySslSyscallError::class(&vm.ctx).to_owned(), "EOF occurred in violation of protocol", ), }, ssl::ErrorCode::SSL => match e.ssl_error() { Some(e) => return convert_openssl_error(vm, e.clone()), - None => (ssl_error(vm), "A failure in the SSL library occurred"), + None => ( + PySslError::class(&vm.ctx).to_owned(), + "A failure in the SSL library occurred", + ), }, - _ => (ssl_error(vm), "A failure in the SSL library occurred"), + _ => ( + PySslError::class(&vm.ctx).to_owned(), + "A failure in the SSL library occurred", + ), }; vm.new_exception_msg(cls, msg.to_owned()) } diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index 20bf467c45..029b12f5a4 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -1450,7 +1450,7 @@ pub(super) mod types { } #[cfg(not(target_arch = "wasm32"))] #[pyslot] - fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + pub fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { // We need this method, because of how `CPython` copies `init` // from `BaseException` in `SimpleExtendsException` macro. // See: `BaseException_new` @@ -1465,12 +1465,12 @@ pub(super) mod types { } #[cfg(target_arch = "wasm32")] #[pyslot] - fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + pub fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { PyBaseException::slot_new(cls, args, vm) } #[pyslot] #[pymethod(name = "__init__")] - fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + pub fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { let len = args.args.len(); let mut new_args = args; if (3..=5).contains(&len) {