diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index bdda7cc6c00f5d..9aa84d0fa15718 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1223,10 +1223,6 @@ def _get_slots(cls): def _update_func_cell_for__class__(f, oldcls, newcls): # Returns True if we update a cell, else False. - if f is None: - # f will be None in the case of a property where not all of - # fget, fset, and fdel are used. Nothing to do in that case. - return False try: idx = f.__code__.co_freevars.index("__class__") except ValueError: @@ -1235,13 +1231,57 @@ def _update_func_cell_for__class__(f, oldcls, newcls): # Fix the cell to point to the new class, if it's already pointing # at the old class. I'm not convinced that the "is oldcls" test # is needed, but other than performance can't hurt. - closure = f.__closure__[idx] - if closure.cell_contents is oldcls: - closure.cell_contents = newcls + cell = f.__closure__[idx] + if cell.cell_contents is oldcls: + cell.cell_contents = newcls return True return False +_object_members_values = { + value for name, value in + ( + *inspect.getmembers_static(object), + *inspect.getmembers_static(object()) + ) +} + + +def _is_not_object_member(v): + try: + return v not in _object_members_values + except TypeError: + return True + + +def _find_inner_functions(obj, seen=None, depth=0): + if seen is None: + seen = set() + if id(obj) in seen: + return None + seen.add(id(obj)) + + depth += 1 + # Normally just an inspection of a descriptor object itself should be enough, + # and we should encounter the function as its attribute, + # but in case function was wrapped (e.g. functools.partial was used), + # we want to dive at least one level deeper. + if depth > 2: + return None + + obj_is_type_instance = type in inspect._static_getmro(type(obj)) + for _, value in inspect.getmembers_static(obj, _is_not_object_member): + value_type = type(value) + if value_type is types.MemberDescriptorType and not obj_is_type_instance: + value = value.__get__(obj) + value_type = type(value) + + if value_type is types.FunctionType: + yield inspect.unwrap(value) + else: + yield from _find_inner_functions(value, seen, depth) + + def _create_slots(defined_fields, inherited_slots, field_names, weakref_slot): # The slots for our class. Remove slots from our base classes. Add # '__weakref__' if weakref_slot was given, unless it is already present. @@ -1317,7 +1357,11 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields): # (the newly created one, which we're returning) and not the # original class. We can break out of this loop as soon as we # make an update, since all closures for a class will share a - # given cell. + # given cell. First we try to find a pure function or a property, + # and then fallback to inspecting custom descriptors + # if no pure function or property is found. + + custom_descriptors_to_check = [] for member in newcls.__dict__.values(): # If this is a wrapped function, unwrap it. member = inspect.unwrap(member) @@ -1325,11 +1369,29 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields): if isinstance(member, types.FunctionType): if _update_func_cell_for__class__(member, cls, newcls): break - elif isinstance(member, property): - if (_update_func_cell_for__class__(member.fget, cls, newcls) - or _update_func_cell_for__class__(member.fset, cls, newcls) - or _update_func_cell_for__class__(member.fdel, cls, newcls)): - break + elif isinstance(member, property) and ( + any( + # Unwrap once more in case function + # was wrapped before it became property. + _update_func_cell_for__class__(inspect.unwrap(f), cls, newcls) + for f in (member.fget, member.fset, member.fdel) + if f is not None + ) + ): + break + elif hasattr(member, "__get__") and not inspect.ismemberdescriptor( + member + ): + # We don't want to inspect custom descriptors just yet + # there's still a chance we'll encounter a pure function + # or a property and won't have to use slower recursive search. + custom_descriptors_to_check.append(member) + else: + # Now let's ensure custom descriptors won't be left out. + for descriptor in custom_descriptors_to_check: + for f in _find_inner_functions(descriptor): + if _update_func_cell_for__class__(f, cls, newcls): + break return newcls diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 2984f4261bd2c4..3cb2dc8ce5e0f5 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -13,6 +13,7 @@ import weakref import traceback import unittest +from functools import partial, update_wrapper from unittest.mock import Mock from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol, DefaultDict from typing import get_type_hints @@ -5031,6 +5032,194 @@ def foo(self): A().foo() + def test_wrapped_property(self): + def mydecorator(f): + @wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + return wrapper + + class B: + @property + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @property + @mydecorator + def foo(self): + return super().foo + + self.assertEqual(A().foo, "bar") + + def test_custom_descriptor(self): + class CustomDescriptor: + def __init__(self, f): + self._f = f + + def __get__(self, instance, owner): + return self._f(instance) + + class B: + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(cls): + return super().foo() + + self.assertEqual(A().foo, "bar") + + def test_custom_descriptor_wrapped(self): + class CustomDescriptor: + def __init__(self, f): + self._f = update_wrapper(lambda *args, **kwargs: f(*args, **kwargs), f) + + def __get__(self, instance, owner): + return self._f(instance) + + class B: + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(cls): + return super().foo() + + self.assertEqual(A().foo, "bar") + + def test_custom_nested_descriptor(self): + class CustomFunctionWrapper: + def __init__(self, f): + self._f = f + + def __call__(self, *args, **kwargs): + return self._f(*args, **kwargs) + + class CustomDescriptor: + def __init__(self, f): + self._wrapper = CustomFunctionWrapper(f) + + def __get__(self, instance, owner): + return self._wrapper(instance) + + class B: + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(cls): + return super().foo() + + self.assertEqual(A().foo, "bar") + + def test_custom_nested_descriptor_with_partial(self): + class CustomDescriptor: + def __init__(self, f): + self._wrapper = partial(f, value="bar") + + def __get__(self, instance, owner): + return self._wrapper(instance) + + class B: + def foo(self, value): + return value + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(self, value): + return super().foo(value) + + self.assertEqual(A().foo, "bar") + + def test_custom_too_nested_descriptor(self): + class UnnecessaryNestedWrapper: + def __init__(self, wrapper): + self._wrapper = wrapper + + def __call__(self, *args, **kwargs): + return self._wrapper(*args, **kwargs) + + class CustomFunctionWrapper: + def __init__(self, f): + self._f = f + + def __call__(self, *args, **kwargs): + return self._f(*args, **kwargs) + + class CustomDescriptor: + def __init__(self, f): + self._wrapper = UnnecessaryNestedWrapper(CustomFunctionWrapper(f)) + + def __get__(self, instance, owner): + return self._wrapper(instance) + + class B: + def foo(self): + return "bar" + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(cls): + return super().foo() + + with self.assertRaises(TypeError) as context: + A().foo + + expected_error_message = ( + 'super(type, obj): obj (instance of A) is not ' + 'an instance or subtype of type (A).' + ) + self.assertEqual(context.exception.args, (expected_error_message,)) + + def test_user_defined_code_execution(self): + class CustomDescriptor: + def __init__(self, f): + self._wrapper = partial(f, value="bar") + + def __get__(self, instance, owner): + return object.__getattribute__(self, "_wrapper")(instance) + + def __getattribute__(self, name): + if name in { + # these are the bare minimum for the feature to work + "__class__", # accessed on `isinstance(value, Field)` + "__wrapped__", # accessed by unwrap + "__get__", # is required for the descriptor protocol + "__dict__", # is accessed by dir() to work + }: + return object.__getattribute__(self, name) + raise RuntimeError(f"Never should be accessed: {name}") + + class B: + def foo(self, value): + return value + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(self, value): + return super().foo(value) + + self.assertEqual(A().foo, "bar") + + @dataclass(slots=True) + class A(B): + @CustomDescriptor + def foo(self, value): + return super().foo(value) + + self.assertEqual(A().foo, "bar") + def test_remembered_class(self): # Apply the dataclass decorator manually (not when the class # is created), so that we can keep a reference to the diff --git a/Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst b/Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst new file mode 100644 index 00000000000000..b0bbdd406f63c7 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-09-27-19-50-30.gh-issue-90562.HeL_JA.rst @@ -0,0 +1,2 @@ +Modify dataclasses to enable zero argument support for ``super()`` when ``slots=True`` is +specified and custom descriptor is used or ``property`` function is wrapped.
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: