Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion crates/vm/src/builtins/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ use crate::{
protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods},
types::{
AsNumber, Callable, Constructor, GetAttr, PyTypeFlags, PyTypeSlots, Representable, SetAttr,
TypeDataRef, TypeDataRefMut, TypeDataSlot,
},
};
use indexmap::{IndexMap, map::Entry};
use itertools::Itertools;
use num_traits::ToPrimitive;
use std::{borrow::Borrow, collections::HashSet, ops::Deref, pin::Pin, ptr::NonNull};
use std::{any::Any, borrow::Borrow, collections::HashSet, ops::Deref, pin::Pin, ptr::NonNull};

#[pyclass(module = false, name = "type", traverse = "manual")]
pub struct PyType {
Expand Down Expand Up @@ -65,6 +66,7 @@ pub struct HeapTypeExt {
pub slots: Option<PyRef<PyTuple<PyStrRef>>>,
pub sequence_methods: PySequenceMethods,
pub mapping_methods: PyMappingMethods,
pub type_data: PyRwLock<Option<TypeDataSlot>>,
}

pub struct PointerSlot<T>(NonNull<T>);
Expand Down Expand Up @@ -203,6 +205,7 @@ impl PyType {
slots: None,
sequence_methods: PySequenceMethods::default(),
mapping_methods: PyMappingMethods::default(),
type_data: PyRwLock::new(None),
};
let base = bases[0].clone();

Expand Down Expand Up @@ -563,6 +566,50 @@ impl PyType {
|ext| PyRwLockReadGuard::map(ext.name.read(), |name| name.as_str()).into(),
)
}

// Type Data Slot API - CPython's PyObject_GetTypeData equivalent

/// Initialize type data for this type. Can only be called once.
/// Returns an error if the type is not a heap type or if data is already initialized.
pub fn init_type_data<T: Any + Send + Sync + 'static>(&self, data: T) -> Result<(), String> {
let ext = self
.heaptype_ext
.as_ref()
.ok_or_else(|| "Cannot set type data on non-heap types".to_string())?;

let mut type_data = ext.type_data.write();
if type_data.is_some() {
return Err("Type data already initialized".to_string());
}
*type_data = Some(TypeDataSlot::new(data));
Ok(())
}

/// Get a read guard to the type data.
/// Returns None if the type is not a heap type, has no data, or the data type doesn't match.
pub fn get_type_data<T: Any + 'static>(&self) -> Option<TypeDataRef<'_, T>> {
self.heaptype_ext
.as_ref()
.and_then(|ext| TypeDataRef::try_new(ext.type_data.read()))
}

/// Get a write guard to the type data.
/// Returns None if the type is not a heap type, has no data, or the data type doesn't match.
pub fn get_type_data_mut<T: Any + 'static>(&self) -> Option<TypeDataRefMut<'_, T>> {
self.heaptype_ext
.as_ref()
.and_then(|ext| TypeDataRefMut::try_new(ext.type_data.write()))
}

/// Check if this type has type data of the given type.
pub fn has_type_data<T: Any + 'static>(&self) -> bool {
self.heaptype_ext.as_ref().is_some_and(|ext| {
ext.type_data
.read()
.as_ref()
.is_some_and(|slot| slot.get::<T>().is_some())
})
}
}

