File size: 6,872 Bytes
76b1148 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """Background worker for ml-intern production system."""
import asyncio
import json
import logging
import os
import signal
import time
from datetime import datetime, timedelta
import aioredis
import asyncpg
logger = logging.getLogger("ml_intern.worker")
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379")
DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://localhost/ml_intern")
CLEANUP_INTERVAL_SECONDS = int(os.environ.get("CLEANUP_INTERVAL_SECONDS", "3600"))
BUDGET_CHECK_INTERVAL_SECONDS = int(os.environ.get("BUDGET_CHECK_INTERVAL_SECONDS", "300"))
class Worker:
def __init__(self):
self.redis: aioredis.Redis = None
self.db: asyncpg.Pool = None
self.running = True
async def start(self):
self.redis = aioredis.from_url(REDIS_URL, decode_responses=True)
await self.redis.ping()
self.db = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=5, command_timeout=60)
logger.info("Worker started")
tasks = [
asyncio.create_task(self._cleanup_expired_sessions()),
asyncio.create_task(self._check_budget_alerts()),
asyncio.create_task(self._aggregate_costs()),
asyncio.create_task(self._refresh_analytics()),
]
await asyncio.gather(*tasks, return_exceptions=True)
async def _cleanup_expired_sessions(self):
while self.running:
try:
expired = await self.db.fetch(
"SELECT id, tenant_id FROM sessions WHERE expires_at < NOW() AND status = 'active' LIMIT 1000"
)
if expired:
session_ids = [str(s["id"]) for s in expired]
await self.db.execute(
"INSERT INTO archived_sessions (id, tenant_id, data, archived_at) SELECT id, tenant_id, jsonb_build_object('metadata', metadata, 'context', context_json, 'spent_usd', spent_usd, 'budget_usd', budget_usd), NOW() FROM sessions WHERE id = ANY($1)",
session_ids,
)
await self.db.execute("UPDATE sessions SET status = 'expired' WHERE id = ANY($1)", session_ids)
pipe = self.redis.pipeline()
for sid in session_ids:
pipe.delete(f"ratelimit:*:{sid}")
pipe.delete(f"session:{sid}:*")
await pipe.execute()
logger.info(f"Cleaned up {len(expired)} expired sessions")
except Exception as e:
logger.exception(f"Error in cleanup: {e}")
await asyncio.sleep(CLEANUP_INTERVAL_SECONDS)
async def _check_budget_alerts(self):
while self.running:
try:
nearing_limit = await self.db.fetch(
"""SELECT id, tenant_id, budget_usd, spent_usd, (spent_usd / budget_usd) * 100 as pct_used
FROM sessions WHERE status = 'active' AND spent_usd > budget_usd * 0.8 AND spent_usd < budget_usd * 0.95
AND NOT (metadata ? 'budget_alert_80_sent')"""
)
for session in nearing_limit:
session_id = str(session["id"])
pct = session["pct_used"]
await self.db.execute(
"INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'threshold_reached', $2, $3, $3, $4)",
session["id"], 80, session["spent_usd"], session["budget_usd"],
)
await self.db.execute(
"UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_alert_80_sent}', 'true') WHERE id = $1",
session["id"],
)
logger.warning(f"Budget alert: session {session_id} at {pct:.1f}% of budget")
exceeded = await self.db.fetch(
"""SELECT id, tenant_id, budget_usd, spent_usd FROM sessions
WHERE status = 'active' AND spent_usd >= budget_usd AND NOT (metadata ? 'budget_exceeded')"""
)
for session in exceeded:
await self.db.execute(
"INSERT INTO budget_events (session_id, event_type, threshold_percent, spent_before, spent_after, budget) VALUES ($1, 'exceeded', 100, $2, $2, $3)",
session["id"], session["spent_usd"], session["budget_usd"],
)
await self.db.execute(
"UPDATE sessions SET metadata = jsonb_set(metadata, '{budget_exceeded}', 'true') WHERE id = $1",
session["id"],
)
logger.error(f"Budget exceeded: session {session['id']} spent ${session['spent_usd']} of ${session['budget_usd']}")
except Exception as e:
logger.exception(f"Error checking budgets: {e}")
await asyncio.sleep(BUDGET_CHECK_INTERVAL_SECONDS)
async def _aggregate_costs(self):
while self.running:
try:
await self.db.execute("SELECT refresh_daily_costs()")
logger.info("Refreshed daily cost analytics")
except Exception as e:
logger.exception(f"Error aggregating costs: {e}")
await asyncio.sleep(3600)
async def _refresh_analytics(self):
while self.running:
try:
top_models = await self.db.fetch(
"""SELECT provider, model, COUNT(*) as requests, SUM(total_tokens) as tokens, AVG(latency_ms)::int as avg_latency
FROM requests WHERE created_at > NOW() - INTERVAL '24 hours'
GROUP BY provider, model ORDER BY requests DESC LIMIT 10"""
)
logger.info("Top models (24h):")
for row in top_models:
logger.info(f" {row['provider']}/{row['model']}: {row['requests']} requests, {row['tokens']} tokens, {row['avg_latency']}ms avg latency")
active = await self.db.fetchval("SELECT COUNT(*) FROM sessions WHERE status = 'active'")
logger.info(f"Active sessions: {active}")
except Exception as e:
logger.exception(f"Error in analytics: {e}")
await asyncio.sleep(300)
async def main():
worker = Worker()
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda: setattr(worker, 'running', False))
await worker.start()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
asyncio.run(main())
|