File size: 6,872 Bytes
76b1148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Background worker for ml-intern production system."""
import asyncio
import json
import logging
import os
import signal
import time
from datetime import datetime, timedelta

import aioredis
import asyncpg

logger = logging.getLogger("ml_intern.worker")
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379")
DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://localhost/ml_intern")
CLEANUP_INTERVAL_SECONDS = int(os.environ.get("CLEANUP_INTERVAL_SECONDS", "3600"))
BUDGET_CHECK_INTERVAL_SECONDS = int(os.environ.get("BUDGET_CHECK_INTERVAL_SECONDS", "300"))

class Worker:
    def __init__(self):
        self.redis: aioredis.Redis = None
        self.db: asyncpg.Pool = None
        self.running = True
    
    async def start(self):
        self.redis = aioredis.from_url(REDIS_URL, decode_responses=True)
        await self.redis.ping()
        self.db = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=5, command_timeout=60)
        logger.info("Worker started")
        tasks = [
            asyncio.create_task(self._cleanup_expired_sessions()),
            asyncio.create_task(self._check_budget_alerts()),
            asyncio.create_task(self._aggregate_costs()),
            asyncio.create_task(self._refresh_analytics()),
        ]
        await asyncio.gather(*tasks, return_exceptions=True)
    
    async def _cleanup_expired_sessions(self):
        while self.running:
            try:
                expired = await self.db.fetch(
                    "SELECT id, tenant_id FROM sessions WHERE expires_at < NOW() AND status = 'active' LIMIT 1000"
                )
                if expired:
                    session_ids = [str(s["id"]) for s in expired]
                    await self.db.execute(
                        "INSERT INTO archived_sessions (id, tenant_id, data, archived_at) SELECT id, tenant_id, jsonb_build_object('metadata', metadata, 'context', context_json, 'spent_usd', spent_usd, 'budget_usd', budget_usd), NOW() FROM sessions WHERE id = ANY($1)",
                        session_ids,
                    )
                    await self.db.execute("UPDATE sessions SET status = 'expired' WHERE id = ANY($1)", session_ids)
                    pipe = self.redis.pipeline()
                    for sid in session_ids:
                        pipe.delete(f"ratelimit:*:{sid}")
                        pipe.delete(f"session:{sid}:*")
                    await pipe.execute()
                    logger.info(f"Cleaned up {len(expired)} expired sessions")
            except Exception as e:
                logger.exception(f"Error in cleanup: {e}")
            await asyncio.sleep(CLEANUP_INTERVAL_SECONDS)
    
    async def _check_budget_alerts(self):
        while self.running:
            try:
                nearing_limit = await self.db.fetch(
                    """SELECT id, tenant_id, budget_usd, spent_usd, (spent_usd / budget_usd) * 100 as pct_used
                       FROM sessions WHERE status = 'active' AND spent_usd > budget_usd * 0.8 AND spent_usd < budget_usd * 0.95
                       AND NOT (metadata ? 'budget_alert_80_sent')"""
                )
                for session in nearing_limit:
                    session_id = str(session["id"])
                    pct = session["pct_used"]
                    await self.db.execute(
                        "INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'threshold_reached', $2, $3, $3, $4)",
                        session["id"], 80, session["spent_usd"], session["budget_usd"],
                    )
                    await self.db.execute(
                        "UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_alert_80_sent}', 'true') WHERE id = $1",
                        session["id"],
                    )
                    logger.warning(f"Budget alert: session {session_id} at {pct:.1f}% of budget")
                
                exceeded = await self.db.fetch(
                    """SELECT id, tenant_id, budget_usd, spent_usd FROM sessions
                       WHERE status = 'active' AND spent_usd >= budget_usd AND NOT (metadata ? 'budget_exceeded')"""
                )
                for session in exceeded:
                    await self.db.execute(
                        "INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'exceeded', 100, $2, $2, $3)",
                        session["id"], session["spent_usd"], session["budget_usd"],
                    )
                    await self.db.execute(
                        "UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_exceeded}', 'true') WHERE id = $1",
                        session["id"],
                    )
                    logger.error(f"Budget exceeded: session {session['id']} spent ${session['spent_usd']} of ${session['budget_usd']}")
            except Exception as e:
                logger.exception(f"Error checking budgets: {e}")
            await asyncio.sleep(BUDGET_CHECK_INTERVAL_SECONDS)
    
    async def _aggregate_costs(self):
        while self.running:
            try:
                await self.db.execute("SELECT refresh_daily_costs()")
                logger.info("Refreshed daily cost analytics")
            except Exception as e:
                logger.exception(f"Error aggregating costs: {e}")
            await asyncio.sleep(3600)
    
    async def _refresh_analytics(self):
        while self.running:
            try:
                top_models = await self.db.fetch(
                    """SELECT provider, model, COUNT(*) as requests, SUM(total_tokens) as tokens, AVG(latency_ms)::int as avg_latency
                       FROM requests WHERE created_at > NOW() - INTERVAL '24 hours'
                       GROUP BY provider, model ORDER BY requests DESC LIMIT 10"""
                )
                logger.info("Top models (24h):")
                for row in top_models:
                    logger.info(f"  {row['provider']}/{row['model']}: {row['requests']} requests, {row['tokens']} tokens, {row['avg_latency']}ms avg latency")
                active = await self.db.fetchval("SELECT COUNT(*) FROM sessions WHERE status = 'active'")
                logger.info(f"Active sessions: {active}")
            except Exception as e:
                logger.exception(f"Error in analytics: {e}")
            await asyncio.sleep(300)

async def main():
    worker = Worker()
    loop = asyncio.get_event_loop()
    for sig in (signal.SIGTERM, signal.SIGINT):
        loop.add_signal_handler(sig, lambda: setattr(worker, 'running', False))
    await worker.start()

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
    asyncio.run(main())