File size: 5,776 Bytes
5f3e9f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Shared workflow event publishing and run lifecycle helpers."""
from __future__ import annotations

import queue
import threading
from typing import Any

from core import run_manager


_SUBSCRIBERS: dict[str, list[queue.Queue]] = {}
_LOCK = threading.RLock()


def publish_run_event(run_id: str, event: dict[str, Any]) -> None:
    payload = dict(event)
    payload.setdefault("run_id", run_id)
    with _LOCK:
        subscribers = list(_SUBSCRIBERS.get(run_id, []))
    for subscriber in subscribers:
        subscriber.put(payload)


def subscribe_run(run_id: str, replay: bool = True) -> queue.Queue:
    subscriber: queue.Queue = queue.Queue()
    with _LOCK:
        _SUBSCRIBERS.setdefault(run_id, []).append(subscriber)

    if replay:
        run = run_manager.get_run(run_id)
        if run:
            status = str(run.get("status") or "queued")
            if status == "completed":
                subscriber.put({
                    "type": "complete",
                    "run_id": run_id,
                    "operation_id": run.get("operation_id") or run_id,
                    "message": run.get("message") or "Process completed",
                    "data": {"success": True, **(run.get("outputs") or {})},
                })
                return subscriber
            if status == "failed":
                subscriber.put({
                    "type": "error",
                    "run_id": run_id,
                    "operation_id": run.get("operation_id") or run_id,
                    "message": run.get("message") or "Process failed",
                })
                return subscriber
            if status == "cancelled":
                subscriber.put({
                    "type": "cancelled",
                    "run_id": run_id,
                    "operation_id": run.get("operation_id") or run_id,
                    "message": run.get("message") or "Operation cancelled",
                })
                return subscriber
            subscriber.put({
                "type": "started" if run.get("status") == "running" else str(run.get("status") or "queued"),
                "run_id": run_id,
                "operation_id": run.get("operation_id") or run_id,
                "stage": run.get("stage"),
                "progress": run.get("progress", 0),
                "message": run.get("message") or "Process queued",
                "queue_position": run.get("queue_position"),
            })
    return subscriber


def unsubscribe_run(run_id: str, subscriber: queue.Queue) -> None:
    with _LOCK:
        subscribers = _SUBSCRIBERS.get(run_id)
        if not subscribers:
            return
        try:
            subscribers.remove(subscriber)
        except ValueError:
            return
        if not subscribers:
            _SUBSCRIBERS.pop(run_id, None)


class WorkflowCancelled(Exception):
    """Raised when the workflow cancel event is set."""


class WorkflowContext:
    """Convenience wrapper used by queued workflow jobs."""

    def __init__(self, run_id: str, operation_id: str, cancel_event: threading.Event):
        self.run_id = run_id
        self.operation_id = operation_id
        self.cancel_event = cancel_event

    def emit(self, event: dict[str, Any]) -> None:
        event.setdefault("run_id", self.run_id)
        event.setdefault("operation_id", self.operation_id)
        publish_run_event(self.run_id, event)

    def started(self, message: str = "Process started") -> None:
        run_manager.update_run(self.run_id, status="running", stage="running", message=message, progress=0)
        run_manager.add_event(self.run_id, event_type="started", stage="running", progress=0, message=message)
        self.emit({"type": "started", "message": message, "progress": 0})

    def progress(self, stage: str, progress: int, message: str, data: dict[str, Any] | None = None) -> None:
        run_manager.add_event(
            self.run_id,
            event_type="progress",
            stage=stage,
            progress=progress,
            message=message,
            data=data or {},
        )
        print(f"[{self.operation_id}] {stage} {progress}% - {message}", flush=True)
        self.emit({
            "type": "progress",
            "stage": stage,
            "progress": progress,
            "message": message,
            **(data or {}),
        })

    def output(self, key: str, value: Any) -> None:
        run_manager.attach_output(self.run_id, key, value)

    def metrics(self, values: dict[str, Any]) -> None:
        run_manager.update_metrics(self.run_id, values)

    def complete(
        self,
        message: str,
        outputs: dict[str, Any] | None = None,
        metrics: dict[str, Any] | None = None,
    ) -> None:
        run_manager.finish_run(
            self.run_id,
            status="completed",
            message=message,
            progress=100,
            outputs=outputs,
            metrics=metrics,
        )
        self.emit({
            "type": "complete",
            "message": message,
            "data": {"success": True, "message": message, **(outputs or {})},
        })

    def fail(self, message: str, progress: int | None = None) -> None:
        run_manager.finish_run(self.run_id, status="failed", message=message, progress=progress)
        self.emit({"type": "error", "message": message})

    def cancel(self, message: str = "Operation cancelled", outputs: dict[str, Any] | None = None) -> None:
        run_manager.finish_run(self.run_id, status="cancelled", message=message, progress=0, outputs=outputs)
        self.emit({"type": "cancelled", "message": message})

    def check_cancelled(self) -> None:
        if self.cancel_event.is_set():
            raise WorkflowCancelled("Operation cancelled")