File size: 5,464 Bytes
81ff144
 
ad8049f
81ff144
ad8049f
 
81ff144
 
 
b19d4ab
ad8049f
81ff144
 
 
 
 
 
 
ad8049f
 
 
 
 
 
 
81ff144
ad8049f
 
 
81ff144
ad8049f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81ff144
ad8049f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b19d4ab
 
 
 
 
ad8049f
 
 
 
 
 
 
 
 
 
b19d4ab
 
 
ad8049f
 
 
 
81ff144
 
 
 
 
 
 
 
 
 
 
ad8049f
 
 
 
81ff144
 
ad8049f
81ff144
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import asyncio
import logging
import os
import signal
import socket
import uuid
from services.task_queue import TaskQueueService
from services.supabase_service import supabase
from services.agent_runner_service import AgentRunnerService
from services.budget_service import BudgetExceededError
from services.config import settings

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("worker")

class AubmWorker:
    def __init__(self):
        self.running = True
        suffix = uuid.uuid4().hex[:8]
        self.worker_id = os.getenv("AUBM_WORKER_ID") or f"{socket.gethostname()}-{suffix}"
        self.lease_seconds = int(os.getenv("AUBM_WORKER_LEASE_SECONDS", "300"))
        self.max_attempts = int(os.getenv("AUBM_WORKER_MAX_ATTEMPTS", "3"))
        self.retry_delay_seconds = int(os.getenv("AUBM_WORKER_RETRY_DELAY_SECONDS", "30"))
        self.processed_count = 0
        self.failed_count = 0

    async def heartbeat(self, status: str, current_task_id: str | None = None):
        if not settings.TASK_QUEUE_HEARTBEAT_ENABLED:
            return
            
        await TaskQueueService.heartbeat(
            self.worker_id,
            status=status,
            current_task_id=current_task_id,
            processed_count=self.processed_count,
            failed_count=self.failed_count,
            metadata={
                "lease_seconds": self.lease_seconds,
                "max_attempts": self.max_attempts,
                "retry_delay_seconds": self.retry_delay_seconds,
            },
        )

    async def _heartbeat_loop(self):
        """Separate loop to send heartbeat at a fixed interval."""
        while self.running:
            try:
                # We use a longer interval for regular heartbeats
                await self.heartbeat("idle")
            except Exception as e:
                logger.warning("Background heartbeat failed: %s", e)
            await asyncio.sleep(30) # Regular heartbeat every 30 seconds

    async def start(self):
        mode_suffix = "" if settings.TASK_QUEUE_HEARTBEAT_ENABLED else " (HEARTBEAT DISABLED)"
        logger.info(f"Aubm Background Worker started{mode_suffix}: {self.worker_id}")
        
        # Start the background heartbeat task if enabled
        heartbeat_task = None
        if settings.TASK_QUEUE_HEARTBEAT_ENABLED:
            heartbeat_task = asyncio.create_task(self._heartbeat_loop())
        
        try:
            while self.running:
                task = await TaskQueueService.claim_next_queued_task(
                    self.worker_id,
                    lease_seconds=self.lease_seconds,
                    max_attempts=self.max_attempts,
                )
                
                if task:
                    task_id = task['id']
                    logger.info("Processing task: %s", task_id)
                    await self.heartbeat("processing", task_id)
                    
                    try:
                        # Fetch agent data for this task
                        agent_id = task.get("assigned_agent_id")
                        if not agent_id:
                            raise RuntimeError("No agent assigned to queued task")

                        agent_res = supabase.table("agents").select("*").eq("id", agent_id).single().execute()
                        if agent_res.data:
                            await AgentRunnerService.execute_agent_logic(task, agent_res.data)
                            await TaskQueueService.clear_lease(task_id)
                            self.processed_count += 1
                            await self.heartbeat("idle")
                            logger.info("Task %s completed successfully.", task_id)
                        else:
                            raise RuntimeError(f"Assigned agent not found: {agent_id}")
                    except BudgetExceededError as e:
                        logger.warning("Budget blocked queued task %s: %s", task_id, e)
                        self.failed_count += 1
                        await TaskQueueService.mark_failed(task_id, str(e))
                        await self.heartbeat("error")
                    except Exception as e:
                        logger.error("Failed to process task %s: %s", task_id, e)
                        self.failed_count += 1
                        await TaskQueueService.mark_attempt_failed(
                            task,
                            str(e),
                            self.max_attempts,
                            self.retry_delay_seconds,
                        )
                else:
                    # No tasks, sleep for an hour as requested (3600s)
                    # This prevents filling the DB with poll/heartbeat logs
                    await asyncio.sleep(3600)
        finally:
            if heartbeat_task:
                heartbeat_task.cancel()
            await self.heartbeat("stopping")

    def stop(self):
        logger.info("Stopping worker...")
        self.running = False

async def main():
    worker = AubmWorker()
    
    # Handle shutdown signals
    loop = asyncio.get_running_loop()
    for sig in (signal.SIGINT, signal.SIGTERM):
        try:
            loop.add_signal_handler(sig, worker.stop)
        except NotImplementedError:
            signal.signal(sig, lambda *_: worker.stop())

    await worker.start()
    await worker.heartbeat("stopping")

if __name__ == "__main__":
    asyncio.run(main())