impl Py<PyType> {
Expand Down Expand Up @@ -1167,6 +1214,7 @@ impl Constructor for PyType {
slots: heaptype_slots.clone(),
sequence_methods: PySequenceMethods::default(),
mapping_methods: PyMappingMethods::default(),
type_data: PyRwLock::new(None),
};
(slots, heaptype_ext)
};
Expand Down
58 changes: 29 additions & 29 deletions crates/vm/src/stdlib/ctypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,39 +389,35 @@ pub(crate) mod _ctypes {
/// Get the size of a ctypes type or instance
#[pyfunction(name = "sizeof")]
pub fn size_of(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
use super::array::{PyCArray, PyCArrayType};
use super::pointer::PyCPointer;
use super::structure::{PyCStructType, PyCStructure};
use super::union::{PyCUnion, PyCUnionType};
use super::union::PyCUnionType;
use super::util::StgInfo;
use crate::builtins::PyType;

// 1. Instances with stg_info
if obj.fast_isinstance(PyCArray::static_type()) {
// Get stg_info from the type
if let Some(type_obj) = obj.class().as_object().downcast_ref::<PyCArrayType>() {
return Ok(type_obj.stg_info.size);
}
// 1. Check TypeDataSlot on class (for instances)
if let Some(stg_info) = obj.class().get_type_data::<StgInfo>() {
return Ok(stg_info.size);
}

// 2. Check TypeDataSlot on type itself (for type objects)
if let Some(type_obj) = obj.downcast_ref::<PyType>()
&& let Some(stg_info) = type_obj.get_type_data::<StgInfo>()
{
return Ok(stg_info.size);
}

// 3. Instances with cdata buffer
if let Some(structure) = obj.downcast_ref::<PyCStructure>() {
return Ok(structure.cdata.read().size());
}
if obj.fast_isinstance(PyCUnion::static_type()) {
// Get stg_info from the type
if let Some(type_obj) = obj.class().as_object().downcast_ref::<PyCUnionType>() {
return Ok(type_obj.stg_info.size);
}
}
if let Some(simple) = obj.downcast_ref::<PyCSimple>() {
return Ok(simple.cdata.read().size());
}
if obj.fast_isinstance(PyCPointer::static_type()) {
return Ok(std::mem::size_of::<usize>());
}

// 2. Types (metatypes with stg_info)
if let Some(array_type) = obj.downcast_ref::<PyCArrayType>() {
return Ok(array_type.stg_info.size);
}

// 3. Type objects
if let Ok(type_ref) = obj.clone().downcast::<crate::builtins::PyType>() {
// Structure types - check if metaclass is or inherits from PyCStructType
Expand Down Expand Up @@ -659,33 +655,37 @@ pub(crate) mod _ctypes {

#[pyfunction]
fn alignment(tp: Either<PyTypeRef, PyObjectRef>, vm: &VirtualMachine) -> PyResult<usize> {
use super::array::{PyCArray, PyCArrayType};
use super::base::PyCSimpleType;
use super::pointer::PyCPointer;
use super::structure::PyCStructure;
use super::union::PyCUnion;
use super::util::StgInfo;
use crate::builtins::PyType;

let obj = match &tp {
Either::A(t) => t.as_object(),
Either::B(o) => o.as_ref(),
};

// Try to get alignment from stg_info directly (for instances)
if let Some(array_type) = obj.downcast_ref::<PyCArrayType>() {
return Ok(array_type.stg_info.align);
// 1. Check TypeDataSlot on class (for instances)
if let Some(stg_info) = obj.class().get_type_data::<StgInfo>() {
return Ok(stg_info.align);
}

// 2. Check TypeDataSlot on type itself (for type objects)
if let Some(type_obj) = obj.downcast_ref::<PyType>()
&& let Some(stg_info) = type_obj.get_type_data::<StgInfo>()
{
return Ok(stg_info.align);
}

// 3. Fallback for simple types without TypeDataSlot
if obj.fast_isinstance(PyCSimple::static_type()) {
// Get stg_info from the type by reading _type_ attribute
let cls = obj.class().to_owned();
let stg_info = PyCSimpleType::get_stg_info(&cls, vm);
return Ok(stg_info.align);
}
if obj.fast_isinstance(PyCArray::static_type()) {
// Get stg_info from the type
if let Some(type_obj) = obj.class().as_object().downcast_ref::<PyCArrayType>() {
return Ok(type_obj.stg_info.align);
}
}
if obj.fast_isinstance(PyCStructure::static_type()) {
// Calculate alignment from _fields_
let cls = obj.class();
Expand Down
Loading
Loading