| """Background worker for ml-intern production system.""" |
| import asyncio |
| import json |
| import logging |
| import os |
| import signal |
| import time |
| from datetime import datetime, timedelta |
|
|
| import aioredis |
| import asyncpg |
|
|
| logger = logging.getLogger("ml_intern.worker") |
| REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379") |
| DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://localhost/ml_intern") |
| CLEANUP_INTERVAL_SECONDS = int(os.environ.get("CLEANUP_INTERVAL_SECONDS", "3600")) |
| BUDGET_CHECK_INTERVAL_SECONDS = int(os.environ.get("BUDGET_CHECK_INTERVAL_SECONDS", "300")) |
|
|
| class Worker: |
| def __init__(self): |
| self.redis: aioredis.Redis = None |
| self.db: asyncpg.Pool = None |
| self.running = True |
| |
| async def start(self): |
| self.redis = aioredis.from_url(REDIS_URL, decode_responses=True) |
| await self.redis.ping() |
| self.db = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=5, command_timeout=60) |
| logger.info("Worker started") |
| tasks = [ |
| asyncio.create_task(self._cleanup_expired_sessions()), |
| asyncio.create_task(self._check_budget_alerts()), |
| asyncio.create_task(self._aggregate_costs()), |
| asyncio.create_task(self._refresh_analytics()), |
| ] |
| await asyncio.gather(*tasks, return_exceptions=True) |
| |
| async def _cleanup_expired_sessions(self): |
| while self.running: |
| try: |
| expired = await self.db.fetch( |
| "SELECT id, tenant_id FROM sessions WHERE expires_at < NOW() AND status = 'active' LIMIT 1000" |
| ) |
| if expired: |
| session_ids = [str(s["id"]) for s in expired] |
| await self.db.execute( |
| "INSERT INTO archived_sessions (id, tenant_id, data, archived_at) SELECT id, tenant_id, jsonb_build_object('metadata', metadata, 'context', context_json, 'spent_usd', spent_usd, 'budget_usd', budget_usd), NOW() FROM sessions WHERE id = ANY($1)", |
| session_ids, |
| ) |
| await self.db.execute("UPDATE sessions SET status = 'expired' WHERE id = ANY($1)", session_ids) |
| pipe = self.redis.pipeline() |
| for sid in session_ids: |
| pipe.delete(f"ratelimit:*:{sid}") |
| pipe.delete(f"session:{sid}:*") |
| await pipe.execute() |
| logger.info(f"Cleaned up {len(expired)} expired sessions") |
| except Exception as e: |
| logger.exception(f"Error in cleanup: {e}") |
| await asyncio.sleep(CLEANUP_INTERVAL_SECONDS) |
| |
| async def _check_budget_alerts(self): |
| while self.running: |
| try: |
| nearing_limit = await self.db.fetch( |
| """SELECT id, tenant_id, budget_usd, spent_usd, (spent_usd / budget_usd) * 100 as pct_used |
| FROM sessions WHERE status = 'active' AND spent_usd > budget_usd * 0.8 AND spent_usd < budget_usd * 0.95 |
| AND NOT (metadata ? 'budget_alert_80_sent')""" |
| ) |
| for session in nearing_limit: |
| session_id = str(session["id"]) |
| pct = session["pct_used"] |
| await self.db.execute( |
| "INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'threshold_reached', $2, $3, $3, $4)", |
| session["id"], 80, session["spent_usd"], session["budget_usd"], |
| ) |
| await self.db.execute( |
| "UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_alert_80_sent}', 'true') WHERE id = $1", |
| session["id"], |
| ) |
| logger.warning(f"Budget alert: session {session_id} at {pct:.1f}% of budget") |
| |
| exceeded = await self.db.fetch( |
| """SELECT id, tenant_id, budget_usd, spent_usd FROM sessions |
| WHERE status = 'active' AND spent_usd >= budget_usd AND NOT (metadata ? 'budget_exceeded')""" |
| ) |
| for session in exceeded: |
| await self.db.execute( |
| "INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'exceeded', 100, $2, $2, $3)", |
| session["id"], session["spent_usd"], session["budget_usd"], |
| ) |
| await self.db.execute( |
| "UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_exceeded}', 'true') WHERE id = $1", |
| session["id"], |
| ) |
| logger.error(f"Budget exceeded: session {session['id']} spent ${session['spent_usd']} of ${session['budget_usd']}") |
| except Exception as e: |
| logger.exception(f"Error checking budgets: {e}") |
| await asyncio.sleep(BUDGET_CHECK_INTERVAL_SECONDS) |
| |
| async def _aggregate_costs(self): |
| while self.running: |
| try: |
| await self.db.execute("SELECT refresh_daily_costs()") |
| logger.info("Refreshed daily cost analytics") |
| except Exception as e: |
| logger.exception(f"Error aggregating costs: {e}") |
| await asyncio.sleep(3600) |
| |
| async def _refresh_analytics(self): |
| while self.running: |
| try: |
| top_models = await self.db.fetch( |
| """SELECT provider, model, COUNT(*) as requests, SUM(total_tokens) as tokens, AVG(latency_ms)::int as avg_latency |
| FROM requests WHERE created_at > NOW() - INTERVAL '24 hours' |
| GROUP BY provider, model ORDER BY requests DESC LIMIT 10""" |
| ) |
| logger.info("Top models (24h):") |
| for row in top_models: |
| logger.info(f" {row['provider']}/{row['model']}: {row['requests']} requests, {row['tokens']} tokens, {row['avg_latency']}ms avg latency") |
| active = await self.db.fetchval("SELECT COUNT(*) FROM sessions WHERE status = 'active'") |
| logger.info(f"Active sessions: {active}") |
| except Exception as e: |
| logger.exception(f"Error in analytics: {e}") |
| await asyncio.sleep(300) |
|
|
| async def main(): |
| worker = Worker() |
| loop = asyncio.get_event_loop() |
| for sig in (signal.SIGTERM, signal.SIGINT): |
| loop.add_signal_handler(sig, lambda: setattr(worker, 'running', False)) |
| await worker.start() |
|
|
| if __name__ == "__main__": |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") |
| asyncio.run(main()) |
|
|