diff --git a/taskiq/api/scheduler.py b/taskiq/api/scheduler.py index 97c59c78..6928b128 100644 --- a/taskiq/api/scheduler.py +++ b/taskiq/api/scheduler.py @@ -1,3 +1,6 @@ +from datetime import timedelta +from typing import Optional + from taskiq.cli.scheduler.run import run_scheduler_loop from taskiq.scheduler.scheduler import TaskiqScheduler @@ -5,6 +8,7 @@ async def run_scheduler_task( scheduler: TaskiqScheduler, run_startup: bool = False, + interval: Optional[timedelta] = None, ) -> None: """ Run scheduler task. @@ -20,4 +24,4 @@ async def run_scheduler_task( if run_startup: await scheduler.startup() while True: - await run_scheduler_loop(scheduler) + await run_scheduler_loop(scheduler, interval) diff --git a/taskiq/cli/scheduler/args.py b/taskiq/cli/scheduler/args.py index 1850f360..d1f6d821 100644 --- a/taskiq/cli/scheduler/args.py +++ b/taskiq/cli/scheduler/args.py @@ -17,6 +17,7 @@ class SchedulerArgs: fs_discover: bool = False tasks_pattern: Sequence[str] = ("**/tasks.py",) skip_first_run: bool = False + update_interval: Optional[int] = None @classmethod def from_cli(cls, args: Optional[Sequence[str]] = None) -> "SchedulerArgs": @@ -80,6 +81,15 @@ def from_cli(cls, args: Optional[Sequence[str]] = None) -> "SchedulerArgs": "This option skips running tasks immediately after scheduler start." ), ) + parser.add_argument( + "--update-interval", + type=int, + default=None, + help=( + "Interval in seconds to check for new tasks. " + "If not specified, scheduler will run once a minute." + ), + ) namespace = parser.parse_args(args) # If there are any patterns specified, remove default. diff --git a/taskiq/cli/scheduler/run.py b/taskiq/cli/scheduler/run.py index b1487599..09467dc0 100644 --- a/taskiq/cli/scheduler/run.py +++ b/taskiq/cli/scheduler/run.py @@ -3,7 +3,7 @@ import sys from datetime import datetime, timedelta from logging import basicConfig, getLevelName, getLogger -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Set, Tuple import pytz from pycron import is_now @@ -55,7 +55,7 @@ async def get_schedules(source: ScheduleSource) -> List[ScheduledTask]: async def get_all_schedules( scheduler: TaskiqScheduler, -) -> Dict[ScheduleSource, List[ScheduledTask]]: +) -> List[Tuple[ScheduleSource, List[ScheduledTask]]]: """ Task to update all schedules. @@ -71,7 +71,7 @@ async def get_all_schedules( schedules = await asyncio.gather( *[get_schedules(source) for source in scheduler.sources], ) - return dict(zip(scheduler.sources, schedules)) + return list(zip(scheduler.sources, schedules)) def get_task_delay(task: ScheduledTask) -> Optional[int]: @@ -98,12 +98,10 @@ def get_task_delay(task: ScheduledTask) -> Optional[int]: task_time = to_tz_aware(task.time) if task_time <= now: return 0 - one_min_ahead = (now + timedelta(minutes=1)).replace(second=1, microsecond=0) - if task_time <= one_min_ahead: - delay = task_time - now - if delay.microseconds: - return int(delay.total_seconds()) + 1 - return int(delay.total_seconds()) + delay = task_time - now + if delay.microseconds: + return int(delay.total_seconds()) + 1 + return int(delay.total_seconds()) return None @@ -145,7 +143,10 @@ async def delayed_send( await scheduler.on_ready(source, task) -async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None: +async def run_scheduler_loop( # noqa: C901 + scheduler: TaskiqScheduler, + interval: Optional[timedelta] = None, +) -> None: """ Runs scheduler loop. @@ -153,13 +154,30 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None: and runs tasks when needed. :param scheduler: current scheduler. + :param interval: interval to check for schedule updates. """ loop = asyncio.get_event_loop() - running_schedules = set() + running_schedules: Dict[str, asyncio.Task[Any]] = {} + ran_cron_jobs: Set[str] = set() + current_minute = datetime.now(tz=pytz.UTC).minute while True: - # We use this method to correctly sleep for one minute. + now = datetime.now(tz=pytz.UTC) + # If minute changed, we need to clear + # ran_cron_jobs set and update current minute. + if now.minute != current_minute: + current_minute = now.minute + ran_cron_jobs.clear() + # If interval is not None, we need to + # calculate next run time using it. + if interval is not None: + next_run = now + interval + # otherwise we need assume that + # we will run it at the start of the next minute. + # as crontab does. + else: + next_run = (now + timedelta(minutes=1)).replace(second=1, microsecond=0) scheduled_tasks = await get_all_schedules(scheduler) - for source, task_list in scheduled_tasks.items(): + for source, task_list in scheduled_tasks: logger.debug("Got %d schedules from source %s.", len(task_list), source) for task in task_list: try: @@ -172,16 +190,37 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None: task.schedule_id, ) continue - if task_delay is not None: - send_task = loop.create_task( - delayed_send(scheduler, source, task, task_delay), - ) - running_schedules.add(send_task) - send_task.add_done_callback(running_schedules.discard) - next_minute = datetime.now().replace(second=0, microsecond=0) + timedelta( - minutes=1, - ) - delay = next_minute - datetime.now() + # If task delay is None, we don't need to run it. + if task_delay is None: + continue + # If task is delayed for more than next_run, + # we don't need to run it, because we will + # run it in the next iteration. + if now + timedelta(seconds=task_delay) >= next_run: + continue + # If task is already running, we don't need to run it again. + if task.schedule_id in running_schedules and task_delay < 1: + continue + # If task is cron job, we need to check if + # we already ran it this minute. + if task.cron is not None: + if task.schedule_id in ran_cron_jobs: + continue + ran_cron_jobs.add(task.schedule_id) + send_task = loop.create_task( + delayed_send(scheduler, source, task, task_delay), + # We need to set the name of the task + # to be able to discard its reference + # after it is done. + name=f"schedule_{task.schedule_id}", + ) + running_schedules[task.schedule_id] = send_task + send_task.add_done_callback( + lambda task_future: running_schedules.pop( + task_future.get_name().removeprefix("schedule_"), + ), + ) + delay = next_run - datetime.now(tz=pytz.UTC) logger.debug( "Sleeping for %.2f seconds before getting schedules.", delay.total_seconds(), @@ -226,6 +265,10 @@ async def run_scheduler(args: SchedulerArgs) -> None: for source in scheduler.sources: await source.startup() + interval = None + if args.update_interval: + interval = timedelta(seconds=args.update_interval) + logger.info("Starting scheduler.") await scheduler.startup() logger.info("Startup completed.") @@ -239,7 +282,7 @@ async def run_scheduler(args: SchedulerArgs) -> None: await asyncio.sleep(delay.total_seconds()) logger.info("First run skipped. The scheduler is now running.") try: - await run_scheduler_loop(scheduler) + await run_scheduler_loop(scheduler, interval) except asyncio.CancelledError: logger.warning("Shutting down scheduler.") await scheduler.shutdown() diff --git a/taskiq/schedule_sources/label_based.py b/taskiq/schedule_sources/label_based.py index 1f313717..94fd42de 100644 --- a/taskiq/schedule_sources/label_based.py +++ b/taskiq/schedule_sources/label_based.py @@ -1,5 +1,6 @@ +import uuid from logging import getLogger -from typing import List +from typing import Dict, List from taskiq.abc.broker import AsyncBroker from taskiq.abc.schedule_source import ScheduleSource @@ -13,20 +14,26 @@ class LabelScheduleSource(ScheduleSource): def __init__(self, broker: AsyncBroker) -> None: self.broker = broker + self.schedules: Dict[str, ScheduledTask] = {} - async def get_schedules(self) -> List["ScheduledTask"]: + async def startup(self) -> None: """ - Collect schedules for all tasks. - - this function checks labels for all - tasks available to the broker. + Startup the schedule source. + This function iterates over all tasks + available to the broker and collects + schedules from their labels. If task has a schedule label, - it will be parsed and returned. + it will be parsed and added to the + scheduler list. - :return: list of schedules. + Every time schedule is added, the random + schedule id is generated. Please be aware that + they are different for every startup. + + :return: None """ - schedules = [] + self.schedules.clear() for task_name, task in self.broker.get_all_tasks().items(): if task.broker != self.broker: # if task broker doesn't match self, something is probably wrong @@ -40,20 +47,36 @@ async def get_schedules(self) -> List["ScheduledTask"]: continue labels = schedule.get("labels", {}) labels.update(task.labels) - schedules.append( - ScheduledTask( - task_name=task_name, - labels=labels, - args=schedule.get("args", []), - kwargs=schedule.get("kwargs", {}), - cron=schedule.get("cron"), - time=schedule.get("time"), - cron_offset=schedule.get("cron_offset"), - ), + schedule_id = uuid.uuid4().hex + + self.schedules[schedule_id] = ScheduledTask( + task_name=task_name, + labels=labels, + schedule_id=schedule_id, + args=schedule.get("args", []), + kwargs=schedule.get("kwargs", {}), + cron=schedule.get("cron"), + time=schedule.get("time"), + cron_offset=schedule.get("cron_offset"), ) - return schedules - def post_send(self, scheduled_task: ScheduledTask) -> None: + return await super().startup() + + async def get_schedules(self) -> List["ScheduledTask"]: + """ + Collect schedules for all tasks. + + this function checks labels for all + tasks available to the broker. + + If task has a schedule label, + it will be parsed and returned. + + :return: list of schedules. + """ + return list(self.schedules.values()) + + def post_send(self, task: "ScheduledTask") -> None: """ Remove `time` schedule from task's scheduler list. @@ -62,22 +85,7 @@ def post_send(self, scheduled_task: ScheduledTask) -> None: :param scheduled_task: task that just have sent """ - if scheduled_task.cron or not scheduled_task.time: + if task.cron or not task.time: return # it's scheduled task with cron label, do not remove this trigger. - for task_name, task in self.broker.get_all_tasks().items(): - if task.broker != self.broker: - # if task broker doesn't match self, something is probably wrong - logger.warning( - f"Broker for {task_name} `{task.broker}` doesn't " - f"match scheduler's broker `{self.broker}`", - ) - continue - if scheduled_task.task_name != task_name: - continue - - schedule_list = task.labels.get("schedule", []).copy() - for idx, schedule in enumerate(schedule_list): - if schedule.get("time") == scheduled_task.time: - task.labels.get("schedule", []).pop(idx) - return + self.schedules.pop(task.schedule_id, None) diff --git a/tests/cli/scheduler/test_task_delays.py b/tests/cli/scheduler/test_task_delays.py index 1af3c783..2e00fe37 100644 --- a/tests/cli/scheduler/test_task_delays.py +++ b/tests/cli/scheduler/test_task_delays.py @@ -9,7 +9,7 @@ def test_should_run_success() -> None: - hour = datetime.datetime.utcnow().hour + hour = datetime.datetime.now(datetime.timezone.utc).hour delay = get_task_delay( ScheduledTask( task_name="", @@ -97,18 +97,26 @@ def test_time_utc_with_local_zone() -> None: assert delay is not None and delay >= 0 +@freeze_time("2023-01-14 12:00:00") def test_time_localtime_without_zone() -> None: time = datetime.datetime.now(tz=pytz.FixedOffset(240)).replace(tzinfo=None) + time_to_run = time - datetime.timedelta(seconds=1) + delay = get_task_delay( ScheduledTask( task_name="", labels={}, args=[], kwargs={}, - time=time - datetime.timedelta(seconds=1), + time=time_to_run, ), ) - assert delay is None + + expected_delay = time_to_run.replace(tzinfo=pytz.UTC) - datetime.datetime.now( + pytz.UTC, + ) + + assert delay == int(expected_delay.total_seconds()) @freeze_time("2023-01-14 12:00:00") diff --git a/tests/cli/scheduler/test_updater.py b/tests/cli/scheduler/test_updater.py index c2a7b9e5..2ac9ef8d 100644 --- a/tests/cli/scheduler/test_updater.py +++ b/tests/cli/scheduler/test_updater.py @@ -56,10 +56,10 @@ async def test_get_schedules_success() -> None: schedules = await get_all_schedules( TaskiqScheduler(InMemoryBroker(), sources), ) - assert schedules == { - sources[0]: schedules1, - sources[1]: schedules2, - } + assert schedules == [ + (sources[0], schedules1), + (sources[1], schedules2), + ] @pytest.mark.anyio @@ -81,7 +81,7 @@ async def test_get_schedules_error() -> None: schedules = await get_all_schedules( TaskiqScheduler(InMemoryBroker(), [source1, source2]), ) - assert schedules == { - source1: source1.schedules, - source2: [], - } + assert schedules == [ + (source1, source1.schedules), + (source2, []), + ] diff --git a/tests/schedule_sources/test_label_based.py b/tests/schedule_sources/test_label_based.py index 9e683917..fa621b07 100644 --- a/tests/schedule_sources/test_label_based.py +++ b/tests/schedule_sources/test_label_based.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List import pytest +import pytz from taskiq.brokers.inmemory_broker import InMemoryBroker from taskiq.schedule_sources.label_based import LabelScheduleSource @@ -13,7 +14,7 @@ "schedule_label", [ pytest.param([{"cron": "* * * * *"}], id="cron"), - pytest.param([{"time": datetime.utcnow()}], id="time"), + pytest.param([{"time": datetime.now(pytz.UTC)}], id="time"), ], ) async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None: @@ -27,6 +28,7 @@ def task() -> None: pass source = LabelScheduleSource(broker) + await source.startup() schedules = await source.get_schedules() assert schedules == [ ScheduledTask( @@ -53,5 +55,6 @@ def task() -> None: pass source = LabelScheduleSource(broker) + await source.startup() schedules = await source.get_schedules() assert schedules == [] diff --git a/tests/scheduler/test_label_based_sched.py b/tests/scheduler/test_label_based_sched.py index 156e8498..506990b0 100644 --- a/tests/scheduler/test_label_based_sched.py +++ b/tests/scheduler/test_label_based_sched.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List import pytest +import pytz from freezegun import freeze_time from taskiq.brokers.inmemory_broker import InMemoryBroker @@ -18,7 +19,7 @@ "schedule_label", [ pytest.param([{"cron": "* * * * *"}], id="cron"), - pytest.param([{"time": datetime.utcnow()}], id="time"), + pytest.param([{"time": datetime.now(pytz.UTC)}], id="time"), ], ) async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None: @@ -31,7 +32,9 @@ async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None: def task() -> None: pass - schedules = await LabelScheduleSource(broker).get_schedules() + source = LabelScheduleSource(broker) + await source.startup() + schedules = await source.get_schedules() assert schedules == [ ScheduledTask( schedule_id=schedules[0].schedule_id, @@ -57,6 +60,7 @@ def task() -> None: pass source = LabelScheduleSource(broker) + await source.startup() schedules = await source.get_schedules() assert schedules == [] @@ -69,6 +73,8 @@ async def test_task_scheduled_at_time_runs_only_once(mock_sleep: None) -> None: broker=broker, sources=[LabelScheduleSource(broker)], ) + for source in scheduler.sources: + await source.startup() # NOTE: # freeze time to 00:00, so task won't be scheduled by `cron`, only by `time` @@ -77,8 +83,8 @@ async def test_task_scheduled_at_time_runs_only_once(mock_sleep: None) -> None: @broker.task( task_name="test_task", schedule=[ - {"time": datetime.utcnow(), "args": [1]}, - {"time": datetime.utcnow() + timedelta(days=1), "args": [2]}, + {"time": datetime.now(pytz.UTC), "args": [1]}, + {"time": datetime.now(pytz.UTC) + timedelta(days=1), "args": [2]}, {"cron": "1 * * * *", "args": [3]}, ], ) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy