"""Background worker for ml-intern production system.""" import asyncio import json import logging import os import signal import time from datetime import datetime, timedelta import aioredis import asyncpg logger = logging.getLogger("ml_intern.worker") REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379") DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://localhost/ml_intern") CLEANUP_INTERVAL_SECONDS = int(os.environ.get("CLEANUP_INTERVAL_SECONDS", "3600")) BUDGET_CHECK_INTERVAL_SECONDS = int(os.environ.get("BUDGET_CHECK_INTERVAL_SECONDS", "300")) class Worker: def __init__(self): self.redis: aioredis.Redis = None self.db: asyncpg.Pool = None self.running = True async def start(self): self.redis = aioredis.from_url(REDIS_URL, decode_responses=True) await self.redis.ping() self.db = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=5, command_timeout=60) logger.info("Worker started") tasks = [ asyncio.create_task(self._cleanup_expired_sessions()), asyncio.create_task(self._check_budget_alerts()), asyncio.create_task(self._aggregate_costs()), asyncio.create_task(self._refresh_analytics()), ] await asyncio.gather(*tasks, return_exceptions=True) async def _cleanup_expired_sessions(self): while self.running: try: expired = await self.db.fetch( "SELECT id, tenant_id FROM sessions WHERE expires_at < NOW() AND status = 'active' LIMIT 1000" ) if expired: session_ids = [str(s["id"]) for s in expired] await self.db.execute( "INSERT INTO archived_sessions (id, tenant_id, data, archived_at) SELECT id, tenant_id, jsonb_build_object('metadata', metadata, 'context', context_json, 'spent_usd', spent_usd, 'budget_usd', budget_usd), NOW() FROM sessions WHERE id = ANY($1)", session_ids, ) await self.db.execute("UPDATE sessions SET status = 'expired' WHERE id = ANY($1)", session_ids) pipe = self.redis.pipeline() for sid in session_ids: pipe.delete(f"ratelimit:*:{sid}") pipe.delete(f"session:{sid}:*") await pipe.execute() logger.info(f"Cleaned up {len(expired)} expired sessions") except Exception as e: logger.exception(f"Error in cleanup: {e}") await asyncio.sleep(CLEANUP_INTERVAL_SECONDS) async def _check_budget_alerts(self): while self.running: try: nearing_limit = await self.db.fetch( """SELECT id, tenant_id, budget_usd, spent_usd, (spent_usd / budget_usd) * 100 as pct_used FROM sessions WHERE status = 'active' AND spent_usd > budget_usd * 0.8 AND spent_usd < budget_usd * 0.95 AND NOT (metadata ? 'budget_alert_80_sent')""" ) for session in nearing_limit: session_id = str(session["id"]) pct = session["pct_used"] await self.db.execute( "INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'threshold_reached', $2, $3, $3, $4)", session["id"], 80, session["spent_usd"], session["budget_usd"], ) await self.db.execute( "UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_alert_80_sent}', 'true') WHERE id = $1", session["id"], ) logger.warning(f"Budget alert: session {session_id} at {pct:.1f}% of budget") exceeded = await self.db.fetch( """SELECT id, tenant_id, budget_usd, spent_usd FROM sessions WHERE status = 'active' AND spent_usd >= budget_usd AND NOT (metadata ? 'budget_exceeded')""" ) for session in exceeded: await self.db.execute( "INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'exceeded', 100, $2, $2, $3)", session["id"], session["spent_usd"], session["budget_usd"], ) await self.db.execute( "UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_exceeded}', 'true') WHERE id = $1", session["id"], ) logger.error(f"Budget exceeded: session {session['id']} spent ${session['spent_usd']} of ${session['budget_usd']}") except Exception as e: logger.exception(f"Error checking budgets: {e}") await asyncio.sleep(BUDGET_CHECK_INTERVAL_SECONDS) async def _aggregate_costs(self): while self.running: try: await self.db.execute("SELECT refresh_daily_costs()") logger.info("Refreshed daily cost analytics") except Exception as e: logger.exception(f"Error aggregating costs: {e}") await asyncio.sleep(3600) async def _refresh_analytics(self): while self.running: try: top_models = await self.db.fetch( """SELECT provider, model, COUNT(*) as requests, SUM(total_tokens) as tokens, AVG(latency_ms)::int as avg_latency FROM requests WHERE created_at > NOW() - INTERVAL '24 hours' GROUP BY provider, model ORDER BY requests DESC LIMIT 10""" ) logger.info("Top models (24h):") for row in top_models: logger.info(f" {row['provider']}/{row['model']}: {row['requests']} requests, {row['tokens']} tokens, {row['avg_latency']}ms avg latency") active = await self.db.fetchval("SELECT COUNT(*) FROM sessions WHERE status = 'active'") logger.info(f"Active sessions: {active}") except Exception as e: logger.exception(f"Error in analytics: {e}") await asyncio.sleep(300) async def main(): worker = Worker() loop = asyncio.get_event_loop() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, lambda: setattr(worker, 'running', False)) await worker.start() if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") asyncio.run(main())