From dc066d2c229c69ed0c02c59fa2f78c4b0512d303 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 15 Apr 2025 16:45:55 -0600 Subject: [PATCH 1/4] Add _PyPickle_GetXIData(). --- Include/internal/pycore_crossinterp.h | 7 + Lib/test/_crossinterp_definitions.py | 56 ++- Lib/test/support/import_helper.py | 108 ++++++ Lib/test/test_crossinterp.py | 493 ++++++++++++++++++++++++-- Modules/_testinternalcapi.c | 5 + Python/crossinterp.c | 451 +++++++++++++++++++++++ 6 files changed, 1090 insertions(+), 30 deletions(-) diff --git a/Include/internal/pycore_crossinterp.h b/Include/internal/pycore_crossinterp.h index 5cf9f8fb5a0388..5b8939a89fdb12 100644 --- a/Include/internal/pycore_crossinterp.h +++ b/Include/internal/pycore_crossinterp.h @@ -171,6 +171,13 @@ PyAPI_FUNC(_PyBytes_data_t *) _PyBytes_GetXIDataWrapped( xid_newobjfunc, _PyXIData_t *); +// _PyObject_GetXIData() for pickle +PyAPI_DATA(PyObject *) _PyPickle_LoadFromXIData(_PyXIData_t *); +PyAPI_FUNC(int) _PyPickle_GetXIData( + PyThreadState *, + PyObject *, + _PyXIData_t *); + /* using cross-interpreter data */ diff --git a/Lib/test/_crossinterp_definitions.py b/Lib/test/_crossinterp_definitions.py index 9b52aea39522f5..455b206a326047 100644 --- a/Lib/test/_crossinterp_definitions.py +++ b/Lib/test/_crossinterp_definitions.py @@ -100,7 +100,7 @@ def ham_C_closure(z): ham_C_closure, *_ = eggs_closure_C(2) -FUNCTIONS = [ +TOP_FUNCTIONS = [ # shallow spam_minimal, spam_full, @@ -112,6 +112,8 @@ def ham_C_closure(z): spam_NC, spam_CN, spam_CC, +] +NESTED_FUNCTIONS = [ # inner func eggs_nested, eggs_closure, @@ -125,6 +127,13 @@ def ham_C_closure(z): ham_C_nested, ham_C_closure, ] +FUNCTIONS = [ + *TOP_FUNCTIONS, + *NESTED_FUNCTIONS, +] + + +# XXX set dishonest __file__, __module__, __name__ ####################################### @@ -202,6 +211,13 @@ def __init__(self, a, b, c): # __str__ # ... + def __eq__(self, other): + if not isinstance(other, SpamFull): + return NotImplemented + return (self.a == other.a and + self.b == other.b and + self.c == other.c) + @property def prop(self): return True @@ -222,9 +238,47 @@ class EggsNested: EggsNested = class_eggs_inner() +TOP_CLASSES = { + Spam: (), + SpamOkay: (), + SpamFull: (1, 2, 3), + SubSpamFull: (1, 2, 3), + SubTuple: ([1, 2, 3],), +} +CLASSES_WITHOUT_EQUALITY = [ + Spam, + SpamOkay, +] +BUILTIN_SUBCLASSES = [ + SubTuple, +] +NESTED_CLASSES = { + EggsNested: (), +} +CLASSES = { + **TOP_CLASSES, + **NESTED_CLASSES, +} + ####################################### # exceptions class MimimalError(Exception): pass + + +class RichError(Exception): + def __init__(self, msg, value=None): + super().__init__(msg, value) + self.msg = msg + self.value = value + + def __eq__(self, other): + if not isinstance(other, RichError): + return NotImplemented + if self.msg != other.msg: + return False + if self.value != other.value: + return False + return True diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py index 42cfe9cfa8cb72..1c827c5ded27b0 100644 --- a/Lib/test/support/import_helper.py +++ b/Lib/test/support/import_helper.py @@ -1,6 +1,7 @@ import contextlib import _imp import importlib +import importlib.machinery import importlib.util import os import shutil @@ -332,3 +333,110 @@ def ensure_lazy_imports(imported_module, modules_to_block): ) from .script_helper import assert_python_ok assert_python_ok("-S", "-c", script) + + +@contextlib.contextmanager +def module_restored(name): + """A context manager that restores a module to the original state.""" + missing = name in sys.modules + orig = sys.modules.get(name) + if orig is None: + mod = importlib.import_module(name) + else: + mod = type(sys)(name) + mod.__dict__.update(orig.__dict__) + sys.modules[name] = mod + try: + yield mod + finally: + if missing: + sys.modules.pop(name, None) + else: + sys.modules[name] = orig + + +def create_module(name, loader=None, *, ispkg=False): + """Return a new, empty module.""" + spec = importlib.machinery.ModuleSpec( + name, + loader, + origin='', + is_package=ispkg, + ) + return importlib.util.module_from_spec(spec) + + +def _ensure_module(name, ispkg, addparent, clearnone): + try: + mod = orig = sys.modules[name] + except KeyError: + mod = orig = None + missing = True + else: + missing = False + if mod is not None: + # It was already imported. + return mod, orig, missing + # Otherwise, None means it was explicitly disabled. + + assert name != '__main__' + if not missing: + assert orig is None, (name, sys.modules[name]) + if not clearnone: + raise ModuleNotFoundError(name) + del sys.modules[name] + # Try normal import, then fall back to adding the module. + try: + mod = importlib.import_module(name) + except ModuleNotFoundError: + if addparent and not clearnone: + addparent = None + mod = _add_module(name, ispkg, addparent) + return mod, orig, missing + + +def _add_module(spec, ispkg, addparent): + if isinstance(spec, str): + name = spec + mod = create_module(name, ispkg=ispkg) + spec = mod.__spec__ + else: + name = spec.name + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + if addparent is not False and spec.parent: + _ensure_module(spec.parent, True, addparent, bool(addparent)) + return mod + + +def add_module(spec, *, parents=True): + """Return the module after creating it and adding it to sys.modules. + + If parents is True then also create any missing parents. + """ + return _add_module(spec, False, parents) + + +def add_package(spec, *, parents=True): + """Return the module after creating it and adding it to sys.modules. + + If parents is True then also create any missing parents. + """ + return _add_module(spec, True, parents) + + +def ensure_module_imported(name, *, clearnone=True): + """Return the corresponding module. + + If it was already imported then return that. Otherwise, try + importing it (optionally clear it first if None). If that fails + then create a new empty module. + + It can be helpful to combine this with ready_to_import() and/or + isolated_modules(). + """ + if sys.modules.get(name) is not None: + mod = sys.modules[name] + else: + mod, _, _ = _force_import(name, False, True, clearnone) + return mod diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py index e1d1998fefc7fb..038835ecffac71 100644 --- a/Lib/test/test_crossinterp.py +++ b/Lib/test/test_crossinterp.py @@ -1,3 +1,6 @@ +import contextlib +import importlib +import importlib.util import itertools import sys import types @@ -9,7 +12,6 @@ _interpreters = import_helper.import_module('_interpreters') from _interpreters import NotShareableError - from test import _crossinterp_definitions as defs @@ -18,6 +20,74 @@ EXCEPTION_TYPES = [cls for cls in BUILTIN_TYPES if issubclass(cls, BaseException)] +DEFS = defs +with open(DEFS.__file__) as infile: + DEFS_TEXT = infile.read() +del infile + + +def load_defs(module=None): + """Return a new copy of the test._crossinterp_definitions module. + + The module's __name__ matches the "module" arg, which is either + a str or a module. + + If the "module" arg is a module then the just-loaded defs are also + copied into that module. + + Note that the new module is not added to sys.modules. + """ + if module is None: + modname = DEFS.__name__ + elif isinstance(module, str): + modname = module + module = None + else: + modname = module.__name__ + # Create the new module and populate it. + defs = import_helper.create_module(modname) + defs.__file__ = DEFS.__file__ + exec(DEFS_TEXT, defs.__dict__) + # Copy the defs into the module arg, if any. + if module is not None: + for name, value in defs.__dict__.items(): + if name.startswith('_'): + continue + assert not hasattr(module, name), (name, getattr(module, name)) + setattr(module, name, value) + return defs + + +@contextlib.contextmanager +def using___main__(): + """Make sure __main__ module exists (and clean up after).""" + modname = '__main__' + if modname not in sys.modules: + with import_helper.isolated_modules(): + yield import_helper.add_module(modname) + else: + with import_helper.module_restored(modname) as mod: + yield mod + + +@contextlib.contextmanager +def temp_module(modname): + """Create the module and add to sys.modules, then remove it after.""" + assert modname not in sys.modules, (modname,) + with import_helper.isolated_modules(): + yield import_helper.add_module(modname) + + +@contextlib.contextmanager +def missing_defs_module(modname, *, prep=False): + assert modname not in sys.modules, (modname,) + if prep: + with import_helper.ready_to_import(modname, DEFS_TEXT): + yield modname + else: + with import_helper.isolated_modules(): + yield modname + class _GetXIDataTests(unittest.TestCase): @@ -29,26 +99,44 @@ def get_xidata(self, obj, *, mode=None): def get_roundtrip(self, obj, *, mode=None): mode = self._resolve_mode(mode) - xid =_testinternalcapi.get_crossinterp_data(obj, mode) + return self._get_roundtrip(obj, mode) + + def _get_roundtrip(self, obj, mode): + xid = _testinternalcapi.get_crossinterp_data(obj, mode) return _testinternalcapi.restore_crossinterp_data(xid) - def iter_roundtrip_values(self, values, *, mode=None): + def assert_roundtrip_identical(self, values, *, mode=None): mode = self._resolve_mode(mode) for obj in values: with self.subTest(obj): - xid = _testinternalcapi.get_crossinterp_data(obj, mode) - got = _testinternalcapi.restore_crossinterp_data(xid) - yield obj, got + got = self._get_roundtrip(obj, mode) + self.assertIs(got, obj) def assert_roundtrip_equal(self, values, *, mode=None): - for obj, got in self.iter_roundtrip_values(values, mode=mode): - self.assertEqual(got, obj) - self.assertIs(type(got), type(obj)) + mode = self._resolve_mode(mode) + for obj in values: + with self.subTest(obj): + got = self._get_roundtrip(obj, mode) + self.assertEqual(got, obj) + assert type(got) is type(obj) - def assert_roundtrip_identical(self, values, *, mode=None): - for obj, got in self.iter_roundtrip_values(values, mode=mode): - # XXX What about between interpreters? - self.assertIs(got, obj) + def assert_roundtrip_equal_not_identical(self, values, *, mode=None): + mode = self._resolve_mode(mode) + for obj in values: + with self.subTest(obj): + got = self._get_roundtrip(obj, mode) + self.assertIsNot(got, obj) + self.assertIs(type(got), type(obj)) + self.assertEqual(got, obj) + + def assert_roundtrip_not_equal(self, values, *, mode=None): + mode = self._resolve_mode(mode) + for obj in values: + with self.subTest(obj): + got = self._get_roundtrip(obj, mode) + self.assertIsNot(got, obj) + self.assertIs(type(got), type(obj)) + self.assertNotEqual(got, obj) def assert_not_shareable(self, values, exctype=None, *, mode=None): mode = self._resolve_mode(mode) @@ -66,6 +154,363 @@ def _resolve_mode(self, mode): return mode +class PickleTests(_GetXIDataTests): + + MODE = 'pickle' + + def test_shareable(self): + self.assert_roundtrip_equal([ + # singletons + None, + True, + False, + # bytes + *(i.to_bytes(2, 'little', signed=True) + for i in range(-1, 258)), + # str + 'hello world', + '你好世界', + '', + # int + sys.maxsize, + -sys.maxsize - 1, + *range(-1, 258), + # float + 0.0, + 1.1, + -1.0, + 0.12345678, + -0.12345678, + # tuple + (), + (1,), + ("hello", "world", ), + (1, True, "hello"), + ((1,),), + ((1, 2), (3, 4)), + ((1, 2), (3, 4), (5, 6)), + ]) + # not shareable using xidata + self.assert_roundtrip_equal([ + # int + sys.maxsize + 1, + -sys.maxsize - 2, + 2**1000, + # tuple + (0, 1.0, []), + (0, 1.0, {}), + (0, 1.0, ([],)), + (0, 1.0, ({},)), + ]) + + def test_list(self): + self.assert_roundtrip_equal_not_identical([ + [], + [1, 2, 3], + [[1], (2,), {3: 4}], + ]) + + def test_dict(self): + self.assert_roundtrip_equal_not_identical([ + {}, + {1: 7, 2: 8, 3: 9}, + {1: [1], 2: (2,), 3: {3: 4}}, + ]) + + def test_set(self): + self.assert_roundtrip_equal_not_identical([ + set(), + {1, 2, 3}, + {frozenset({1}), (2,)}, + ]) + + # classes + + def assert_class_defs_same(self, defs): + # Unpickle relative to the unchanged original module. + self.assert_roundtrip_identical(defs.TOP_CLASSES) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + if cls in defs.CLASSES_WITHOUT_EQUALITY: + continue + instances.append(cls(*args)) + self.assert_roundtrip_equal_not_identical(instances) + + # these don't compare equal + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + if cls not in defs.CLASSES_WITHOUT_EQUALITY: + continue + instances.append(cls(*args)) + self.assert_roundtrip_not_equal(instances) + + def assert_class_defs_other_pickle(self, defs, mod): + # Pickle relative to a different module than the original. + for cls in defs.TOP_CLASSES: + assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__)) + self.assert_not_shareable(defs.TOP_CLASSES) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + instances.append(cls(*args)) + self.assert_not_shareable(instances) + + def assert_class_defs_other_unpickle(self, defs, mod, *, fail=False): + # Unpickle relative to a different module than the original. + for cls in defs.TOP_CLASSES: + assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__)) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + with self.subTest(cls): + setattr(mod, cls.__name__, cls) + xid = self.get_xidata(cls) + inst = cls(*args) + instxid = self.get_xidata(inst) + instances.append( + (cls, xid, inst, instxid)) + + for cls, xid, inst, instxid in instances: + with self.subTest(cls): + delattr(mod, cls.__name__) + if fail: + with self.assertRaises(NotShareableError): + _testinternalcapi.restore_crossinterp_data(xid) + continue + got = _testinternalcapi.restore_crossinterp_data(xid) + self.assertIsNot(got, cls) + self.assertNotEqual(got, cls) + + gotcls = got + got = _testinternalcapi.restore_crossinterp_data(instxid) + self.assertIsNot(got, inst) + self.assertIs(type(got), gotcls) + if cls in defs.CLASSES_WITHOUT_EQUALITY: + self.assertNotEqual(got, inst) + elif cls in defs.BUILTIN_SUBCLASSES: + self.assertEqual(got, inst) + else: + self.assertNotEqual(got, inst) + + def assert_class_defs_not_shareable(self, defs): + self.assert_not_shareable(defs.TOP_CLASSES) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + instances.append(cls(*args)) + self.assert_not_shareable(instances) + + def test_user_class_normal(self): + self.assert_class_defs_same(defs) + + def test_user_class_in___main__(self): + with using___main__() as mod: + defs = load_defs(mod) + self.assert_class_defs_same(defs) + + def test_user_class_not_in___main___with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_class_defs_not_shareable(defs) + + def test_user_class_not_in___main___without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_class_defs_not_shareable(defs) + + def test_user_class_not_in___main___unpickle_with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_class_defs_other_unpickle(defs, mod) + + def test_user_class_not_in___main___unpickle_without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_class_defs_other_unpickle(defs, mod, fail=True) + + def test_user_class_in_module(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod) + self.assert_class_defs_same(defs) + + def test_user_class_not_in_module_with_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + assert defs.__file__ + # For now, we only address this case for __main__. + self.assert_class_defs_not_shareable(defs) + + def test_user_class_not_in_module_without_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + defs.__file__ = None + self.assert_class_defs_not_shareable(defs) + + def test_user_class_module_missing_then_imported(self): + with missing_defs_module('__spam__', prep=True) as modname: + defs = load_defs(modname) + # For now, we only address this case for __main__. + self.assert_class_defs_not_shareable(defs) + + def test_user_class_module_missing_not_available(self): + with missing_defs_module('__spam__') as modname: + defs = load_defs(modname) + self.assert_class_defs_not_shareable(defs) + + def test_nested_class(self): + eggs = defs.EggsNested() + with self.assertRaises(NotShareableError): + self.get_roundtrip(eggs) + + # functions + + def assert_func_defs_same(self, defs): + # Unpickle relative to the unchanged original module. + self.assert_roundtrip_identical(defs.TOP_FUNCTIONS) + + def assert_func_defs_other_pickle(self, defs, mod): + # Pickle relative to a different module than the original. + for func in defs.TOP_FUNCTIONS: + assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__)) + self.assert_not_shareable(defs.TOP_FUNCTIONS) + + def assert_func_defs_other_unpickle(self, defs, mod, *, fail=False): + # Unpickle relative to a different module than the original. + for func in defs.TOP_FUNCTIONS: + assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__)) + + captured = [] + for func in defs.TOP_FUNCTIONS: + with self.subTest(func): + setattr(mod, func.__name__, func) + xid = self.get_xidata(func) + captured.append( + (func, xid)) + + for func, xid in captured: + with self.subTest(func): + delattr(mod, func.__name__) + if fail: + with self.assertRaises(NotShareableError): + _testinternalcapi.restore_crossinterp_data(xid) + continue + got = _testinternalcapi.restore_crossinterp_data(xid) + self.assertIsNot(got, func) + self.assertNotEqual(got, func) + + def assert_func_defs_not_shareable(self, defs): + self.assert_not_shareable(defs.TOP_FUNCTIONS) + + def test_user_function_normal(self): +# self.assert_roundtrip_equal(defs.TOP_FUNCTIONS) + self.assert_func_defs_same(defs) + + def test_user_func_in___main__(self): + with using___main__() as mod: + defs = load_defs(mod) + self.assert_func_defs_same(defs) + + def test_user_func_not_in___main___with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_func_defs_not_shareable(defs) + + def test_user_func_not_in___main___without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_func_defs_not_shareable(defs) + + def test_user_func_not_in___main___unpickle_with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_func_defs_other_unpickle(defs, mod) + + def test_user_func_not_in___main___unpickle_without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_func_defs_other_unpickle(defs, mod, fail=True) + + def test_user_func_in_module(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod) + self.assert_func_defs_same(defs) + + def test_user_func_not_in_module_with_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + assert defs.__file__ + # For now, we only address this case for __main__. + self.assert_func_defs_not_shareable(defs) + + def test_user_func_not_in_module_without_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + defs.__file__ = None + self.assert_func_defs_not_shareable(defs) + + def test_user_func_module_missing_then_imported(self): + with missing_defs_module('__spam__', prep=True) as modname: + defs = load_defs(modname) + # For now, we only address this case for __main__. + self.assert_func_defs_not_shareable(defs) + + def test_user_func_module_missing_not_available(self): + with missing_defs_module('__spam__') as modname: + defs = load_defs(modname) + self.assert_func_defs_not_shareable(defs) + + def test_nested_function(self): + self.assert_not_shareable(defs.NESTED_FUNCTIONS) + + # exceptions + + def test_user_exception_normal(self): + self.assert_roundtrip_not_equal([ + defs.MimimalError('error!'), + ]) + self.assert_roundtrip_equal_not_identical([ + defs.RichError('error!', 42), + ]) + + def test_builtin_exception(self): + msg = 'error!' + try: + raise Exception + except Exception as exc: + caught = exc + special = { + BaseExceptionGroup: (msg, [caught]), + ExceptionGroup: (msg, [caught]), +# UnicodeError: (None, msg, None, None, None), + UnicodeEncodeError: ('utf-8', '', 1, 3, msg), + UnicodeDecodeError: ('utf-8', b'', 1, 3, msg), + UnicodeTranslateError: ('', 1, 3, msg), + } + exceptions = [] + for cls in EXCEPTION_TYPES: + args = special.get(cls) or (msg,) + exceptions.append(cls(*args)) + + self.assert_roundtrip_not_equal(exceptions) + + class ShareableTypeTests(_GetXIDataTests): MODE = 'xidata' @@ -223,22 +668,12 @@ def test_module(self): ]) def test_class(self): - self.assert_not_shareable([ - defs.Spam, - defs.SpamOkay, - defs.SpamFull, - defs.SubSpamFull, - defs.SubTuple, - defs.EggsNested, - ]) - self.assert_not_shareable([ - defs.Spam(), - defs.SpamOkay(), - defs.SpamFull(1, 2, 3), - defs.SubSpamFull(1, 2, 3), - defs.SubTuple([1, 2, 3]), - defs.EggsNested(), - ]) + self.assert_not_shareable(defs.CLASSES) + + instances = [] + for cls, args in defs.CLASSES.items(): + instances.append(cls(*args)) + self.assert_not_shareable(instances) def test_builtin_type(self): self.assert_not_shareable([ diff --git a/Modules/_testinternalcapi.c b/Modules/_testinternalcapi.c index 99dca9f77df6be..fd6d496a573213 100644 --- a/Modules/_testinternalcapi.c +++ b/Modules/_testinternalcapi.c @@ -1730,6 +1730,11 @@ get_crossinterp_data(PyObject *self, PyObject *args, PyObject *kwargs) goto error; } } + else if (strcmp(mode, "pickle") == 0) { + if (_PyPickle_GetXIData(tstate, obj, xidata) != 0) { + goto error; + } + } else { PyErr_Format(PyExc_ValueError, "unsupported mode %R", modeobj); goto error; diff --git a/Python/crossinterp.c b/Python/crossinterp.c index 662c9c72b15eb7..ddf35010c505a9 100644 --- a/Python/crossinterp.c +++ b/Python/crossinterp.c @@ -2,6 +2,7 @@ /* API for managing interactions between isolated interpreters */ #include "Python.h" +#include "osdefs.h" // MAXPATHLEN #include "pycore_ceval.h" // _Py_simple_func #include "pycore_crossinterp.h" // _PyXIData_t #include "pycore_initconfig.h" // _PyStatus_OK() @@ -9,6 +10,155 @@ #include "pycore_typeobject.h" // _PyStaticType_InitBuiltin() +static Py_ssize_t +_Py_GetMainfile(char *buffer, size_t maxlen) +{ + // We don't expect subinterpreters to have the __main__ module's + // __name__ set, but proceed just in case. + PyThreadState *tstate = _PyThreadState_GET(); + PyObject *module = _Py_GetMainModule(tstate); + if (_Py_CheckMainModule(module) < 0) { + return -1; + } + Py_ssize_t size = _PyModule_GetFilenameUTF8(module, buffer, maxlen); + Py_DECREF(module); + return size; +} + + +static PyObject * +import_get_module(PyThreadState *tstate, const char *modname) +{ + PyObject *module = NULL; + if (strcmp(modname, "__main__") == 0) { + module = _Py_GetMainModule(tstate); + if (_Py_CheckMainModule(module) < 0) { + assert(_PyErr_Occurred(tstate)); + return NULL; + } + } + else { + module = PyImport_ImportModule(modname); + if (module == NULL) { + return NULL; + } + } + return module; +} + + +static PyObject * +runpy_run_path(const char *filename, const char *modname) +{ + PyObject *run_path = PyImport_ImportModuleAttrString("runpy", "run_path"); + if (run_path == NULL) { + return NULL; + } + PyObject *args = Py_BuildValue("(sOs)", filename, Py_None, modname); + if (args == NULL) { + Py_DECREF(run_path); + return NULL; + } + PyObject *ns = PyObject_Call(run_path, args, NULL); + Py_DECREF(run_path); + Py_DECREF(args); + return ns; +} + + +static PyObject * +pyerr_get_message(PyObject *exc) +{ + assert(!PyErr_Occurred()); + PyObject *args = PyException_GetArgs(exc); + if (args == NULL || args == Py_None || PyObject_Size(args) < 1) { + return NULL; + } + if (PyUnicode_Check(args)) { + return args; + } + PyObject *msg = PySequence_GetItem(args, 0); + Py_DECREF(args); + if (msg == NULL) { + PyErr_Clear(); + return NULL; + } + if (!PyUnicode_Check(msg)) { + Py_DECREF(msg); + return NULL; + } + return msg; +} + +#define MAX_MODNAME (255) +#define MAX_ATTRNAME (255) + +struct attributeerror_info { + char modname[MAX_MODNAME+1]; + char attrname[MAX_ATTRNAME+1]; +}; + +static int +_parse_attributeerror(PyObject *exc, struct attributeerror_info *info) +{ + assert(exc != NULL); + assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError)); + int res = -1; + + PyObject *msgobj = pyerr_get_message(exc); + if (msgobj == NULL) { + return -1; + } + const char *err = PyUnicode_AsUTF8(msgobj); + + if (strncmp(err, "module '", 8) != 0) { + goto finally; + } + err += 8; + + const char *matched = strchr(err, '\''); + if (matched == NULL) { + goto finally; + } + Py_ssize_t len = matched - err; + if (len > MAX_MODNAME) { + goto finally; + } + (void)strncpy(info->modname, err, len); + info->modname[len] = '\0'; + err = matched; + + if (strncmp(err, "' has no attribute '", 20) != 0) { + goto finally; + } + err += 20; + + matched = strchr(err, '\''); + if (matched == NULL) { + goto finally; + } + len = matched - err; + if (len > MAX_ATTRNAME) { + goto finally; + } + (void)strncpy(info->attrname, err, len); + info->attrname[len] = '\0'; + err = matched + 1; + + if (strlen(err) > 0) { + goto finally; + } + res = 0; + +finally: + Py_DECREF(msgobj); + return res; +} + +#undef MAX_MODNAME +#undef MAX_ATTRNAME + + /**************/ /* exceptions */ /**************/ @@ -286,6 +436,307 @@ _PyObject_GetXIData(PyThreadState *tstate, } +/* pickle C-API */ + +struct _pickle_context { + PyThreadState *tstate; +}; + +static PyObject * +_PyPickle_Dumps(struct _pickle_context *ctx, PyObject *obj) +{ + PyObject *dumps = PyImport_ImportModuleAttrString("pickle", "dumps"); + if (dumps == NULL) { + return NULL; + } + PyObject *bytes = PyObject_CallOneArg(dumps, obj); + Py_DECREF(dumps); + return bytes; +} + + +struct sync_module_result { + PyObject *module; + PyObject *loaded; + PyObject *failed; +}; + +struct sync_module { + const char *filename; + char _filename[MAXPATHLEN+1]; + struct sync_module_result cached; +}; + +static void +sync_module_clear(struct sync_module *data) +{ + data->filename = NULL; + Py_CLEAR(data->cached.module); + Py_CLEAR(data->cached.loaded); + Py_CLEAR(data->cached.failed); +} + + +struct _unpickle_context { + PyThreadState *tstate; + // We only special-case the __main__ module, + // since other modules behave consistently. + struct sync_module main; +}; + +static void +_unpickle_context_clear(struct _unpickle_context *ctx) +{ + sync_module_clear(&ctx->main); +} + +static struct sync_module_result +_unpickle_context_get_module(struct _unpickle_context *ctx, + const char *modname) +{ + if (strcmp(modname, "__main__") == 0) { + return ctx->main.cached; + } + else { + return (struct sync_module_result){ + .failed = PyExc_NotImplementedError, + }; + } +} + +static struct sync_module_result +_unpickle_context_set_module(struct _unpickle_context *ctx, + const char *modname) +{ + struct sync_module_result res = {0}; + struct sync_module_result *cached = NULL; + const char *filename = NULL; + if (strcmp(modname, "__main__") == 0) { + cached = &ctx->main.cached; + filename = ctx->main.filename; + } + else { + res.failed = PyExc_NotImplementedError; + goto finally; + } + + res.module = import_get_module(ctx->tstate, modname); + if (res.module == NULL) { + res.failed = _PyErr_GetRaisedException(ctx->tstate); + assert(res.failed != NULL); + goto finally; + } + + if (filename == NULL) { + Py_CLEAR(res.module); + res.failed = PyExc_NotImplementedError; + goto finally; + } + res.loaded = runpy_run_path(filename, modname); + if (res.loaded == NULL) { + Py_CLEAR(res.module); + res.failed = _PyErr_GetRaisedException(ctx->tstate); + assert(res.failed != NULL); + goto finally; + } + +finally: + if (cached != NULL) { + assert(cached->module == NULL); + assert(cached->loaded == NULL); + assert(cached->failed == NULL); + *cached = res; + } + return res; +} + + +static int +_handle_unpickle_missing_attr(struct _unpickle_context *ctx, PyObject *exc) +{ + // The caller must check if an exception is set or not when -1 is returned. + assert(!_PyErr_Occurred(ctx->tstate)); + assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError)); + struct attributeerror_info info; + if (_parse_attributeerror(exc, &info) < 0) { + return -1; + } + + // Get the module. + struct sync_module_result mod = _unpickle_context_get_module(ctx, info.modname); + if (mod.failed != NULL) { + // It must have failed previously. + return -1; + } + if (mod.module == NULL) { + mod = _unpickle_context_set_module(ctx, info.modname); + if (mod.failed != NULL) { + return -1; + } + assert(mod.module != NULL); + } + + // Bail out if it is unexpectedly set already. + if (PyObject_HasAttrString(mod.module, info.attrname)) { + return -1; + } + + // Try setting the attribute. + PyObject *value = NULL; + if (PyDict_GetItemStringRef(mod.loaded, info.attrname, &value) <= 0) { + return -1; + } + assert(value != NULL); + int res = PyObject_SetAttrString(mod.module, info.attrname, value); + Py_DECREF(value); + if (res < 0) { + return -1; + } + + return 0; +} + +static PyObject * +_PyPickle_Loads(struct _unpickle_context *ctx, PyObject *pickled) +{ + PyObject *loads = PyImport_ImportModuleAttrString("pickle", "loads"); + if (loads == NULL) { + return NULL; + } + PyObject *obj = PyObject_CallOneArg(loads, pickled); + if (ctx != NULL) { + while (obj == NULL) { + assert(_PyErr_Occurred(ctx->tstate)); + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + // We leave other failures unhandled. + break; + } + // Try setting the attr if not set. + PyObject *exc = _PyErr_GetRaisedException(ctx->tstate); + if (_handle_unpickle_missing_attr(ctx, exc) < 0) { + // Any resulting exceptions are ignored + // in favor of the original. + _PyErr_SetRaisedException(ctx->tstate, exc); + break; + } + Py_CLEAR(exc); + // Retry with the attribute set. + obj = PyObject_CallOneArg(loads, pickled); + } + } + Py_DECREF(loads); + return obj; +} + + +/* pickle wrapper */ + +struct _pickle_xid_context { + // __main__.__file__ + struct { + const char *utf8; + size_t len; + char _utf8[MAXPATHLEN+1]; + } mainfile; +}; + +static int +_set_pickle_xid_context(PyThreadState *tstate, struct _pickle_xid_context *ctx) +{ + // Set mainfile if possible. + Py_ssize_t len = _Py_GetMainfile(ctx->mainfile._utf8, MAXPATHLEN); + if (len < 0) { + // For now we ignore any exceptions. + PyErr_Clear(); + } + else if (len > 0) { + ctx->mainfile.utf8 = ctx->mainfile._utf8; + ctx->mainfile.len = (size_t)len; + } + + return 0; +} + + +struct _shared_pickle_data { + _PyBytes_data_t pickled; // Must be first if we use _PyBytes_FromXIData(). + struct _pickle_xid_context ctx; +}; + +PyObject * +_PyPickle_LoadFromXIData(_PyXIData_t *xidata) +{ + PyThreadState *tstate = _PyThreadState_GET(); + struct _shared_pickle_data *shared = + (struct _shared_pickle_data *)xidata->data; + // We avoid copying the pickled data by wrapping it in a memoryview. + // The alternative is to get a bytes object using _PyBytes_FromXIData(). + PyObject *pickled = PyMemoryView_FromMemory( + (char *)shared->pickled.bytes, shared->pickled.len, PyBUF_READ); + if (pickled == NULL) { + return NULL; + } + + // Unpickle the object. + struct _unpickle_context ctx = { + .tstate = tstate, + .main = { + .filename = shared->ctx.mainfile.utf8, + }, + }; + PyObject *obj = _PyPickle_Loads(&ctx, pickled); + Py_DECREF(pickled); + _unpickle_context_clear(&ctx); + if (obj == NULL) { + PyObject *cause = _PyErr_GetRaisedException(tstate); + assert(cause != NULL); + _set_xid_lookup_failure( + tstate, NULL, "object could not be unpickled", cause); + Py_DECREF(cause); + } + return obj; +} + +int +_PyPickle_GetXIData(PyThreadState *tstate, PyObject *obj, _PyXIData_t *xidata) +{ + // Pickle the object. + struct _pickle_context ctx = { + .tstate = tstate, + }; + PyObject *bytes = _PyPickle_Dumps(&ctx, obj); + if (bytes == NULL) { + PyObject *cause = _PyErr_GetRaisedException(tstate); + assert(cause != NULL); + _set_xid_lookup_failure( + tstate, NULL, "object could not be pickled", cause); + Py_DECREF(cause); + return -1; + } + + // If we had an "unwrapper" mechnanism, we could call + // _PyObject_GetXIData() on the bytes object directly and add + // a simple unwrapper to call pickle.loads() on the bytes. + size_t size = sizeof(struct _shared_pickle_data); + struct _shared_pickle_data *shared = + (struct _shared_pickle_data *)_PyBytes_GetXIDataWrapped( + tstate, bytes, size, _PyPickle_LoadFromXIData, xidata); + Py_DECREF(bytes); + if (shared == NULL) { + return -1; + } + + // If it mattered, we could skip getting __main__.__file__ + // when "__main__" doesn't show up in the pickle bytes. + if (_set_pickle_xid_context(tstate, &shared->ctx) < 0) { + _xidata_clear(xidata); + return -1; + } + + return 0; +} + + /* using cross-interpreter data */ PyObject * From 647c5384c0361257ed82abd0fc7f4af55a267ea5 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Mon, 28 Apr 2025 13:24:49 -0600 Subject: [PATCH 2/4] Drop a TODO comment. --- Lib/test/_crossinterp_definitions.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/Lib/test/_crossinterp_definitions.py b/Lib/test/_crossinterp_definitions.py index 455b206a326047..5a37c404b4c0dd 100644 --- a/Lib/test/_crossinterp_definitions.py +++ b/Lib/test/_crossinterp_definitions.py @@ -133,9 +133,6 @@ def ham_C_closure(z): ] -# XXX set dishonest __file__, __module__, __name__ - - ####################################### # function-like From bf0013c46e99e82f1252ab0706f364444d059ad9 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Wed, 30 Apr 2025 13:19:06 -0600 Subject: [PATCH 3/4] Fix the tests. --- Lib/test/test_crossinterp.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py index f34eb729f3a048..32d6fd4e94bf1b 100644 --- a/Lib/test/test_crossinterp.py +++ b/Lib/test/test_crossinterp.py @@ -12,6 +12,7 @@ _interpreters = import_helper.import_module('_interpreters') from _interpreters import NotShareableError +from test import _code_definitions as code_defs from test import _crossinterp_definitions as defs @@ -24,9 +25,23 @@ n not in ('DynamicClassAttribute', '_GeneratorWrapper'))] DEFS = defs +with open(code_defs.__file__) as infile: + _code_defs_text = infile.read() with open(DEFS.__file__) as infile: - DEFS_TEXT = infile.read() -del infile + _defs_text = infile.read() + _defs_text = _defs_text.replace('from ', '# from ') +DEFS_TEXT = f""" +####################################### +# from {code_defs.__file__} + +{_code_defs_text} + +####################################### +# from {defs.__file__} + +{_defs_text} +""" +del infile, _code_defs_text, _defs_text def load_defs(module=None): From a135db8c2558b319f222caadebd9b4d3911b7293 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Wed, 30 Apr 2025 13:30:24 -0600 Subject: [PATCH 4/4] Fix import_helper.module_restored(). --- Lib/test/support/import_helper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py index 1c827c5ded27b0..edb734d294f287 100644 --- a/Lib/test/support/import_helper.py +++ b/Lib/test/support/import_helper.py @@ -338,8 +338,8 @@ def ensure_lazy_imports(imported_module, modules_to_block): @contextlib.contextmanager def module_restored(name): """A context manager that restores a module to the original state.""" - missing = name in sys.modules - orig = sys.modules.get(name) + missing = object() + orig = sys.modules.get(name, missing) if orig is None: mod = importlib.import_module(name) else: @@ -349,7 +349,7 @@ def module_restored(name): try: yield mod finally: - if missing: + if orig is missing: sys.modules.pop(name, None) else: sys.modules[name] = orig pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy