diff --git a/Include/internal/pycore_crossinterp.h b/Include/internal/pycore_crossinterp.h index 4b7446a1f40ccf..4b4617fdbcb2ad 100644 --- a/Include/internal/pycore_crossinterp.h +++ b/Include/internal/pycore_crossinterp.h @@ -171,6 +171,13 @@ PyAPI_FUNC(_PyBytes_data_t *) _PyBytes_GetXIDataWrapped( xid_newobjfunc, _PyXIData_t *); +// _PyObject_GetXIData() for pickle +PyAPI_DATA(PyObject *) _PyPickle_LoadFromXIData(_PyXIData_t *); +PyAPI_FUNC(int) _PyPickle_GetXIData( + PyThreadState *, + PyObject *, + _PyXIData_t *); + // _PyObject_GetXIData() for marshal PyAPI_FUNC(PyObject *) _PyMarshal_ReadObjectFromXIData(_PyXIData_t *); PyAPI_FUNC(int) _PyMarshal_GetXIData( diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py index 42cfe9cfa8cb72..edb734d294f287 100644 --- a/Lib/test/support/import_helper.py +++ b/Lib/test/support/import_helper.py @@ -1,6 +1,7 @@ import contextlib import _imp import importlib +import importlib.machinery import importlib.util import os import shutil @@ -332,3 +333,110 @@ def ensure_lazy_imports(imported_module, modules_to_block): ) from .script_helper import assert_python_ok assert_python_ok("-S", "-c", script) + + +@contextlib.contextmanager +def module_restored(name): + """A context manager that restores a module to the original state.""" + missing = object() + orig = sys.modules.get(name, missing) + if orig is None: + mod = importlib.import_module(name) + else: + mod = type(sys)(name) + mod.__dict__.update(orig.__dict__) + sys.modules[name] = mod + try: + yield mod + finally: + if orig is missing: + sys.modules.pop(name, None) + else: + sys.modules[name] = orig + + +def create_module(name, loader=None, *, ispkg=False): + """Return a new, empty module.""" + spec = importlib.machinery.ModuleSpec( + name, + loader, + origin='', + is_package=ispkg, + ) + return importlib.util.module_from_spec(spec) + + +def _ensure_module(name, ispkg, addparent, clearnone): + try: + mod = orig = sys.modules[name] + except KeyError: + mod = orig = None + missing = True + else: + missing = False + if mod is not None: + # It was already imported. + return mod, orig, missing + # Otherwise, None means it was explicitly disabled. + + assert name != '__main__' + if not missing: + assert orig is None, (name, sys.modules[name]) + if not clearnone: + raise ModuleNotFoundError(name) + del sys.modules[name] + # Try normal import, then fall back to adding the module. + try: + mod = importlib.import_module(name) + except ModuleNotFoundError: + if addparent and not clearnone: + addparent = None + mod = _add_module(name, ispkg, addparent) + return mod, orig, missing + + +def _add_module(spec, ispkg, addparent): + if isinstance(spec, str): + name = spec + mod = create_module(name, ispkg=ispkg) + spec = mod.__spec__ + else: + name = spec.name + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + if addparent is not False and spec.parent: + _ensure_module(spec.parent, True, addparent, bool(addparent)) + return mod + + +def add_module(spec, *, parents=True): + """Return the module after creating it and adding it to sys.modules. + + If parents is True then also create any missing parents. + """ + return _add_module(spec, False, parents) + + +def add_package(spec, *, parents=True): + """Return the module after creating it and adding it to sys.modules. + + If parents is True then also create any missing parents. + """ + return _add_module(spec, True, parents) + + +def ensure_module_imported(name, *, clearnone=True): + """Return the corresponding module. + + If it was already imported then return that. Otherwise, try + importing it (optionally clear it first if None). If that fails + then create a new empty module. + + It can be helpful to combine this with ready_to_import() and/or + isolated_modules(). + """ + if sys.modules.get(name) is not None: + mod = sys.modules[name] + else: + mod, _, _ = _force_import(name, False, True, clearnone) + return mod diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py index 5ebb78b0ea9e3b..32d6fd4e94bf1b 100644 --- a/Lib/test/test_crossinterp.py +++ b/Lib/test/test_crossinterp.py @@ -1,3 +1,6 @@ +import contextlib +import importlib +import importlib.util import itertools import sys import types @@ -9,7 +12,7 @@ _interpreters = import_helper.import_module('_interpreters') from _interpreters import NotShareableError - +from test import _code_definitions as code_defs from test import _crossinterp_definitions as defs @@ -21,6 +24,88 @@ if (isinstance(o, type) and n not in ('DynamicClassAttribute', '_GeneratorWrapper'))] +DEFS = defs +with open(code_defs.__file__) as infile: + _code_defs_text = infile.read() +with open(DEFS.__file__) as infile: + _defs_text = infile.read() + _defs_text = _defs_text.replace('from ', '# from ') +DEFS_TEXT = f""" +####################################### +# from {code_defs.__file__} + +{_code_defs_text} + +####################################### +# from {defs.__file__} + +{_defs_text} +""" +del infile, _code_defs_text, _defs_text + + +def load_defs(module=None): + """Return a new copy of the test._crossinterp_definitions module. + + The module's __name__ matches the "module" arg, which is either + a str or a module. + + If the "module" arg is a module then the just-loaded defs are also + copied into that module. + + Note that the new module is not added to sys.modules. + """ + if module is None: + modname = DEFS.__name__ + elif isinstance(module, str): + modname = module + module = None + else: + modname = module.__name__ + # Create the new module and populate it. + defs = import_helper.create_module(modname) + defs.__file__ = DEFS.__file__ + exec(DEFS_TEXT, defs.__dict__) + # Copy the defs into the module arg, if any. + if module is not None: + for name, value in defs.__dict__.items(): + if name.startswith('_'): + continue + assert not hasattr(module, name), (name, getattr(module, name)) + setattr(module, name, value) + return defs + + +@contextlib.contextmanager +def using___main__(): + """Make sure __main__ module exists (and clean up after).""" + modname = '__main__' + if modname not in sys.modules: + with import_helper.isolated_modules(): + yield import_helper.add_module(modname) + else: + with import_helper.module_restored(modname) as mod: + yield mod + + +@contextlib.contextmanager +def temp_module(modname): + """Create the module and add to sys.modules, then remove it after.""" + assert modname not in sys.modules, (modname,) + with import_helper.isolated_modules(): + yield import_helper.add_module(modname) + + +@contextlib.contextmanager +def missing_defs_module(modname, *, prep=False): + assert modname not in sys.modules, (modname,) + if prep: + with import_helper.ready_to_import(modname, DEFS_TEXT): + yield modname + else: + with import_helper.isolated_modules(): + yield modname + class _GetXIDataTests(unittest.TestCase): @@ -32,52 +117,49 @@ def get_xidata(self, obj, *, mode=None): def get_roundtrip(self, obj, *, mode=None): mode = self._resolve_mode(mode) - xid =_testinternalcapi.get_crossinterp_data(obj, mode) + return self._get_roundtrip(obj, mode) + + def _get_roundtrip(self, obj, mode): + xid = _testinternalcapi.get_crossinterp_data(obj, mode) return _testinternalcapi.restore_crossinterp_data(xid) - def iter_roundtrip_values(self, values, *, mode=None): + def assert_roundtrip_identical(self, values, *, mode=None): mode = self._resolve_mode(mode) for obj in values: with self.subTest(obj): - xid = _testinternalcapi.get_crossinterp_data(obj, mode) - got = _testinternalcapi.restore_crossinterp_data(xid) - yield obj, got - - def assert_roundtrip_identical(self, values, *, mode=None): - for obj, got in self.iter_roundtrip_values(values, mode=mode): - # XXX What about between interpreters? - self.assertIs(got, obj) + got = self._get_roundtrip(obj, mode) + self.assertIs(got, obj) def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None): - for obj, got in self.iter_roundtrip_values(values, mode=mode): - self.assertEqual(got, obj) - self.assertIs(type(got), - type(obj) if expecttype is None else expecttype) - -# def assert_roundtrip_equal_not_identical(self, values, *, -# mode=None, expecttype=None): -# mode = self._resolve_mode(mode) -# for obj in values: -# cls = type(obj) -# with self.subTest(obj): -# got = self._get_roundtrip(obj, mode) -# self.assertIsNot(got, obj) -# self.assertIs(type(got), type(obj)) -# self.assertEqual(got, obj) -# self.assertIs(type(got), -# cls if expecttype is None else expecttype) -# -# def assert_roundtrip_not_equal(self, values, *, mode=None, expecttype=None): -# mode = self._resolve_mode(mode) -# for obj in values: -# cls = type(obj) -# with self.subTest(obj): -# got = self._get_roundtrip(obj, mode) -# self.assertIsNot(got, obj) -# self.assertIs(type(got), type(obj)) -# self.assertNotEqual(got, obj) -# self.assertIs(type(got), -# cls if expecttype is None else expecttype) + mode = self._resolve_mode(mode) + for obj in values: + with self.subTest(obj): + got = self._get_roundtrip(obj, mode) + self.assertEqual(got, obj) + self.assertIs(type(got), + type(obj) if expecttype is None else expecttype) + + def assert_roundtrip_equal_not_identical(self, values, *, + mode=None, expecttype=None): + mode = self._resolve_mode(mode) + for obj in values: + with self.subTest(obj): + got = self._get_roundtrip(obj, mode) + self.assertIsNot(got, obj) + self.assertIs(type(got), + type(obj) if expecttype is None else expecttype) + self.assertEqual(got, obj) + + def assert_roundtrip_not_equal(self, values, *, + mode=None, expecttype=None): + mode = self._resolve_mode(mode) + for obj in values: + with self.subTest(obj): + got = self._get_roundtrip(obj, mode) + self.assertIsNot(got, obj) + self.assertIs(type(got), + type(obj) if expecttype is None else expecttype) + self.assertNotEqual(got, obj) def assert_not_shareable(self, values, exctype=None, *, mode=None): mode = self._resolve_mode(mode) @@ -95,6 +177,363 @@ def _resolve_mode(self, mode): return mode +class PickleTests(_GetXIDataTests): + + MODE = 'pickle' + + def test_shareable(self): + self.assert_roundtrip_equal([ + # singletons + None, + True, + False, + # bytes + *(i.to_bytes(2, 'little', signed=True) + for i in range(-1, 258)), + # str + 'hello world', + '你好世界', + '', + # int + sys.maxsize, + -sys.maxsize - 1, + *range(-1, 258), + # float + 0.0, + 1.1, + -1.0, + 0.12345678, + -0.12345678, + # tuple + (), + (1,), + ("hello", "world", ), + (1, True, "hello"), + ((1,),), + ((1, 2), (3, 4)), + ((1, 2), (3, 4), (5, 6)), + ]) + # not shareable using xidata + self.assert_roundtrip_equal([ + # int + sys.maxsize + 1, + -sys.maxsize - 2, + 2**1000, + # tuple + (0, 1.0, []), + (0, 1.0, {}), + (0, 1.0, ([],)), + (0, 1.0, ({},)), + ]) + + def test_list(self): + self.assert_roundtrip_equal_not_identical([ + [], + [1, 2, 3], + [[1], (2,), {3: 4}], + ]) + + def test_dict(self): + self.assert_roundtrip_equal_not_identical([ + {}, + {1: 7, 2: 8, 3: 9}, + {1: [1], 2: (2,), 3: {3: 4}}, + ]) + + def test_set(self): + self.assert_roundtrip_equal_not_identical([ + set(), + {1, 2, 3}, + {frozenset({1}), (2,)}, + ]) + + # classes + + def assert_class_defs_same(self, defs): + # Unpickle relative to the unchanged original module. + self.assert_roundtrip_identical(defs.TOP_CLASSES) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + if cls in defs.CLASSES_WITHOUT_EQUALITY: + continue + instances.append(cls(*args)) + self.assert_roundtrip_equal_not_identical(instances) + + # these don't compare equal + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + if cls not in defs.CLASSES_WITHOUT_EQUALITY: + continue + instances.append(cls(*args)) + self.assert_roundtrip_not_equal(instances) + + def assert_class_defs_other_pickle(self, defs, mod): + # Pickle relative to a different module than the original. + for cls in defs.TOP_CLASSES: + assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__)) + self.assert_not_shareable(defs.TOP_CLASSES) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + instances.append(cls(*args)) + self.assert_not_shareable(instances) + + def assert_class_defs_other_unpickle(self, defs, mod, *, fail=False): + # Unpickle relative to a different module than the original. + for cls in defs.TOP_CLASSES: + assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__)) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + with self.subTest(cls): + setattr(mod, cls.__name__, cls) + xid = self.get_xidata(cls) + inst = cls(*args) + instxid = self.get_xidata(inst) + instances.append( + (cls, xid, inst, instxid)) + + for cls, xid, inst, instxid in instances: + with self.subTest(cls): + delattr(mod, cls.__name__) + if fail: + with self.assertRaises(NotShareableError): + _testinternalcapi.restore_crossinterp_data(xid) + continue + got = _testinternalcapi.restore_crossinterp_data(xid) + self.assertIsNot(got, cls) + self.assertNotEqual(got, cls) + + gotcls = got + got = _testinternalcapi.restore_crossinterp_data(instxid) + self.assertIsNot(got, inst) + self.assertIs(type(got), gotcls) + if cls in defs.CLASSES_WITHOUT_EQUALITY: + self.assertNotEqual(got, inst) + elif cls in defs.BUILTIN_SUBCLASSES: + self.assertEqual(got, inst) + else: + self.assertNotEqual(got, inst) + + def assert_class_defs_not_shareable(self, defs): + self.assert_not_shareable(defs.TOP_CLASSES) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + instances.append(cls(*args)) + self.assert_not_shareable(instances) + + def test_user_class_normal(self): + self.assert_class_defs_same(defs) + + def test_user_class_in___main__(self): + with using___main__() as mod: + defs = load_defs(mod) + self.assert_class_defs_same(defs) + + def test_user_class_not_in___main___with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_class_defs_not_shareable(defs) + + def test_user_class_not_in___main___without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_class_defs_not_shareable(defs) + + def test_user_class_not_in___main___unpickle_with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_class_defs_other_unpickle(defs, mod) + + def test_user_class_not_in___main___unpickle_without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_class_defs_other_unpickle(defs, mod, fail=True) + + def test_user_class_in_module(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod) + self.assert_class_defs_same(defs) + + def test_user_class_not_in_module_with_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + assert defs.__file__ + # For now, we only address this case for __main__. + self.assert_class_defs_not_shareable(defs) + + def test_user_class_not_in_module_without_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + defs.__file__ = None + self.assert_class_defs_not_shareable(defs) + + def test_user_class_module_missing_then_imported(self): + with missing_defs_module('__spam__', prep=True) as modname: + defs = load_defs(modname) + # For now, we only address this case for __main__. + self.assert_class_defs_not_shareable(defs) + + def test_user_class_module_missing_not_available(self): + with missing_defs_module('__spam__') as modname: + defs = load_defs(modname) + self.assert_class_defs_not_shareable(defs) + + def test_nested_class(self): + eggs = defs.EggsNested() + with self.assertRaises(NotShareableError): + self.get_roundtrip(eggs) + + # functions + + def assert_func_defs_same(self, defs): + # Unpickle relative to the unchanged original module. + self.assert_roundtrip_identical(defs.TOP_FUNCTIONS) + + def assert_func_defs_other_pickle(self, defs, mod): + # Pickle relative to a different module than the original. + for func in defs.TOP_FUNCTIONS: + assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__)) + self.assert_not_shareable(defs.TOP_FUNCTIONS) + + def assert_func_defs_other_unpickle(self, defs, mod, *, fail=False): + # Unpickle relative to a different module than the original. + for func in defs.TOP_FUNCTIONS: + assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__)) + + captured = [] + for func in defs.TOP_FUNCTIONS: + with self.subTest(func): + setattr(mod, func.__name__, func) + xid = self.get_xidata(func) + captured.append( + (func, xid)) + + for func, xid in captured: + with self.subTest(func): + delattr(mod, func.__name__) + if fail: + with self.assertRaises(NotShareableError): + _testinternalcapi.restore_crossinterp_data(xid) + continue + got = _testinternalcapi.restore_crossinterp_data(xid) + self.assertIsNot(got, func) + self.assertNotEqual(got, func) + + def assert_func_defs_not_shareable(self, defs): + self.assert_not_shareable(defs.TOP_FUNCTIONS) + + def test_user_function_normal(self): +# self.assert_roundtrip_equal(defs.TOP_FUNCTIONS) + self.assert_func_defs_same(defs) + + def test_user_func_in___main__(self): + with using___main__() as mod: + defs = load_defs(mod) + self.assert_func_defs_same(defs) + + def test_user_func_not_in___main___with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_func_defs_not_shareable(defs) + + def test_user_func_not_in___main___without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_func_defs_not_shareable(defs) + + def test_user_func_not_in___main___unpickle_with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_func_defs_other_unpickle(defs, mod) + + def test_user_func_not_in___main___unpickle_without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_func_defs_other_unpickle(defs, mod, fail=True) + + def test_user_func_in_module(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod) + self.assert_func_defs_same(defs) + + def test_user_func_not_in_module_with_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + assert defs.__file__ + # For now, we only address this case for __main__. + self.assert_func_defs_not_shareable(defs) + + def test_user_func_not_in_module_without_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + defs.__file__ = None + self.assert_func_defs_not_shareable(defs) + + def test_user_func_module_missing_then_imported(self): + with missing_defs_module('__spam__', prep=True) as modname: + defs = load_defs(modname) + # For now, we only address this case for __main__. + self.assert_func_defs_not_shareable(defs) + + def test_user_func_module_missing_not_available(self): + with missing_defs_module('__spam__') as modname: + defs = load_defs(modname) + self.assert_func_defs_not_shareable(defs) + + def test_nested_function(self): + self.assert_not_shareable(defs.NESTED_FUNCTIONS) + + # exceptions + + def test_user_exception_normal(self): + self.assert_roundtrip_not_equal([ + defs.MimimalError('error!'), + ]) + self.assert_roundtrip_equal_not_identical([ + defs.RichError('error!', 42), + ]) + + def test_builtin_exception(self): + msg = 'error!' + try: + raise Exception + except Exception as exc: + caught = exc + special = { + BaseExceptionGroup: (msg, [caught]), + ExceptionGroup: (msg, [caught]), +# UnicodeError: (None, msg, None, None, None), + UnicodeEncodeError: ('utf-8', '', 1, 3, msg), + UnicodeDecodeError: ('utf-8', b'', 1, 3, msg), + UnicodeTranslateError: ('', 1, 3, msg), + } + exceptions = [] + for cls in EXCEPTION_TYPES: + args = special.get(cls) or (msg,) + exceptions.append(cls(*args)) + + self.assert_roundtrip_not_equal(exceptions) + + class MarshalTests(_GetXIDataTests): MODE = 'marshal' @@ -444,22 +883,12 @@ def test_module(self): ]) def test_class(self): - self.assert_not_shareable([ - defs.Spam, - defs.SpamOkay, - defs.SpamFull, - defs.SubSpamFull, - defs.SubTuple, - defs.EggsNested, - ]) - self.assert_not_shareable([ - defs.Spam(), - defs.SpamOkay(), - defs.SpamFull(1, 2, 3), - defs.SubSpamFull(1, 2, 3), - defs.SubTuple([1, 2, 3]), - defs.EggsNested(), - ]) + self.assert_not_shareable(defs.CLASSES) + + instances = [] + for cls, args in defs.CLASSES.items(): + instances.append(cls(*args)) + self.assert_not_shareable(instances) def test_builtin_type(self): self.assert_not_shareable([ diff --git a/Modules/_testinternalcapi.c b/Modules/_testinternalcapi.c index 4bfe88f2cf920c..812737e294fcb7 100644 --- a/Modules/_testinternalcapi.c +++ b/Modules/_testinternalcapi.c @@ -1939,6 +1939,11 @@ get_crossinterp_data(PyObject *self, PyObject *args, PyObject *kwargs) goto error; } } + else if (strcmp(mode, "pickle") == 0) { + if (_PyPickle_GetXIData(tstate, obj, xidata) != 0) { + goto error; + } + } else if (strcmp(mode, "marshal") == 0) { if (_PyMarshal_GetXIData(tstate, obj, xidata) != 0) { goto error; diff --git a/Python/crossinterp.c b/Python/crossinterp.c index 753d784a503467..a9f9b78562917e 100644 --- a/Python/crossinterp.c +++ b/Python/crossinterp.c @@ -3,6 +3,7 @@ #include "Python.h" #include "marshal.h" // PyMarshal_WriteObjectToString() +#include "osdefs.h" // MAXPATHLEN #include "pycore_ceval.h" // _Py_simple_func #include "pycore_crossinterp.h" // _PyXIData_t #include "pycore_initconfig.h" // _PyStatus_OK() @@ -10,6 +11,155 @@ #include "pycore_typeobject.h" // _PyStaticType_InitBuiltin() +static Py_ssize_t +_Py_GetMainfile(char *buffer, size_t maxlen) +{ + // We don't expect subinterpreters to have the __main__ module's + // __name__ set, but proceed just in case. + PyThreadState *tstate = _PyThreadState_GET(); + PyObject *module = _Py_GetMainModule(tstate); + if (_Py_CheckMainModule(module) < 0) { + return -1; + } + Py_ssize_t size = _PyModule_GetFilenameUTF8(module, buffer, maxlen); + Py_DECREF(module); + return size; +} + + +static PyObject * +import_get_module(PyThreadState *tstate, const char *modname) +{ + PyObject *module = NULL; + if (strcmp(modname, "__main__") == 0) { + module = _Py_GetMainModule(tstate); + if (_Py_CheckMainModule(module) < 0) { + assert(_PyErr_Occurred(tstate)); + return NULL; + } + } + else { + module = PyImport_ImportModule(modname); + if (module == NULL) { + return NULL; + } + } + return module; +} + + +static PyObject * +runpy_run_path(const char *filename, const char *modname) +{ + PyObject *run_path = PyImport_ImportModuleAttrString("runpy", "run_path"); + if (run_path == NULL) { + return NULL; + } + PyObject *args = Py_BuildValue("(sOs)", filename, Py_None, modname); + if (args == NULL) { + Py_DECREF(run_path); + return NULL; + } + PyObject *ns = PyObject_Call(run_path, args, NULL); + Py_DECREF(run_path); + Py_DECREF(args); + return ns; +} + + +static PyObject * +pyerr_get_message(PyObject *exc) +{ + assert(!PyErr_Occurred()); + PyObject *args = PyException_GetArgs(exc); + if (args == NULL || args == Py_None || PyObject_Size(args) < 1) { + return NULL; + } + if (PyUnicode_Check(args)) { + return args; + } + PyObject *msg = PySequence_GetItem(args, 0); + Py_DECREF(args); + if (msg == NULL) { + PyErr_Clear(); + return NULL; + } + if (!PyUnicode_Check(msg)) { + Py_DECREF(msg); + return NULL; + } + return msg; +} + +#define MAX_MODNAME (255) +#define MAX_ATTRNAME (255) + +struct attributeerror_info { + char modname[MAX_MODNAME+1]; + char attrname[MAX_ATTRNAME+1]; +}; + +static int +_parse_attributeerror(PyObject *exc, struct attributeerror_info *info) +{ + assert(exc != NULL); + assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError)); + int res = -1; + + PyObject *msgobj = pyerr_get_message(exc); + if (msgobj == NULL) { + return -1; + } + const char *err = PyUnicode_AsUTF8(msgobj); + + if (strncmp(err, "module '", 8) != 0) { + goto finally; + } + err += 8; + + const char *matched = strchr(err, '\''); + if (matched == NULL) { + goto finally; + } + Py_ssize_t len = matched - err; + if (len > MAX_MODNAME) { + goto finally; + } + (void)strncpy(info->modname, err, len); + info->modname[len] = '\0'; + err = matched; + + if (strncmp(err, "' has no attribute '", 20) != 0) { + goto finally; + } + err += 20; + + matched = strchr(err, '\''); + if (matched == NULL) { + goto finally; + } + len = matched - err; + if (len > MAX_ATTRNAME) { + goto finally; + } + (void)strncpy(info->attrname, err, len); + info->attrname[len] = '\0'; + err = matched + 1; + + if (strlen(err) > 0) { + goto finally; + } + res = 0; + +finally: + Py_DECREF(msgobj); + return res; +} + +#undef MAX_MODNAME +#undef MAX_ATTRNAME + + /**************/ /* exceptions */ /**************/ @@ -287,6 +437,308 @@ _PyObject_GetXIData(PyThreadState *tstate, } +/* pickle C-API */ + +struct _pickle_context { + PyThreadState *tstate; +}; + +static PyObject * +_PyPickle_Dumps(struct _pickle_context *ctx, PyObject *obj) +{ + PyObject *dumps = PyImport_ImportModuleAttrString("pickle", "dumps"); + if (dumps == NULL) { + return NULL; + } + PyObject *bytes = PyObject_CallOneArg(dumps, obj); + Py_DECREF(dumps); + return bytes; +} + + +struct sync_module_result { + PyObject *module; + PyObject *loaded; + PyObject *failed; +}; + +struct sync_module { + const char *filename; + char _filename[MAXPATHLEN+1]; + struct sync_module_result cached; +}; + +static void +sync_module_clear(struct sync_module *data) +{ + data->filename = NULL; + Py_CLEAR(data->cached.module); + Py_CLEAR(data->cached.loaded); + Py_CLEAR(data->cached.failed); +} + + +struct _unpickle_context { + PyThreadState *tstate; + // We only special-case the __main__ module, + // since other modules behave consistently. + struct sync_module main; +}; + +static void +_unpickle_context_clear(struct _unpickle_context *ctx) +{ + sync_module_clear(&ctx->main); +} + +static struct sync_module_result +_unpickle_context_get_module(struct _unpickle_context *ctx, + const char *modname) +{ + if (strcmp(modname, "__main__") == 0) { + return ctx->main.cached; + } + else { + return (struct sync_module_result){ + .failed = PyExc_NotImplementedError, + }; + } +} + +static struct sync_module_result +_unpickle_context_set_module(struct _unpickle_context *ctx, + const char *modname) +{ + struct sync_module_result res = {0}; + struct sync_module_result *cached = NULL; + const char *filename = NULL; + if (strcmp(modname, "__main__") == 0) { + cached = &ctx->main.cached; + filename = ctx->main.filename; + } + else { + res.failed = PyExc_NotImplementedError; + goto finally; + } + + res.module = import_get_module(ctx->tstate, modname); + if (res.module == NULL) { + res.failed = _PyErr_GetRaisedException(ctx->tstate); + assert(res.failed != NULL); + goto finally; + } + + if (filename == NULL) { + Py_CLEAR(res.module); + res.failed = PyExc_NotImplementedError; + goto finally; + } + res.loaded = runpy_run_path(filename, modname); + if (res.loaded == NULL) { + Py_CLEAR(res.module); + res.failed = _PyErr_GetRaisedException(ctx->tstate); + assert(res.failed != NULL); + goto finally; + } + +finally: + if (cached != NULL) { + assert(cached->module == NULL); + assert(cached->loaded == NULL); + assert(cached->failed == NULL); + *cached = res; + } + return res; +} + + +static int +_handle_unpickle_missing_attr(struct _unpickle_context *ctx, PyObject *exc) +{ + // The caller must check if an exception is set or not when -1 is returned. + assert(!_PyErr_Occurred(ctx->tstate)); + assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError)); + struct attributeerror_info info; + if (_parse_attributeerror(exc, &info) < 0) { + return -1; + } + + // Get the module. + struct sync_module_result mod = _unpickle_context_get_module(ctx, info.modname); + if (mod.failed != NULL) { + // It must have failed previously. + return -1; + } + if (mod.module == NULL) { + mod = _unpickle_context_set_module(ctx, info.modname); + if (mod.failed != NULL) { + return -1; + } + assert(mod.module != NULL); + } + + // Bail out if it is unexpectedly set already. + if (PyObject_HasAttrString(mod.module, info.attrname)) { + return -1; + } + + // Try setting the attribute. + PyObject *value = NULL; + if (PyDict_GetItemStringRef(mod.loaded, info.attrname, &value) <= 0) { + return -1; + } + assert(value != NULL); + int res = PyObject_SetAttrString(mod.module, info.attrname, value); + Py_DECREF(value); + if (res < 0) { + return -1; + } + + return 0; +} + +static PyObject * +_PyPickle_Loads(struct _unpickle_context *ctx, PyObject *pickled) +{ + PyObject *loads = PyImport_ImportModuleAttrString("pickle", "loads"); + if (loads == NULL) { + return NULL; + } + PyObject *obj = PyObject_CallOneArg(loads, pickled); + if (ctx != NULL) { + while (obj == NULL) { + assert(_PyErr_Occurred(ctx->tstate)); + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + // We leave other failures unhandled. + break; + } + // Try setting the attr if not set. + PyObject *exc = _PyErr_GetRaisedException(ctx->tstate); + if (_handle_unpickle_missing_attr(ctx, exc) < 0) { + // Any resulting exceptions are ignored + // in favor of the original. + _PyErr_SetRaisedException(ctx->tstate, exc); + break; + } + Py_CLEAR(exc); + // Retry with the attribute set. + obj = PyObject_CallOneArg(loads, pickled); + } + } + Py_DECREF(loads); + return obj; +} + + +/* pickle wrapper */ + +struct _pickle_xid_context { + // __main__.__file__ + struct { + const char *utf8; + size_t len; + char _utf8[MAXPATHLEN+1]; + } mainfile; +}; + +static int +_set_pickle_xid_context(PyThreadState *tstate, struct _pickle_xid_context *ctx) +{ + // Set mainfile if possible. + Py_ssize_t len = _Py_GetMainfile(ctx->mainfile._utf8, MAXPATHLEN); + if (len < 0) { + // For now we ignore any exceptions. + PyErr_Clear(); + } + else if (len > 0) { + ctx->mainfile.utf8 = ctx->mainfile._utf8; + ctx->mainfile.len = (size_t)len; + } + + return 0; +} + + +struct _shared_pickle_data { + _PyBytes_data_t pickled; // Must be first if we use _PyBytes_FromXIData(). + struct _pickle_xid_context ctx; +}; + +PyObject * +_PyPickle_LoadFromXIData(_PyXIData_t *xidata) +{ + PyThreadState *tstate = _PyThreadState_GET(); + struct _shared_pickle_data *shared = + (struct _shared_pickle_data *)xidata->data; + // We avoid copying the pickled data by wrapping it in a memoryview. + // The alternative is to get a bytes object using _PyBytes_FromXIData(). + PyObject *pickled = PyMemoryView_FromMemory( + (char *)shared->pickled.bytes, shared->pickled.len, PyBUF_READ); + if (pickled == NULL) { + return NULL; + } + + // Unpickle the object. + struct _unpickle_context ctx = { + .tstate = tstate, + .main = { + .filename = shared->ctx.mainfile.utf8, + }, + }; + PyObject *obj = _PyPickle_Loads(&ctx, pickled); + Py_DECREF(pickled); + _unpickle_context_clear(&ctx); + if (obj == NULL) { + PyObject *cause = _PyErr_GetRaisedException(tstate); + assert(cause != NULL); + _set_xid_lookup_failure( + tstate, NULL, "object could not be unpickled", cause); + Py_DECREF(cause); + } + return obj; +} + + +int +_PyPickle_GetXIData(PyThreadState *tstate, PyObject *obj, _PyXIData_t *xidata) +{ + // Pickle the object. + struct _pickle_context ctx = { + .tstate = tstate, + }; + PyObject *bytes = _PyPickle_Dumps(&ctx, obj); + if (bytes == NULL) { + PyObject *cause = _PyErr_GetRaisedException(tstate); + assert(cause != NULL); + _set_xid_lookup_failure( + tstate, NULL, "object could not be pickled", cause); + Py_DECREF(cause); + return -1; + } + + // If we had an "unwrapper" mechnanism, we could call + // _PyObject_GetXIData() on the bytes object directly and add + // a simple unwrapper to call pickle.loads() on the bytes. + size_t size = sizeof(struct _shared_pickle_data); + struct _shared_pickle_data *shared = + (struct _shared_pickle_data *)_PyBytes_GetXIDataWrapped( + tstate, bytes, size, _PyPickle_LoadFromXIData, xidata); + Py_DECREF(bytes); + if (shared == NULL) { + return -1; + } + + // If it mattered, we could skip getting __main__.__file__ + // when "__main__" doesn't show up in the pickle bytes. + if (_set_pickle_xid_context(tstate, &shared->ctx) < 0) { + _xidata_clear(xidata); + return -1; + } + + return 0; +} + + /* marshal wrapper */ PyObject * pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy