raazkumar commited on
Commit
76b1148
·
verified ·
1 Parent(s): 1279000

Upload production/worker.py

Browse files
Files changed (1) hide show
  1. production/worker.py +135 -0
production/worker.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Background worker for ml-intern production system."""
2
+ import asyncio
3
+ import json
4
+ import logging
5
+ import os
6
+ import signal
7
+ import time
8
+ from datetime import datetime, timedelta
9
+
10
+ import aioredis
11
+ import asyncpg
12
+
13
+ logger = logging.getLogger("ml_intern.worker")
14
+ REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379")
15
+ DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://localhost/ml_intern")
16
+ CLEANUP_INTERVAL_SECONDS = int(os.environ.get("CLEANUP_INTERVAL_SECONDS", "3600"))
17
+ BUDGET_CHECK_INTERVAL_SECONDS = int(os.environ.get("BUDGET_CHECK_INTERVAL_SECONDS", "300"))
18
+
19
+ class Worker:
20
+ def __init__(self):
21
+ self.redis: aioredis.Redis = None
22
+ self.db: asyncpg.Pool = None
23
+ self.running = True
24
+
25
+ async def start(self):
26
+ self.redis = aioredis.from_url(REDIS_URL, decode_responses=True)
27
+ await self.redis.ping()
28
+ self.db = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=5, command_timeout=60)
29
+ logger.info("Worker started")
30
+ tasks = [
31
+ asyncio.create_task(self._cleanup_expired_sessions()),
32
+ asyncio.create_task(self._check_budget_alerts()),
33
+ asyncio.create_task(self._aggregate_costs()),
34
+ asyncio.create_task(self._refresh_analytics()),
35
+ ]
36
+ await asyncio.gather(*tasks, return_exceptions=True)
37
+
38
+ async def _cleanup_expired_sessions(self):
39
+ while self.running:
40
+ try:
41
+ expired = await self.db.fetch(
42
+ "SELECT id, tenant_id FROM sessions WHERE expires_at < NOW() AND status = 'active' LIMIT 1000"
43
+ )
44
+ if expired:
45
+ session_ids = [str(s["id"]) for s in expired]
46
+ await self.db.execute(
47
+ "INSERT INTO archived_sessions (id, tenant_id, data, archived_at) SELECT id, tenant_id, jsonb_build_object('metadata', metadata, 'context', context_json, 'spent_usd', spent_usd, 'budget_usd', budget_usd), NOW() FROM sessions WHERE id = ANY($1)",
48
+ session_ids,
49
+ )
50
+ await self.db.execute("UPDATE sessions SET status = 'expired' WHERE id = ANY($1)", session_ids)
51
+ pipe = self.redis.pipeline()
52
+ for sid in session_ids:
53
+ pipe.delete(f"ratelimit:*:{sid}")
54
+ pipe.delete(f"session:{sid}:*")
55
+ await pipe.execute()
56
+ logger.info(f"Cleaned up {len(expired)} expired sessions")
57
+ except Exception as e:
58
+ logger.exception(f"Error in cleanup: {e}")
59
+ await asyncio.sleep(CLEANUP_INTERVAL_SECONDS)
60
+
61
+ async def _check_budget_alerts(self):
62
+ while self.running:
63
+ try:
64
+ nearing_limit = await self.db.fetch(
65
+ """SELECT id, tenant_id, budget_usd, spent_usd, (spent_usd / budget_usd) * 100 as pct_used
66
+ FROM sessions WHERE status = 'active' AND spent_usd > budget_usd * 0.8 AND spent_usd < budget_usd * 0.95
67
+ AND NOT (metadata ? 'budget_alert_80_sent')"""
68
+ )
69
+ for session in nearing_limit:
70
+ session_id = str(session["id"])
71
+ pct = session["pct_used"]
72
+ await self.db.execute(
73
+ "INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'threshold_reached', $2, $3, $3, $4)",
74
+ session["id"], 80, session["spent_usd"], session["budget_usd"],
75
+ )
76
+ await self.db.execute(
77
+ "UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_alert_80_sent}', 'true') WHERE id = $1",
78
+ session["id"],
79
+ )
80
+ logger.warning(f"Budget alert: session {session_id} at {pct:.1f}% of budget")
81
+
82
+ exceeded = await self.db.fetch(
83
+ """SELECT id, tenant_id, budget_usd, spent_usd FROM sessions
84
+ WHERE status = 'active' AND spent_usd >= budget_usd AND NOT (metadata ? 'budget_exceeded')"""
85
+ )
86
+ for session in exceeded:
87
+ await self.db.execute(
88
+ "INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'exceeded', 100, $2, $2, $3)",
89
+ session["id"], session["spent_usd"], session["budget_usd"],
90
+ )
91
+ await self.db.execute(
92
+ "UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_exceeded}', 'true') WHERE id = $1",
93
+ session["id"],
94
+ )
95
+ logger.error(f"Budget exceeded: session {session['id']} spent ${session['spent_usd']} of ${session['budget_usd']}")
96
+ except Exception as e:
97
+ logger.exception(f"Error checking budgets: {e}")
98
+ await asyncio.sleep(BUDGET_CHECK_INTERVAL_SECONDS)
99
+
100
+ async def _aggregate_costs(self):
101
+ while self.running:
102
+ try:
103
+ await self.db.execute("SELECT refresh_daily_costs()")
104
+ logger.info("Refreshed daily cost analytics")
105
+ except Exception as e:
106
+ logger.exception(f"Error aggregating costs: {e}")
107
+ await asyncio.sleep(3600)
108
+
109
+ async def _refresh_analytics(self):
110
+ while self.running:
111
+ try:
112
+ top_models = await self.db.fetch(
113
+ """SELECT provider, model, COUNT(*) as requests, SUM(total_tokens) as tokens, AVG(latency_ms)::int as avg_latency
114
+ FROM requests WHERE created_at > NOW() - INTERVAL '24 hours'
115
+ GROUP BY provider, model ORDER BY requests DESC LIMIT 10"""
116
+ )
117
+ logger.info("Top models (24h):")
118
+ for row in top_models:
119
+ logger.info(f" {row['provider']}/{row['model']}: {row['requests']} requests, {row['tokens']} tokens, {row['avg_latency']}ms avg latency")
120
+ active = await self.db.fetchval("SELECT COUNT(*) FROM sessions WHERE status = 'active'")
121
+ logger.info(f"Active sessions: {active}")
122
+ except Exception as e:
123
+ logger.exception(f"Error in analytics: {e}")
124
+ await asyncio.sleep(300)
125
+
126
+ async def main():
127
+ worker = Worker()
128
+ loop = asyncio.get_event_loop()
129
+ for sig in (signal.SIGTERM, signal.SIGINT):
130
+ loop.add_signal_handler(sig, lambda: setattr(worker, 'running', False))
131
+ await worker.start()
132
+
133
+ if __name__ == "__main__":
134
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
135
+ asyncio.run(main())