diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 71c6cc1ed9f..86983960849 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -5,6 +5,7 @@ import functools import logging import os +import shutil import socket import subprocess import sys @@ -12,7 +13,13 @@ import uuid import zipfile from collections.abc import Awaitable +from contextlib import contextmanager +from importlib.util import find_spec +from io import BytesIO from typing import TYPE_CHECKING, Any, Callable, ClassVar +from types import ModuleType +from typing import Any, Tuple +from pathlib import Path from dask.typing import Key from dask.utils import funcname, tmpfile @@ -29,6 +36,7 @@ from distributed.scheduler import TaskStateState as SchedulerTaskStateState from distributed.worker import Worker from distributed.worker_state_machine import TaskStateState as WorkerTaskStateState + from distributed.node import ServerNode logger = logging.getLogger(__name__) @@ -1051,3 +1059,165 @@ def setup(self, worker): def teardown(self, worker): self._exit_stack.close() + + +@contextmanager +def serialize_module( + module: ModuleType, exclude: Tuple[str] = ("__pycache__", ".DS_Store") +) -> Path: + module_path = Path(module.__file__) + + if module_path.stem == "__init__": + # In case of package we serialize the whole package + module_path = module_path.parent + if "." in module.__name__: + # TODO: the problem is that we serialize the `package.module`, as module.egg that contains module.py, + # but it should contain the whole structure of the package (package/module.py) + raise Exception( + f"Plugin supports only top-level packages or single-file modules. You provided `{module.__name__}`, try `{module.__name__.split('.')[0]}`." + ) + + # In case of single file we don't need to serialize anything + + with tempfile.TemporaryDirectory() as tmp: + package_name = module_path.name + + package_copy_path = Path(tmp).joinpath(package_name) + if module_path.is_dir(): + copied_package = Path( + shutil.copytree( + module_path, + package_copy_path, + ignore=shutil.ignore_patterns(f"{package_name}.zip", *exclude), + ) + ) + else: + copied_package = Path(shutil.copy2(module_path, package_copy_path)) + + archive_path = shutil.make_archive( + # output path including a name w/o extension + base_name=str(copied_package), + format="zip", + # chroot + root_dir=copied_package.parent, + # Name of the directory to archive and a common prefix of all files and directories in the archive + base_dir=package_name, + ) + + egg_file = shutil.move(archive_path, package_copy_path.with_suffix(".egg")) + + # zip file handler + zip = zipfile.ZipFile(egg_file) + # list available files in the container + logger.debug( + "The egg file %s contains the following files %s", + str(egg_file), + str(zip.namelist()), + ) + + logger.info("Created an egg file %s from %s", str(egg_file), str(module_path)) + + yield Path(egg_file) + + +class AbstractUploadModulePlugin: + def __init__(self, module: ModuleType): + self._module_name = module.__name__ + self._data: bytes + self._filepath: Path + self._filename: str + with serialize_module(module) as filepath: + self._filename = filepath.name + with open(filepath, "rb") as f: + self._data = f.read() + + async def _upload_file(self, node: ServerNode): + response = await node.upload_file(self._filename, self._data, load=True) + assert len(self._data) == response["nbytes"] + + async def _upload(self, node: ServerNode): + import zipfile + import sys + try: + from IPython.extensions.autoreload import superreload + except ImportError: + superreload = lambda x: x + + # Try to find already loaded module + module = ( + sys.modules[self._module_name] if self._module_name in sys.modules else None + ) + # Try to find module on disk + module_spec = find_spec(self._module_name) + + if not module_spec and not module: + # If module does not exist we keep it as egg file and load it. + logger.info( + 'Uploading a new module "%s" to "%s" on %s "%s"', + self._module_name, + str(self._filename), + "worker" if isinstance(node, Worker) else "scheduler", + node.id, + ) + await self._upload_file(node) + return + + if module: + module_path = self._get_module_dir(module) + else: + module_path = Path(module_spec.origin) + + if ".egg" in str(module_path): + # Update the previously uploaded egg module and reload it. + logger.info( + 'Uploading an update for a previously uploaded a new module "%s" to "%s" on %s "%s"', + self._module_name, + str(self._filename), + "worker" if isinstance(node, Worker) else "scheduler", + node.id, + ) + await self._upload_file(node) + return + + with zipfile.ZipFile(BytesIO(self._data), "r") as zip_ref: + # In case, we received egg file for module that exists on node in source code, + # we overwrite each file separately by extracting it from the egg. + logger.info( + 'Uploading an update for an existing module "%s" in "%s" on %s "%s"', + self._module_name, + str(module_path.parent), + "worker" if isinstance(node, Worker) else "scheduler", + node.id, + ) + zip_ref.extractall(module_path.parent) + + # TODO: Do we really need Jupyter's `superreload` here instead of built-in Python's function? + if self._module_name in sys.modules: + # Reload module if it is already loaded + superreload(sys.modules[self._module_name]) + + @classmethod + def _get_module_dir(cls, module: ModuleType) -> Path: + """Get the directory of the module.""" + module_path = Path(sys.modules[module.__name__].__file__) + + if module_path.stem == "__init__": + # In case of package we serialize the whole package + return module_path.parent + + # In case of single file we don't need to serialize anything + return module_path + + +class UploadModule(WorkerPlugin, AbstractUploadModulePlugin): + name = "upload_module" + + async def setup(self, worker: Worker): + await self._upload(worker) + + +class SchedulerUploadModule(SchedulerPlugin, AbstractUploadModulePlugin): + name = "upload_module" + + async def start(self, scheduler: Scheduler) -> None: + await self._upload(scheduler) \ No newline at end of file
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: