raazkumar's picture
Upload production/worker.py
76b1148 verified
"""Background worker for ml-intern production system."""
import asyncio
import json
import logging
import os
import signal
import time
from datetime import datetime, timedelta
import aioredis
import asyncpg
logger = logging.getLogger("ml_intern.worker")
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379")
DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://localhost/ml_intern")
CLEANUP_INTERVAL_SECONDS = int(os.environ.get("CLEANUP_INTERVAL_SECONDS", "3600"))
BUDGET_CHECK_INTERVAL_SECONDS = int(os.environ.get("BUDGET_CHECK_INTERVAL_SECONDS", "300"))
class Worker:
def __init__(self):
self.redis: aioredis.Redis = None
self.db: asyncpg.Pool = None
self.running = True
async def start(self):
self.redis = aioredis.from_url(REDIS_URL, decode_responses=True)
await self.redis.ping()
self.db = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=5, command_timeout=60)
logger.info("Worker started")
tasks = [
asyncio.create_task(self._cleanup_expired_sessions()),
asyncio.create_task(self._check_budget_alerts()),
asyncio.create_task(self._aggregate_costs()),
asyncio.create_task(self._refresh_analytics()),
]
await asyncio.gather(*tasks, return_exceptions=True)
async def _cleanup_expired_sessions(self):
while self.running:
try:
expired = await self.db.fetch(
"SELECT id, tenant_id FROM sessions WHERE expires_at < NOW() AND status = 'active' LIMIT 1000"
)
if expired:
session_ids = [str(s["id"]) for s in expired]
await self.db.execute(
"INSERT INTO archived_sessions (id, tenant_id, data, archived_at) SELECT id, tenant_id, jsonb_build_object('metadata', metadata, 'context', context_json, 'spent_usd', spent_usd, 'budget_usd', budget_usd), NOW() FROM sessions WHERE id = ANY($1)",
session_ids,
)
await self.db.execute("UPDATE sessions SET status = 'expired' WHERE id = ANY($1)", session_ids)
pipe = self.redis.pipeline()
for sid in session_ids:
pipe.delete(f"ratelimit:*:{sid}")
pipe.delete(f"session:{sid}:*")
await pipe.execute()
logger.info(f"Cleaned up {len(expired)} expired sessions")
except Exception as e:
logger.exception(f"Error in cleanup: {e}")
await asyncio.sleep(CLEANUP_INTERVAL_SECONDS)
async def _check_budget_alerts(self):
while self.running:
try:
nearing_limit = await self.db.fetch(
"""SELECT id, tenant_id, budget_usd, spent_usd, (spent_usd / budget_usd) * 100 as pct_used
FROM sessions WHERE status = 'active' AND spent_usd > budget_usd * 0.8 AND spent_usd < budget_usd * 0.95
AND NOT (metadata ? 'budget_alert_80_sent')"""
)
for session in nearing_limit:
session_id = str(session["id"])
pct = session["pct_used"]
await self.db.execute(
"INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'threshold_reached', $2, $3, $3, $4)",
session["id"], 80, session["spent_usd"], session["budget_usd"],
)
await self.db.execute(
"UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_alert_80_sent}', 'true') WHERE id = $1",
session["id"],
)
logger.warning(f"Budget alert: session {session_id} at {pct:.1f}% of budget")
exceeded = await self.db.fetch(
"""SELECT id, tenant_id, budget_usd, spent_usd FROM sessions
WHERE status = 'active' AND spent_usd >= budget_usd AND NOT (metadata ? 'budget_exceeded')"""
)
for session in exceeded:
await self.db.execute(
"INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'exceeded', 100, $2, $2, $3)",
session["id"], session["spent_usd"], session["budget_usd"],
)
await self.db.execute(
"UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_exceeded}', 'true') WHERE id = $1",
session["id"],
)
logger.error(f"Budget exceeded: session {session['id']} spent ${session['spent_usd']} of ${session['budget_usd']}")
except Exception as e:
logger.exception(f"Error checking budgets: {e}")
await asyncio.sleep(BUDGET_CHECK_INTERVAL_SECONDS)
async def _aggregate_costs(self):
while self.running:
try:
await self.db.execute("SELECT refresh_daily_costs()")
logger.info("Refreshed daily cost analytics")
except Exception as e:
logger.exception(f"Error aggregating costs: {e}")
await asyncio.sleep(3600)
async def _refresh_analytics(self):
while self.running:
try:
top_models = await self.db.fetch(
"""SELECT provider, model, COUNT(*) as requests, SUM(total_tokens) as tokens, AVG(latency_ms)::int as avg_latency
FROM requests WHERE created_at > NOW() - INTERVAL '24 hours'
GROUP BY provider, model ORDER BY requests DESC LIMIT 10"""
)
logger.info("Top models (24h):")
for row in top_models:
logger.info(f" {row['provider']}/{row['model']}: {row['requests']} requests, {row['tokens']} tokens, {row['avg_latency']}ms avg latency")
active = await self.db.fetchval("SELECT COUNT(*) FROM sessions WHERE status = 'active'")
logger.info(f"Active sessions: {active}")
except Exception as e:
logger.exception(f"Error in analytics: {e}")
await asyncio.sleep(300)
async def main():
worker = Worker()
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda: setattr(worker, 'running', False))
await worker.start()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
asyncio.run(main())