Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions derive-impl/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,13 @@ pub(crate) fn impl_pyexception(attr: PunctuatedNestedMeta, item: Item) -> Result
}
}
} else {
quote! {}
quote! {
impl ::rustpython_vm::PyPayload for #ident {
fn class(_ctx: &::rustpython_vm::vm::Context) -> &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType> {
<Self as ::rustpython_vm::class::StaticType>::static_type()
}
}
}
};
let impl_pyclass = if class_meta.has_impl()? {
quote! {
Expand Down Expand Up @@ -613,7 +619,7 @@ pub(crate) fn impl_pyexception_impl(attr: PunctuatedNestedMeta, item: Item) -> R
} else {
quote! {
#[pyslot]
pub(crate) fn slot_new(
pub fn slot_new(
cls: ::rustpython_vm::builtins::PyTypeRef,
args: ::rustpython_vm::function::FuncArgs,
vm: &::rustpython_vm::VirtualMachine,
Expand All @@ -634,7 +640,7 @@ pub(crate) fn impl_pyexception_impl(attr: PunctuatedNestedMeta, item: Item) -> R
quote! {
#[pyslot]
#[pymethod(name="__init__")]
pub(crate) fn slot_init(
pub fn slot_init(
zelf: ::rustpython_vm::PyObjectRef,
args: ::rustpython_vm::function::FuncArgs,
vm: &::rustpython_vm::VirtualMachine,
Expand Down
10 changes: 7 additions & 3 deletions derive-impl/src/pymodule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ enum AttrName {
Function,
Attr,
Class,
Exception,
}

impl std::fmt::Display for AttrName {
Expand All @@ -24,6 +25,7 @@ impl std::fmt::Display for AttrName {
Self::Function => "pyfunction",
Self::Attr => "pyattr",
Self::Class => "pyclass",
Self::Exception => "pyexception",
};
s.fmt(f)
}
Expand All @@ -37,6 +39,7 @@ impl FromStr for AttrName {
"pyfunction" => Self::Function,
"pyattr" => Self::Attr,
"pyclass" => Self::Class,
"pyexception" => Self::Exception,
s => {
return Err(s.to_owned());
}
Expand Down Expand Up @@ -232,7 +235,8 @@ fn module_item_new(
inner: ContentItemInner { index, attr_name },
py_attrs,
}),
AttrName::Class => Box::new(ClassItem {
// pyexception is treated like pyclass - both define types
AttrName::Class | AttrName::Exception => Box::new(ClassItem {
inner: ContentItemInner { index, attr_name },
py_attrs,
}),
Expand Down Expand Up @@ -302,13 +306,13 @@ where
result.push(item_new(i, attr_name, Vec::new()));
} else {
match attr_name {
AttrName::Class | AttrName::Function => {
AttrName::Class | AttrName::Function | AttrName::Exception => {
result.push(item_new(i, attr_name, py_attrs.clone()));
}
_ => {
bail_span!(
attr,
"#[pyclass] or #[pyfunction] only can follow #[pyattr]",
"#[pyclass], #[pyfunction], or #[pyexception] can follow #[pyattr]",
)
}
}
Expand Down
3 changes: 2 additions & 1 deletion derive-impl/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,8 @@ impl ClassItemMeta {
pub(crate) struct ExceptionItemMeta(ClassItemMeta);

impl ItemMeta for ExceptionItemMeta {
const ALLOWED_NAMES: &'static [&'static str] = &["name", "base", "unhashable", "ctx", "impl"];
const ALLOWED_NAMES: &'static [&'static str] =
&["module", "name", "base", "unhashable", "ctx", "impl"];

fn from_inner(inner: ItemMetaInner) -> Self {
Self(ClassItemMeta(inner))
Expand Down
156 changes: 94 additions & 62 deletions stdlib/src/ssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ mod _ssl {
},
socket::{self, PySocket},
vm::{
Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::{PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyTypeRef, PyWeak},
AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::{
PyBaseExceptionRef, PyBytesRef, PyListRef, PyOSError, PyStrRef, PyTypeRef, PyWeak,
},
class_or_notimplemented,
convert::{ToPyException, ToPyObject},
exceptions,
Expand Down Expand Up @@ -198,63 +200,84 @@ mod _ssl {
parse_version_info(openssl_api_version)
}

// SSL Exception Types

/// An error occurred in the SSL implementation.
#[pyattr(name = "SSLError", once)]
fn ssl_error(vm: &VirtualMachine) -> PyTypeRef {
vm.ctx.new_exception_type(
"ssl",
"SSLError",
Some(vec![vm.ctx.exceptions.os_error.to_owned()]),
)
#[pyattr]
#[pyexception(name = "SSLError", base = "PyOSError")]
#[derive(Debug)]
pub struct PySslError {}

#[pyexception]
impl PySslError {
// Returns strerror attribute if available, otherwise str(args)
#[pymethod]
fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult<PyStrRef> {
// Try to get strerror attribute first (OSError compatibility)
if let Ok(strerror) = exc.as_object().get_attr("strerror", vm)
&& !vm.is_none(&strerror)
{
return strerror.str(vm);
}

// Otherwise return str(args)
exc.args().as_object().str(vm)
}
}

/// A certificate could not be verified.
#[pyattr(name = "SSLCertVerificationError", once)]
fn ssl_cert_verification_error(vm: &VirtualMachine) -> PyTypeRef {
vm.ctx.new_exception_type(
"ssl",
"SSLCertVerificationError",
Some(vec![
ssl_error(vm),
vm.ctx.exceptions.value_error.to_owned(),
]),
)
}
#[pyattr]
#[pyexception(name = "SSLCertVerificationError", base = "PySslError")]
#[derive(Debug)]
pub struct PySslCertVerificationError {}

#[pyexception]
impl PySslCertVerificationError {}

/// SSL/TLS session closed cleanly.
#[pyattr(name = "SSLZeroReturnError", once)]
fn ssl_zero_return_error(vm: &VirtualMachine) -> PyTypeRef {
vm.ctx
.new_exception_type("ssl", "SSLZeroReturnError", Some(vec![ssl_error(vm)]))
}
#[pyattr]
#[pyexception(name = "SSLZeroReturnError", base = "PySslError")]
#[derive(Debug)]
pub struct PySslZeroReturnError {}

/// Non-blocking SSL socket needs to read more data before the requested operation can be completed.
#[pyattr(name = "SSLWantReadError", once)]
fn ssl_want_read_error(vm: &VirtualMachine) -> PyTypeRef {
vm.ctx
.new_exception_type("ssl", "SSLWantReadError", Some(vec![ssl_error(vm)]))
}
#[pyexception]
impl PySslZeroReturnError {}

/// Non-blocking SSL socket needs to write more data before the requested operation can be completed.
#[pyattr(name = "SSLWantWriteError", once)]
fn ssl_want_write_error(vm: &VirtualMachine) -> PyTypeRef {
vm.ctx
.new_exception_type("ssl", "SSLWantWriteError", Some(vec![ssl_error(vm)]))
}
/// Non-blocking SSL socket needs to read more data.
#[pyattr]
#[pyexception(name = "SSLWantReadError", base = "PySslError")]
#[derive(Debug)]
pub struct PySslWantReadError {}

#[pyexception]
impl PySslWantReadError {}

/// Non-blocking SSL socket needs to write more data.
#[pyattr]
#[pyexception(name = "SSLWantWriteError", base = "PySslError")]
#[derive(Debug)]
pub struct PySslWantWriteError {}

#[pyexception]
impl PySslWantWriteError {}

/// System error when attempting SSL operation.
#[pyattr(name = "SSLSyscallError", once)]
fn ssl_syscall_error(vm: &VirtualMachine) -> PyTypeRef {
vm.ctx
.new_exception_type("ssl", "SSLSyscallError", Some(vec![ssl_error(vm)]))
}
#[pyattr]
#[pyexception(name = "SSLSyscallError", base = "PySslError")]
#[derive(Debug)]
pub struct PySslSyscallError {}

#[pyexception]
impl PySslSyscallError {}

/// SSL/TLS connection terminated abruptly.
#[pyattr(name = "SSLEOFError", once)]
fn ssl_eof_error(vm: &VirtualMachine) -> PyTypeRef {
vm.ctx
.new_exception_type("ssl", "SSLEOFError", Some(vec![ssl_error(vm)]))
}
#[pyattr]
#[pyexception(name = "SSLEOFError", base = "PySslError")]
#[derive(Debug)]
pub struct PySslEOFError {}

#[pyexception]
impl PySslEOFError {}

type OpensslVersionInfo = (u8, u8, u8, u8, u8);
const fn parse_version_info(mut n: i64) -> OpensslVersionInfo {
Expand Down Expand Up @@ -617,7 +640,10 @@ mod _ssl {
return Err(exceptions::cstring_error(vm));
}
self.builder().set_cipher_list(ciphers).map_err(|_| {
vm.new_exception_msg(ssl_error(vm), "No cipher can be selected.".to_owned())
vm.new_exception_msg(
PySslError::class(&vm.ctx).to_owned(),
"No cipher can be selected.".to_owned(),
)
})
}

Expand Down Expand Up @@ -744,13 +770,13 @@ mod _ssl {

if clear != 0 && sys::X509_VERIFY_PARAM_clear_flags(param, clear) == 0 {
return Err(vm.new_exception_msg(
ssl_error(vm),
PySslError::class(&vm.ctx).to_owned(),
"Failed to clear verify flags".to_owned(),
));
}
if set != 0 && sys::X509_VERIFY_PARAM_set_flags(param, set) == 0 {
return Err(vm.new_exception_msg(
ssl_error(vm),
PySslError::class(&vm.ctx).to_owned(),
"Failed to set verify flags".to_owned(),
));
}
Expand Down Expand Up @@ -934,13 +960,13 @@ mod _ssl {
// validate socket type and context protocol
if !args.server_side && zelf.protocol == SslVersion::TlsServer {
return Err(vm.new_exception_msg(
ssl_error(vm),
PySslError::class(&vm.ctx).to_owned(),
"Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(),
));
}
if args.server_side && zelf.protocol == SslVersion::TlsClient {
return Err(vm.new_exception_msg(
ssl_error(vm),
PySslError::class(&vm.ctx).to_owned(),
"Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(),
));
}
Expand Down Expand Up @@ -1124,7 +1150,7 @@ mod _ssl {

fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_msg(
ssl_error(vm),
PySslError::class(&vm.ctx).to_owned(),
"Underlying socket has been closed.".to_owned(),
)
}
Expand Down Expand Up @@ -1390,7 +1416,7 @@ mod _ssl {
let result = unsafe { SSL_verify_client_post_handshake(stream.ssl().as_ptr()) };
if result == 0 {
Err(vm.new_exception_msg(
ssl_error(vm),
PySslError::class(&vm.ctx).to_owned(),
"Post-handshake authentication failed".to_owned(),
))
} else {
Expand Down Expand Up @@ -1422,7 +1448,7 @@ mod _ssl {
// Return the underlying socket
} else {
return Err(vm.new_exception_msg(
ssl_error(vm),
PySslError::class(&vm.ctx).to_owned(),
format!("SSL shutdown failed: error code {}", err),
));
}
Expand Down Expand Up @@ -1854,7 +1880,7 @@ mod _ssl {
fn write(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult<i32> {
if self.eof_written.load() {
return Err(vm.new_exception_msg(
ssl_error(vm),
PySslError::class(&vm.ctx).to_owned(),
"cannot write() after write_eof()".to_owned(),
));
}
Expand Down Expand Up @@ -1953,7 +1979,7 @@ mod _ssl {

#[track_caller]
fn convert_openssl_error(vm: &VirtualMachine, err: ErrorStack) -> PyBaseExceptionRef {
let cls = ssl_error(vm);
let cls = PySslError::class(&vm.ctx).to_owned();
match err.errors().last() {
Some(e) => {
let caller = std::panic::Location::caller();
Expand Down Expand Up @@ -2012,25 +2038,31 @@ mod _ssl {
let e = e.borrow();
let (cls, msg) = match e.code() {
ssl::ErrorCode::WANT_READ => (
vm.class("_ssl", "SSLWantReadError"),
PySslWantReadError::class(&vm.ctx).to_owned(),
"The operation did not complete (read)",
),
ssl::ErrorCode::WANT_WRITE => (
vm.class("_ssl", "SSLWantWriteError"),
PySslWantWriteError::class(&vm.ctx).to_owned(),
"The operation did not complete (write)",
),
ssl::ErrorCode::SYSCALL => match e.io_error() {
Some(io_err) => return io_err.to_pyexception(vm),
None => (
vm.class("_ssl", "SSLSyscallError"),
PySslSyscallError::class(&vm.ctx).to_owned(),
"EOF occurred in violation of protocol",
),
},
ssl::ErrorCode::SSL => match e.ssl_error() {
Some(e) => return convert_openssl_error(vm, e.clone()),
None => (ssl_error(vm), "A failure in the SSL library occurred"),
None => (
PySslError::class(&vm.ctx).to_owned(),
"A failure in the SSL library occurred",
),
},
_ => (ssl_error(vm), "A failure in the SSL library occurred"),
_ => (
PySslError::class(&vm.ctx).to_owned(),
"A failure in the SSL library occurred",
),
};
vm.new_exception_msg(cls, msg.to_owned())
}
Expand Down
6 changes: 3 additions & 3 deletions vm/src/exceptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1450,7 +1450,7 @@ pub(super) mod types {
}
#[cfg(not(target_arch = "wasm32"))]
#[pyslot]
fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
pub fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
// We need this method, because of how `CPython` copies `init`
// from `BaseException` in `SimpleExtendsException` macro.
// See: `BaseException_new`
Expand All @@ -1465,12 +1465,12 @@ pub(super) mod types {
}
#[cfg(target_arch = "wasm32")]
#[pyslot]
fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
pub fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
PyBaseException::slot_new(cls, args, vm)
}
#[pyslot]
#[pymethod(name = "__init__")]
fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> {
pub fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> {
let len = args.args.len();
let mut new_args = args;
if (3..=5).contains(&len) {
Expand Down
Loading