diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 7f3855c529..dfe545488b 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -537,9 +537,9 @@ impl AsRef<[T]> for PyTupleTyped { } impl PyTupleTyped { - pub fn empty(vm: &VirtualMachine) -> Self { + pub fn empty(ctx: &Context) -> Self { Self { - tuple: vm.ctx.empty_tuple.clone(), + tuple: ctx.empty_tuple.clone(), _marker: PhantomData, } } diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index 94c4f2f668..c5792c6b2d 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -37,7 +37,7 @@ use std::{borrow::Borrow, collections::HashSet, ops::Deref, pin::Pin, ptr::NonNu pub struct PyType { pub base: Option, pub bases: PyRwLock>, - pub mro: PyRwLock>, + pub mro: PyRwLock>, // TODO: PyTypedTuple pub subclasses: PyRwLock>>, pub attributes: PyRwLock, pub slots: PyTypeSlots, @@ -48,7 +48,7 @@ unsafe impl crate::object::Traverse for PyType { fn traverse(&self, tracer_fn: &mut crate::object::TraverseFn<'_>) { self.base.traverse(tracer_fn); self.bases.traverse(tracer_fn); - self.mro.traverse(tracer_fn); + // self.mro.traverse(tracer_fn); self.subclasses.traverse(tracer_fn); self.attributes .read_recursive() @@ -158,6 +158,15 @@ fn downcast_qualname(value: PyObjectRef, vm: &VirtualMachine) -> PyResult, b: &Py) -> bool { + for item in a_mro { + if item.is(b) { + return true; + } + } + false +} + impl PyType { pub fn new_simple_heap( name: &str, @@ -197,6 +206,12 @@ impl PyType { Self::new_heap_inner(base, bases, attrs, slots, heaptype_ext, metaclass, ctx) } + /// Equivalent to CPython's PyType_Check macro + /// Checks if obj is an instance of type (or its subclass) + pub(crate) fn check(obj: &PyObject) -> Option<&Py> { + obj.downcast_ref::() + } + fn resolve_mro(bases: &[PyRef]) -> Result, String> { // Check for duplicates in bases. let mut unique_bases = HashSet::new(); @@ -223,8 +238,6 @@ impl PyType { metaclass: PyRef, ctx: &Context, ) -> Result, String> { - let mro = Self::resolve_mro(&bases)?; - if base.slots.flags.has_feature(PyTypeFlags::HAS_DICT) { slots.flags |= PyTypeFlags::HAS_DICT } @@ -241,6 +254,7 @@ impl PyType { } } + let mro = Self::resolve_mro(&bases)?; let new_type = PyRef::new_ref( PyType { base: Some(base), @@ -254,6 +268,7 @@ impl PyType { metaclass, None, ); + new_type.mro.write().insert(0, new_type.clone()); new_type.init_slots(ctx); @@ -285,7 +300,6 @@ impl PyType { let bases = PyRwLock::new(vec![base.clone()]); let mro = base.mro_map_collect(|x| x.to_owned()); - let new_type = PyRef::new_ref( PyType { base: Some(base), @@ -299,6 +313,7 @@ impl PyType { metaclass, None, ); + new_type.mro.write().insert(0, new_type.clone()); let weakref_type = super::PyWeak::static_type(); for base in new_type.bases.read().iter() { @@ -317,7 +332,7 @@ impl PyType { #[allow(clippy::mutable_key_type)] let mut slot_name_set = std::collections::HashSet::new(); - for cls in self.mro.read().iter() { + for cls in self.mro.read()[1..].iter() { for &name in cls.attributes.read().keys() { if name == identifier!(ctx, __new__) { continue; @@ -366,8 +381,7 @@ impl PyType { } pub fn get_super_attr(&self, attr_name: &'static PyStrInterned) -> Option { - self.mro - .read() + self.mro.read()[1..] .iter() .find_map(|class| class.attributes.read().get(attr_name).cloned()) } @@ -375,9 +389,7 @@ impl PyType { // This is the internal has_attr implementation for fast lookup on a class. pub fn has_attr(&self, attr_name: &'static PyStrInterned) -> bool { self.attributes.read().contains_key(attr_name) - || self - .mro - .read() + || self.mro.read()[1..] .iter() .any(|c| c.attributes.read().contains_key(attr_name)) } @@ -386,10 +398,7 @@ impl PyType { // Gather all members here: let mut attributes = PyAttributes::default(); - for bc in std::iter::once(self) - .chain(self.mro.read().iter().map(|cls| -> &PyType { cls })) - .rev() - { + for bc in self.mro.read().iter().map(|cls| -> &PyType { cls }).rev() { for (name, value) in bc.attributes.read().iter() { attributes.insert(name.to_owned(), value.clone()); } @@ -439,26 +448,35 @@ impl PyType { } impl Py { + pub(crate) fn is_subtype(&self, other: &Py) -> bool { + is_subtype_with_mro(&self.mro.read(), self, other) + } + + /// Equivalent to CPython's PyType_CheckExact macro + /// Checks if obj is exactly a type (not a subclass) + pub fn check_exact<'a>(obj: &'a PyObject, vm: &VirtualMachine) -> Option<&'a Py> { + obj.downcast_ref_if_exact::(vm) + } + /// Determines if `subclass` is actually a subclass of `cls`, this doesn't call __subclasscheck__, /// so only use this if `cls` is known to have not overridden the base __subclasscheck__ magic /// method. pub fn fast_issubclass(&self, cls: &impl Borrow) -> bool { - self.as_object().is(cls.borrow()) || self.mro.read().iter().any(|c| c.is(cls.borrow())) + self.as_object().is(cls.borrow()) || self.mro.read()[1..].iter().any(|c| c.is(cls.borrow())) } pub fn mro_map_collect(&self, f: F) -> Vec where F: Fn(&Self) -> R, { - std::iter::once(self) - .chain(self.mro.read().iter().map(|x| x.deref())) - .map(f) - .collect() + self.mro.read().iter().map(|x| x.deref()).map(f).collect() } pub fn mro_collect(&self) -> Vec> { - std::iter::once(self) - .chain(self.mro.read().iter().map(|x| x.deref())) + self.mro + .read() + .iter() + .map(|x| x.deref()) .map(|x| x.to_owned()) .collect() } @@ -472,7 +490,7 @@ impl Py { if let Some(r) = f(self) { Some(r) } else { - self.mro.read().iter().find_map(|cls| f(cls)) + self.mro.read()[1..].iter().find_map(|cls| f(cls)) } } @@ -531,8 +549,10 @@ impl PyType { *zelf.bases.write() = bases; // Recursively update the mros of this class and all subclasses fn update_mro_recursively(cls: &PyType, vm: &VirtualMachine) -> PyResult<()> { - *cls.mro.write() = + let mut mro = PyType::resolve_mro(&cls.bases.read()).map_err(|msg| vm.new_type_error(msg))?; + mro.insert(0, cls.mro.read()[0].to_owned()); + *cls.mro.write() = mro; for subclass in cls.subclasses.write().iter() { let subclass = subclass.upgrade().unwrap(); let subclass: &PyType = subclass.payload().unwrap(); diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 3a02d582a1..123ed6a04f 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -1384,7 +1384,7 @@ impl ExecutingFrame<'_> { fn import(&mut self, vm: &VirtualMachine, module_name: Option<&Py>) -> PyResult<()> { let module_name = module_name.unwrap_or(vm.ctx.empty_str); let from_list = >>::try_from_object(vm, self.pop_value())? - .unwrap_or_else(|| PyTupleTyped::empty(vm)); + .unwrap_or_else(|| PyTupleTyped::empty(&vm.ctx)); let level = usize::try_from_object(vm, self.pop_value())?; let module = vm.import_from(module_name, from_list, level)?; diff --git a/vm/src/object/core.rs b/vm/src/object/core.rs index 253d8fda63..dca63f1192 100644 --- a/vm/src/object/core.rs +++ b/vm/src/object/core.rs @@ -1252,12 +1252,14 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) { ptr::write(&mut (*type_type_ptr).typ, PyAtomicRef::from(type_type)); let object_type = PyTypeRef::from_raw(object_type_ptr.cast()); + (*object_type_ptr).payload.mro = PyRwLock::new(vec![object_type.clone()]); - (*type_type_ptr).payload.mro = PyRwLock::new(vec![object_type.clone()]); (*type_type_ptr).payload.bases = PyRwLock::new(vec![object_type.clone()]); (*type_type_ptr).payload.base = Some(object_type.clone()); let type_type = PyTypeRef::from_raw(type_type_ptr.cast()); + (*type_type_ptr).payload.mro = + PyRwLock::new(vec![type_type.clone(), object_type.clone()]); (type_type, object_type) } @@ -1273,6 +1275,7 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) { heaptype_ext: None, }; let weakref_type = PyRef::new_ref(weakref_type, type_type.clone(), None); + weakref_type.mro.write().insert(0, weakref_type.clone()); object_type.subclasses.write().push( type_type diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index 804918abb3..86d2b33fe7 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -371,80 +371,112 @@ impl PyObject { }) } - // Equivalent to check_class. Masks Attribute errors (into TypeErrors) and lets everything - // else go through. - fn check_cls(&self, cls: &PyObject, vm: &VirtualMachine, msg: F) -> PyResult + // Equivalent to CPython's check_class. Returns Ok(()) if cls is a valid class, + // Err with TypeError if not. Uses abstract_get_bases internally. + fn check_class(&self, vm: &VirtualMachine, msg: F) -> PyResult<()> where F: Fn() -> String, { - cls.get_attr(identifier!(vm, __bases__), vm).map_err(|e| { - // Only mask AttributeErrors. - if e.class().is(vm.ctx.exceptions.attribute_error) { - vm.new_type_error(msg()) - } else { - e + let cls = self; + match cls.abstract_get_bases(vm)? { + Some(_bases) => Ok(()), // Has __bases__, it's a valid class + None => { + // No __bases__ or __bases__ is not a tuple + Err(vm.new_type_error(msg())) } - }) + } + } + + /// abstract_get_bases() has logically 4 return states: + /// 1. getattr(cls, '__bases__') could raise an AttributeError + /// 2. getattr(cls, '__bases__') could raise some other exception + /// 3. getattr(cls, '__bases__') could return a tuple + /// 4. getattr(cls, '__bases__') could return something other than a tuple + /// + /// Only state #3 returns Some(tuple). AttributeErrors are masked by returning None. + /// If an object other than a tuple comes out of __bases__, then again, None is returned. + /// Other exceptions are propagated. + fn abstract_get_bases(&self, vm: &VirtualMachine) -> PyResult> { + match vm.get_attribute_opt(self.to_owned(), identifier!(vm, __bases__))? { + Some(bases) => { + // Check if it's a tuple + match PyTupleRef::try_from_object(vm, bases) { + Ok(tuple) => Ok(Some(tuple)), + Err(_) => Ok(None), // Not a tuple, return None + } + } + None => Ok(None), // AttributeError was masked + } } fn abstract_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult { let mut derived = self; - let mut first_item: PyObjectRef; - loop { + + // First loop: handle single inheritance without recursion + let bases = loop { if derived.is(cls) { return Ok(true); } - let bases = derived.get_attr(identifier!(vm, __bases__), vm)?; - let tuple = PyTupleRef::try_from_object(vm, bases)?; - - let n = tuple.len(); + let Some(bases) = derived.abstract_get_bases(vm)? else { + return Ok(false); + }; + let n = bases.len(); match n { - 0 => { - return Ok(false); - } + 0 => return Ok(false), 1 => { - first_item = tuple[0].clone(); - derived = &first_item; + // Avoid recursion in the single inheritance case + // # safety + // Intention: bases.as_slice()[0].as_object(); + // Though type-system cannot guarantee, derived does live long enough in the loop. + derived = unsafe { &*(bases.as_slice()[0].as_object() as *const _) }; continue; } _ => { - for i in 0..n { - let check = vm.with_recursion("in abstract_issubclass", || { - tuple[i].abstract_issubclass(cls, vm) - })?; - if check { - return Ok(true); - } - } + // Multiple inheritance - break out to handle recursively + break bases; } } + }; - return Ok(false); + // Second loop: handle multiple inheritance with recursion + // At this point we know n >= 2 + let n = bases.len(); + assert!(n >= 2); + + for i in 0..n { + let result = vm.with_recursion("in __issubclass__", || { + bases.as_slice()[i].abstract_issubclass(cls, vm) + })?; + if result { + return Ok(true); + } } + + Ok(false) } fn recursive_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult { - if let (Ok(obj), Ok(cls)) = (self.try_to_ref::(vm), cls.try_to_ref::(vm)) { - Ok(obj.fast_issubclass(cls)) - } else { - // Check if derived is a class - self.check_cls(self, vm, || { - format!("issubclass() arg 1 must be a class, not {}", self.class()) + // Fast path for both being types (matches CPython's PyType_Check) + if let Some(cls) = PyType::check(cls) + && let Some(derived) = PyType::check(self) + { + // PyType_IsSubtype equivalent + return Ok(derived.is_subtype(cls)); + } + // Check if derived is a class + self.check_class(vm, || { + format!("issubclass() arg 1 must be a class, not {}", self.class()) + })?; + + // Check if cls is a class, tuple, or union (matches CPython's order and message) + if !cls.class().is(vm.ctx.types.union_type) { + cls.check_class(vm, || { + "issubclass() arg 2 must be a class, a tuple of classes, or a union".to_string() })?; - - // Check if cls is a class, tuple, or union - if !cls.class().is(vm.ctx.types.union_type) { - self.check_cls(cls, vm, || { - format!( - "issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}", - cls.class() - ) - })?; - } - - self.abstract_issubclass(cls, vm) } + + self.abstract_issubclass(cls, vm) } /// Real issubclass check without going through __subclasscheck__ @@ -520,7 +552,7 @@ impl PyObject { Ok(retval) } else { // Not a type object, check if it's a valid class - self.check_cls(cls, vm, || { + cls.check_class(vm, || { format!( "isinstance() arg 2 must be a type, a tuple of types, or a union, not {}", cls.class() diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index 20c8161004..fd97aaabd8 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -580,7 +580,7 @@ impl VirtualMachine { #[inline] pub fn import<'a>(&self, module_name: impl AsPyStr<'a>, level: usize) -> PyResult { let module_name = module_name.as_pystr(&self.ctx); - let from_list = PyTupleTyped::empty(self); + let from_list = PyTupleTyped::empty(&self.ctx); self.import_inner(module_name, from_list, level) }