Upload production/worker.py
Browse files- production/worker.py +135 -0
production/worker.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Background worker for ml-intern production system."""
|
| 2 |
+
import asyncio
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import signal
|
| 7 |
+
import time
|
| 8 |
+
from datetime import datetime, timedelta
|
| 9 |
+
|
| 10 |
+
import aioredis
|
| 11 |
+
import asyncpg
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger("ml_intern.worker")
|
| 14 |
+
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379")
|
| 15 |
+
DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://localhost/ml_intern")
|
| 16 |
+
CLEANUP_INTERVAL_SECONDS = int(os.environ.get("CLEANUP_INTERVAL_SECONDS", "3600"))
|
| 17 |
+
BUDGET_CHECK_INTERVAL_SECONDS = int(os.environ.get("BUDGET_CHECK_INTERVAL_SECONDS", "300"))
|
| 18 |
+
|
| 19 |
+
class Worker:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.redis: aioredis.Redis = None
|
| 22 |
+
self.db: asyncpg.Pool = None
|
| 23 |
+
self.running = True
|
| 24 |
+
|
| 25 |
+
async def start(self):
|
| 26 |
+
self.redis = aioredis.from_url(REDIS_URL, decode_responses=True)
|
| 27 |
+
await self.redis.ping()
|
| 28 |
+
self.db = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=5, command_timeout=60)
|
| 29 |
+
logger.info("Worker started")
|
| 30 |
+
tasks = [
|
| 31 |
+
asyncio.create_task(self._cleanup_expired_sessions()),
|
| 32 |
+
asyncio.create_task(self._check_budget_alerts()),
|
| 33 |
+
asyncio.create_task(self._aggregate_costs()),
|
| 34 |
+
asyncio.create_task(self._refresh_analytics()),
|
| 35 |
+
]
|
| 36 |
+
await asyncio.gather(*tasks, return_exceptions=True)
|
| 37 |
+
|
| 38 |
+
async def _cleanup_expired_sessions(self):
|
| 39 |
+
while self.running:
|
| 40 |
+
try:
|
| 41 |
+
expired = await self.db.fetch(
|
| 42 |
+
"SELECT id, tenant_id FROM sessions WHERE expires_at < NOW() AND status = 'active' LIMIT 1000"
|
| 43 |
+
)
|
| 44 |
+
if expired:
|
| 45 |
+
session_ids = [str(s["id"]) for s in expired]
|
| 46 |
+
await self.db.execute(
|
| 47 |
+
"INSERT INTO archived_sessions (id, tenant_id, data, archived_at) SELECT id, tenant_id, jsonb_build_object('metadata', metadata, 'context', context_json, 'spent_usd', spent_usd, 'budget_usd', budget_usd), NOW() FROM sessions WHERE id = ANY($1)",
|
| 48 |
+
session_ids,
|
| 49 |
+
)
|
| 50 |
+
await self.db.execute("UPDATE sessions SET status = 'expired' WHERE id = ANY($1)", session_ids)
|
| 51 |
+
pipe = self.redis.pipeline()
|
| 52 |
+
for sid in session_ids:
|
| 53 |
+
pipe.delete(f"ratelimit:*:{sid}")
|
| 54 |
+
pipe.delete(f"session:{sid}:*")
|
| 55 |
+
await pipe.execute()
|
| 56 |
+
logger.info(f"Cleaned up {len(expired)} expired sessions")
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.exception(f"Error in cleanup: {e}")
|
| 59 |
+
await asyncio.sleep(CLEANUP_INTERVAL_SECONDS)
|
| 60 |
+
|
| 61 |
+
async def _check_budget_alerts(self):
|
| 62 |
+
while self.running:
|
| 63 |
+
try:
|
| 64 |
+
nearing_limit = await self.db.fetch(
|
| 65 |
+
"""SELECT id, tenant_id, budget_usd, spent_usd, (spent_usd / budget_usd) * 100 as pct_used
|
| 66 |
+
FROM sessions WHERE status = 'active' AND spent_usd > budget_usd * 0.8 AND spent_usd < budget_usd * 0.95
|
| 67 |
+
AND NOT (metadata ? 'budget_alert_80_sent')"""
|
| 68 |
+
)
|
| 69 |
+
for session in nearing_limit:
|
| 70 |
+
session_id = str(session["id"])
|
| 71 |
+
pct = session["pct_used"]
|
| 72 |
+
await self.db.execute(
|
| 73 |
+
"INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'threshold_reached', $2, $3, $3, $4)",
|
| 74 |
+
session["id"], 80, session["spent_usd"], session["budget_usd"],
|
| 75 |
+
)
|
| 76 |
+
await self.db.execute(
|
| 77 |
+
"UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_alert_80_sent}', 'true') WHERE id = $1",
|
| 78 |
+
session["id"],
|
| 79 |
+
)
|
| 80 |
+
logger.warning(f"Budget alert: session {session_id} at {pct:.1f}% of budget")
|
| 81 |
+
|
| 82 |
+
exceeded = await self.db.fetch(
|
| 83 |
+
"""SELECT id, tenant_id, budget_usd, spent_usd FROM sessions
|
| 84 |
+
WHERE status = 'active' AND spent_usd >= budget_usd AND NOT (metadata ? 'budget_exceeded')"""
|
| 85 |
+
)
|
| 86 |
+
for session in exceeded:
|
| 87 |
+
await self.db.execute(
|
| 88 |
+
"INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'exceeded', 100, $2, $2, $3)",
|
| 89 |
+
session["id"], session["spent_usd"], session["budget_usd"],
|
| 90 |
+
)
|
| 91 |
+
await self.db.execute(
|
| 92 |
+
"UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_exceeded}', 'true') WHERE id = $1",
|
| 93 |
+
session["id"],
|
| 94 |
+
)
|
| 95 |
+
logger.error(f"Budget exceeded: session {session['id']} spent ${session['spent_usd']} of ${session['budget_usd']}")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.exception(f"Error checking budgets: {e}")
|
| 98 |
+
await asyncio.sleep(BUDGET_CHECK_INTERVAL_SECONDS)
|
| 99 |
+
|
| 100 |
+
async def _aggregate_costs(self):
|
| 101 |
+
while self.running:
|
| 102 |
+
try:
|
| 103 |
+
await self.db.execute("SELECT refresh_daily_costs()")
|
| 104 |
+
logger.info("Refreshed daily cost analytics")
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.exception(f"Error aggregating costs: {e}")
|
| 107 |
+
await asyncio.sleep(3600)
|
| 108 |
+
|
| 109 |
+
async def _refresh_analytics(self):
|
| 110 |
+
while self.running:
|
| 111 |
+
try:
|
| 112 |
+
top_models = await self.db.fetch(
|
| 113 |
+
"""SELECT provider, model, COUNT(*) as requests, SUM(total_tokens) as tokens, AVG(latency_ms)::int as avg_latency
|
| 114 |
+
FROM requests WHERE created_at > NOW() - INTERVAL '24 hours'
|
| 115 |
+
GROUP BY provider, model ORDER BY requests DESC LIMIT 10"""
|
| 116 |
+
)
|
| 117 |
+
logger.info("Top models (24h):")
|
| 118 |
+
for row in top_models:
|
| 119 |
+
logger.info(f" {row['provider']}/{row['model']}: {row['requests']} requests, {row['tokens']} tokens, {row['avg_latency']}ms avg latency")
|
| 120 |
+
active = await self.db.fetchval("SELECT COUNT(*) FROM sessions WHERE status = 'active'")
|
| 121 |
+
logger.info(f"Active sessions: {active}")
|
| 122 |
+
except Exception as e:
|
| 123 |
+
logger.exception(f"Error in analytics: {e}")
|
| 124 |
+
await asyncio.sleep(300)
|
| 125 |
+
|
| 126 |
+
async def main():
|
| 127 |
+
worker = Worker()
|
| 128 |
+
loop = asyncio.get_event_loop()
|
| 129 |
+
for sig in (signal.SIGTERM, signal.SIGINT):
|
| 130 |
+
loop.add_signal_handler(sig, lambda: setattr(worker, 'running', False))
|
| 131 |
+
await worker.start()
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
|
| 135 |
+
asyncio.run(main())
|