| import logging |
| import threading |
| import time |
| from dataclasses import dataclass, field |
| from typing import Callable, Optional |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class Job: |
| name: str |
| fn: Callable |
| interval: float |
| max_runs: int = 0 |
| on_error: Optional[Callable[[Exception], None]] = None |
| _runs: int = field(default=0, init=False, repr=False) |
| _last: float = field(default=0.0, init=False, repr=False) |
| _errors: int = field(default=0, init=False, repr=False) |
|
|
| def due(self, now: float) -> bool: |
| return now - self._last >= self.interval |
|
|
| def run(self) -> None: |
| try: |
| self.fn() |
| self._runs += 1 |
| self._last = time.time() |
| logger.debug("Job '%s' run #%d OK", self.name, self._runs) |
| except Exception as exc: |
| self._errors += 1 |
| logger.error("Job '%s' error #%d: %s", self.name, self._errors, exc) |
| if self.on_error: |
| try: |
| self.on_error(exc) |
| except Exception: |
| pass |
|
|
| def exhausted(self) -> bool: |
| return self.max_runs > 0 and self._runs >= self.max_runs |
|
|
| def stats(self) -> dict: |
| return {"name": self.name, "runs": self._runs, |
| "errors": self._errors, "last": self._last} |
|
|
|
|
| class Scheduler: |
| def __init__(self, tick: float = 1.0) -> None: |
| self._jobs: list[Job] = [] |
| self._tick = tick |
| self._running = False |
| self._lock = threading.Lock() |
| self._thread: Optional[threading.Thread] = None |
|
|
| def add(self, name: str, fn: Callable, interval: float, |
| max_runs: int = 0, on_error: Optional[Callable] = None) -> Job: |
| job = Job(name=name, fn=fn, interval=interval, |
| max_runs=max_runs, on_error=on_error) |
| with self._lock: |
| self._jobs.append(job) |
| logger.info("Scheduled '%s' every %.1fs", name, interval) |
| return job |
|
|
| def remove(self, name: str) -> bool: |
| with self._lock: |
| before = len(self._jobs) |
| self._jobs = [j for j in self._jobs if j.name != name] |
| return len(self._jobs) < before |
|
|
| def _tick_once(self) -> None: |
| now = time.time() |
| with self._lock: |
| jobs = list(self._jobs) |
| done = [] |
| for job in jobs: |
| if job.due(now): |
| job.run() |
| if job.exhausted(): |
| done.append(job.name) |
| if done: |
| with self._lock: |
| self._jobs = [j for j in self._jobs if j.name not in done] |
|
|
| def _loop(self) -> None: |
| while self._running: |
| self._tick_once() |
| time.sleep(self._tick) |
|
|
| def start(self) -> None: |
| if self._running: |
| return |
| self._running = True |
| self._thread = threading.Thread(target=self._loop, daemon=True, |
| name="Scheduler") |
| self._thread.start() |
| logger.info("Scheduler started (tick=%.1fs)", self._tick) |
|
|
| def stop(self, timeout: float = 5.0) -> None: |
| self._running = False |
| if self._thread: |
| self._thread.join(timeout) |
| logger.info("Scheduler stopped") |
|
|
| def all_stats(self) -> list[dict]: |
| with self._lock: |
| return [j.stats() for j in self._jobs] |
|
|