diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 7f3855c529..dfe545488b 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -537,9 +537,9 @@ impl AsRef<[T]> for PyTupleTyped { } impl PyTupleTyped { - pub fn empty(vm: &VirtualMachine) -> Self { + pub fn empty(ctx: &Context) -> Self { Self { - tuple: vm.ctx.empty_tuple.clone(), + tuple: ctx.empty_tuple.clone(), _marker: PhantomData, } } diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index 94c4f2f668..c5792c6b2d 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -37,7 +37,7 @@ use std::{borrow::Borrow, collections::HashSet, ops::Deref, pin::Pin, ptr::NonNu pub struct PyType { pub base: Option, pub bases: PyRwLock>, - pub mro: PyRwLock>, + pub mro: PyRwLock>, // TODO: PyTypedTuple pub subclasses: PyRwLock>>, pub attributes: PyRwLock, pub slots: PyTypeSlots, @@ -48,7 +48,7 @@ unsafe impl crate::object::Traverse for PyType { fn traverse(&self, tracer_fn: &mut crate::object::TraverseFn<'_>) { self.base.traverse(tracer_fn); self.bases.traverse(tracer_fn); - self.mro.traverse(tracer_fn); + // self.mro.traverse(tracer_fn); self.subclasses.traverse(tracer_fn); self.attributes .read_recursive() @@ -158,6 +158,15 @@ fn downcast_qualname(value: PyObjectRef, vm: &VirtualMachine) -> PyResult, b: &Py) -> bool { + for item in a_mro { + if item.is(b) { + return true; + } + } + false +} + impl PyType { pub fn new_simple_heap( name: &str, @@ -197,6 +206,12 @@ impl PyType { Self::new_heap_inner(base, bases, attrs, slots, heaptype_ext, metaclass, ctx) } + /// Equivalent to CPython's PyType_Check macro + /// Checks if obj is an instance of type (or its subclass) + pub(crate) fn check(obj: &PyObject) -> Option<&Py> { + obj.downcast_ref::() + } + fn resolve_mro(bases: &[PyRef]) -> Result, String> { // Check for duplicates in bases. let mut unique_bases = HashSet::new(); @@ -223,8 +238,6 @@ impl PyType { metaclass: PyRef, ctx: &Context, ) -> Result, String> { - let mro = Self::resolve_mro(&bases)?; - if base.slots.flags.has_feature(PyTypeFlags::HAS_DICT) { slots.flags |= PyTypeFlags::HAS_DICT } @@ -241,6 +254,7 @@ impl PyType { } } + let mro = Self::resolve_mro(&bases)?; let new_type = PyRef::new_ref( PyType { base: Some(base), @@ -254,6 +268,7 @@ impl PyType { metaclass, None, ); + new_type.mro.write().insert(0, new_type.clone()); new_type.init_slots(ctx); @@ -285,7 +300,6 @@ impl PyType { let bases = PyRwLock::new(vec![base.clone()]); let mro = base.mro_map_collect(|x| x.to_owned()); - let new_type = PyRef::new_ref( PyType { base: Some(base), @@ -299,6 +313,7 @@ impl PyType { metaclass, None, ); + new_type.mro.write().insert(0, new_type.clone()); let weakref_type = super::PyWeak::static_type(); for base in new_type.bases.read().iter() { @@ -317,7 +332,7 @@ impl PyType { #[allow(clippy::mutable_key_type)] let mut slot_name_set = std::collections::HashSet::new(); - for cls in self.mro.read().iter() { + for cls in self.mro.read()[1..].iter() { for &name in cls.attributes.read().keys() { if name == identifier!(ctx, __new__) { continue; @@ -366,8 +381,7 @@ impl PyType { } pub fn get_super_attr(&self, attr_name: &'static PyStrInterned) -> Option { - self.mro - .read() + self.mro.read()[1..] .iter() .find_map(|class| class.attributes.read().get(attr_name).cloned()) } @@ -375,9 +389,7 @@ impl PyType { // This is the internal has_attr implementation for fast lookup on a class. pub fn has_attr(&self, attr_name: &'static PyStrInterned) -> bool { self.attributes.read().contains_key(attr_name) - || self - .mro - .read() + || self.mro.read()[1..] .iter() .any(|c| c.attributes.read().contains_key(attr_name)) } @@ -386,10 +398,7 @@ impl PyType { // Gather all members here: let mut attributes = PyAttributes::default(); - for bc in std::iter::once(self) - .chain(self.mro.read().iter().map(|cls| -> &PyType { cls })) - .rev() - { + for bc in self.mro.read().iter().map(|cls| -> &PyType { cls }).rev() { for (name, value) in bc.attributes.read().iter() { attributes.insert(name.to_owned(), value.clone()); } @@ -439,26 +448,35 @@ impl PyType { } impl Py { + pub(crate) fn is_subtype(&self, other: &Py) -> bool { + is_subtype_with_mro(&self.mro.read(), self, other) + } + + /// Equivalent to CPython's PyType_CheckExact macro + /// Checks if obj is exactly a type (not a subclass) + pub fn check_exact<'a>(obj: &'a PyObject, vm: &VirtualMachine) -> Option<&'a Py> { + obj.downcast_ref_if_exact::(vm) + } + /// Determines if `subclass` is actually a subclass of `cls`, this doesn't call __subclasscheck__, /// so only use this if `cls` is known to have not overridden the base __subclasscheck__ magic /// method. pub fn fast_issubclass(&self, cls: &impl Borrow) -> bool { - self.as_object().is(cls.borrow()) || self.mro.read().iter().any(|c| c.is(cls.borrow())) + self.as_object().is(cls.borrow()) || self.mro.read()[1..].iter().any(|c| c.is(cls.borrow())) } pub fn mro_map_collect(&self, f: F) -> Vec where F: Fn(&Self) -> R, { - std::iter::once(self) - .chain(self.mro.read().iter().map(|x| x.deref())) - .map(f) - .collect() + self.mro.read().iter().map(|x| x.deref()).map(f).collect() } pub fn mro_collect(&self) -> Vec> { - std::iter::once(self) - .chain(self.mro.read().iter().map(|x| x.deref())) + self.mro + .read() + .iter() + .map(|x| x.deref()) .map(|x| x.to_owned()) .collect() } @@ -472,7 +490,7 @@ impl Py { if let Some(r) = f(self) { Some(r) } else { - self.mro.read().iter().find_map(|cls| f(cls)) + self.mro.read()[1..].iter().find_map(|cls| f(cls)) } } @@ -531,8 +549,10 @@ impl PyType { *zelf.bases.write() = bases; // Recursively update the mros of this class and all subclasses fn update_mro_recursively(cls: &PyType, vm: &VirtualMachine) -> PyResult<()> { - *cls.mro.write() = + let mut mro = PyType::resolve_mro(&cls.bases.read()).map_err(|msg| vm.new_type_error(msg))?; + mro.insert(0, cls.mro.read()[0].to_owned()); + *cls.mro.write() = mro; for subclass in cls.subclasses.write().iter() { let subclass = subclass.upgrade().unwrap(); let subclass: &PyType = subclass.payload().unwrap(); diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 3a02d582a1..123ed6a04f 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -1384,7 +1384,7 @@ impl ExecutingFrame<'_> { fn import(&mut self, vm: &VirtualMachine, module_name: Option<&Py>) -> PyResult<()> { let module_name = module_name.unwrap_or(vm.ctx.empty_str); let from_list = >>::try_from_object(vm, self.pop_value())? - .unwrap_or_else(|| PyTupleTyped::empty(vm)); + .unwrap_or_else(|| PyTupleTyped::empty(&vm.ctx)); let level = usize::try_from_object(vm, self.pop_value())?; let module = vm.import_from(module_name, from_list, level)?; diff --git a/vm/src/object/core.rs b/vm/src/object/core.rs index 253d8fda63..dca63f1192 100644 --- a/vm/src/object/core.rs +++ b/vm/src/object/core.rs @@ -1252,12 +1252,14 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) { ptr::write(&mut (*type_type_ptr).typ, PyAtomicRef::from(type_type)); let object_type = PyTypeRef::from_raw(object_type_ptr.cast()); + (*object_type_ptr).payload.mro = PyRwLock::new(vec![object_type.clone()]); - (*type_type_ptr).payload.mro = PyRwLock::new(vec![object_type.clone()]); (*type_type_ptr).payload.bases = PyRwLock::new(vec![object_type.clone()]); (*type_type_ptr).payload.base = Some(object_type.clone()); let type_type = PyTypeRef::from_raw(type_type_ptr.cast()); + (*type_type_ptr).payload.mro = + PyRwLock::new(vec![type_type.clone(), object_type.clone()]); (type_type, object_type) } @@ -1273,6 +1275,7 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) { heaptype_ext: None, }; let weakref_type = PyRef::new_ref(weakref_type, type_type.clone(), None); + weakref_type.mro.write().insert(0, weakref_type.clone()); object_type.subclasses.write().push( type_type diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index 804918abb3..86d2b33fe7 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -371,80 +371,112 @@ impl PyObject { }) } - // Equivalent to check_class. Masks Attribute errors (into TypeErrors) and lets everything - // else go through. - fn check_cls(&self, cls: &PyObject, vm: &VirtualMachine, msg: F) -> PyResult + // Equivalent to CPython's check_class. Returns Ok(()) if cls is a valid class, + // Err with TypeError if not. Uses abstract_get_bases internally. + fn check_class(&self, vm: &VirtualMachine, msg: F) -> PyResult<()> where F: Fn() -> String, { - cls.get_attr(identifier!(vm, __bases__), vm).map_err(|e| { - // Only mask AttributeErrors. - if e.class().is(vm.ctx.exceptions.attribute_error) { - vm.new_type_error(msg()) - } else { - e + let cls = self; + match cls.abstract_get_bases(vm)? { + Some(_bases) => Ok(()), // Has __bases__, it's a valid class + None => { + // No __bases__ or __bases__ is not a tuple + Err(vm.new_type_error(msg())) } - }) + } + } + + /// abstract_get_bases() has logically 4 return states: + /// 1. getattr(cls, '__bases__') could raise an AttributeError + /// 2. getattr(cls, '__bases__') could raise some other exception + /// 3. getattr(cls, '__bases__') could return a tuple + /// 4. getattr(cls, '__bases__') could return something other than a tuple + /// + /// Only state #3 returns Some(tuple). AttributeErrors are masked by returning None. + /// If an object other than a tuple comes out of __bases__, then again, None is returned. + /// Other exceptions are propagated. + fn abstract_get_bases(&self, vm: &VirtualMachine) -> PyResult> { + match vm.get_attribute_opt(self.to_owned(), identifier!(vm, __bases__))? { + Some(bases) => { + // Check if it's a tuple + match PyTupleRef::try_from_object(vm, bases) { + Ok(tuple) => Ok(Some(tuple)), + Err(_) => Ok(None), // Not a tuple, return None + } + } + None => Ok(None), // AttributeError was masked + } } fn abstract_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult { let mut derived = self; - let mut first_item: PyObjectRef; - loop { + + // First loop: handle single inheritance without recursion + let bases = loop { if derived.is(cls) { return Ok(true); } - let bases = derived.get_attr(identifier!(vm, __bases__), vm)?; - let tuple = PyTupleRef::try_from_object(vm, bases)?; - - let n = tuple.len(); + let Some(bases) = derived.abstract_get_bases(vm)? else { + return Ok(false); + }; + let n = bases.len(); match n { - 0 => { - return Ok(false); - } + 0 => return Ok(false), 1 => { - first_item = tuple[0].clone(); - derived = &first_item; + // Avoid recursion in the single inheritance case + // # safety + // Intention: bases.as_slice()[0].as_object(); + // Though type-system cannot guarantee, derived does live long enough in the loop. + derived = unsafe { &*(bases.as_slice()[0].as_object() as *const _) }; continue; } _ => { - for i in 0..n { - let check = vm.with_recursion("in abstract_issubclass", || { - tuple[i].abstract_issubclass(cls, vm) - })?; - if check { - return Ok(true); - } - } + // Multiple inheritance - break out to handle recursively + break bases; } } + }; - return Ok(false); + // Second loop: handle multiple inheritance with recursion + // At this point we know n >= 2 + let n = bases.len(); + assert!(n >= 2); + + for i in 0..n { + let result = vm.with_recursion("in __issubclass__", || { + bases.as_slice()[i].abstract_issubclass(cls, vm) + })?; + if result { + return Ok(true); + } } + + Ok(false) } fn recursive_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult { - if let (Ok(obj), Ok(cls)) = (self.try_to_ref::(vm), cls.try_to_ref::(vm)) { - Ok(obj.fast_issubclass(cls)) - } else { - // Check if derived is a class - self.check_cls(self, vm, || { - format!("issubclass() arg 1 must be a class, not {}", self.class()) + // Fast path for both being types (matches CPython's PyType_Check) + if let Some(cls) = PyType::check(cls) + && let Some(derived) = PyType::check(self) + { + // PyType_IsSubtype equivalent + return Ok(derived.is_subtype(cls)); + } + // Check if derived is a class + self.check_class(vm, || { + format!("issubclass() arg 1 must be a class, not {}", self.class()) + })?; + + // Check if cls is a class, tuple, or union (matches CPython's order and message) + if !cls.class().is(vm.ctx.types.union_type) { + cls.check_class(vm, || { + "issubclass() arg 2 must be a class, a tuple of classes, or a union".to_string() })?; - - // Check if cls is a class, tuple, or union - if !cls.class().is(vm.ctx.types.union_type) { - self.check_cls(cls, vm, || { - format!( - "issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}", - cls.class() - ) - })?; - } - - self.abstract_issubclass(cls, vm) } + + self.abstract_issubclass(cls, vm) } /// Real issubclass check without going through __subclasscheck__ @@ -520,7 +552,7 @@ impl PyObject { Ok(retval) } else { // Not a type object, check if it's a valid class - self.check_cls(cls, vm, || { + cls.check_class(vm, || { format!( "isinstance() arg 2 must be a type, a tuple of types, or a union, not {}", cls.class() diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index 20c8161004..fd97aaabd8 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -580,7 +580,7 @@ impl VirtualMachine { #[inline] pub fn import<'a>(&self, module_name: impl AsPyStr<'a>, level: usize) -> PyResult { let module_name = module_name.as_pystr(&self.ctx); - let from_list = PyTupleTyped::empty(self); + let from_list = PyTupleTyped::empty(&self.ctx); self.import_inner(module_name, from_list, level) } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy