vish85521 commited on
Commit
dffa8c2
·
verified ·
1 Parent(s): 1f23a32

Upload 17 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces (Docker) - FastAPI backend
2
+ # Build context: this folder
3
+
4
+ FROM python:3.11-slim
5
+
6
+ ENV PYTHONDONTWRITEBYTECODE=1 \
7
+ PYTHONUNBUFFERED=1
8
+
9
+ WORKDIR /code
10
+
11
+ # System deps: build tools + Postgres client libs (psycopg2) + minimal OpenCV runtime deps
12
+ RUN apt-get update \
13
+ && apt-get install -y --no-install-recommends \
14
+ build-essential \
15
+ gcc \
16
+ libpq-dev \
17
+ libgl1 \
18
+ libglib2.0-0 \
19
+ && rm -rf /var/lib/apt/lists/*
20
+
21
+ # Install Python dependencies.
22
+ # Note: This repo folder doesn't include requirements.txt/pyproject.toml, so we install pinned-at-runtime deps directly.
23
+ RUN pip install --no-cache-dir --upgrade pip \
24
+ && pip install --no-cache-dir \
25
+ fastapi \
26
+ "uvicorn[standard]" \
27
+ sqlalchemy \
28
+ psycopg2-binary \
29
+ pydantic \
30
+ pydantic-settings \
31
+ bcrypt \
32
+ "python-jose[cryptography]" \
33
+ redis \
34
+ celery \
35
+ aiofiles \
36
+ python-multipart \
37
+ google-genai \
38
+ opencv-python-headless
39
+
40
+ # Copy your FastAPI package into /code/app so imports like `app.main:app` work.
41
+ COPY . /code/app
42
+
43
+ # Hugging Face Spaces sets PORT; default to 7860 (HF convention) if not present.
44
+ EXPOSE 7860
45
+
46
+ CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port ${PORT:-7860}"]
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ Backend application module
3
+ """
4
+ from app.main import app
config.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration settings loaded from environment variables
3
+ """
4
+ import os
5
+ from pydantic_settings import BaseSettings
6
+ from functools import lru_cache
7
+
8
+
9
+ # Determine .env path - check both current dir and parent dir
10
+ def find_env_file():
11
+ """Find .env file in current or parent directory"""
12
+ current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
+
14
+ # Check backend directory
15
+ if os.path.exists(os.path.join(current_dir, ".env")):
16
+ return os.path.join(current_dir, ".env")
17
+
18
+ # Check parent directory (agent-society-platform)
19
+ parent_dir = os.path.dirname(current_dir)
20
+ if os.path.exists(os.path.join(parent_dir, ".env")):
21
+ return os.path.join(parent_dir, ".env")
22
+
23
+ return ".env"
24
+
25
+
26
+ class Settings(BaseSettings):
27
+ # Database
28
+ database_url: str = "postgresql://agentsociety:dev_password@localhost:5433/agentsociety_db"
29
+
30
+ # Redis
31
+ redis_url: str = "redis://localhost:6379/0"
32
+
33
+ # MQTT
34
+ mqtt_broker_host: str = "localhost"
35
+ mqtt_broker_port: int = 1883
36
+ mqtt_transport: str = "tcp"
37
+ mqtt_path: str = ""
38
+
39
+ # ChromaDB
40
+ chroma_host: str = "localhost"
41
+ chroma_port: int = 8000
42
+ chroma_ssl: bool = False
43
+
44
+ # AWS S3 (optional)
45
+ aws_access_key_id: str = ""
46
+ aws_secret_access_key: str = ""
47
+ aws_s3_bucket: str = "agentsociety-videos-dev"
48
+ aws_region: str = "ap-south-1"
49
+
50
+ # Google Gemini API (used by VLM video analysis)
51
+ gemini_api_key: str = ""
52
+ gemini_api_keys: str = "" # comma-separated keys for rotation
53
+
54
+ # Qwen LLM (HuggingFace Space - Ollama API, used by simulation agents)
55
+ qwen_api_url: str = "https://vish85521-doc.hf.space/api/generate"
56
+ qwen_model_name: str = "qwen3.5:397b-cloud"
57
+
58
+ # Security
59
+ jwt_secret: str = "change_this_to_a_random_32_character_string"
60
+ jwt_algorithm: str = "HS256"
61
+ jwt_expiry_hours: int = 24
62
+
63
+ # Simulation
64
+ default_num_agents: int = 10
65
+ default_simulation_days: int = 5
66
+
67
+ # File Storage
68
+ upload_dir: str = "uploads"
69
+
70
+ class Config:
71
+ env_file = find_env_file()
72
+ env_file_encoding = "utf-8"
73
+
74
+
75
+ @lru_cache()
76
+ def get_settings() -> Settings:
77
+ return Settings()
78
+
database.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Database connection and session management
3
+ """
4
+ from sqlalchemy import create_engine
5
+ from sqlalchemy.ext.declarative import declarative_base
6
+ from sqlalchemy.orm import sessionmaker
7
+ from app.config import get_settings
8
+
9
+ settings = get_settings()
10
+
11
+ # Create engine
12
+ engine = create_engine(
13
+ settings.database_url,
14
+ pool_pre_ping=True,
15
+ pool_size=10,
16
+ max_overflow=20
17
+ )
18
+
19
+ # Create session factory
20
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
21
+
22
+ # Base class for models
23
+ Base = declarative_base()
24
+
25
+
26
+ def get_db():
27
+ """
28
+ Dependency that provides database session
29
+ """
30
+ db = SessionLocal()
31
+ try:
32
+ yield db
33
+ finally:
34
+ db.close()
dependencies.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI dependencies for authentication and database access
3
+ """
4
+ from fastapi import Depends, HTTPException, status
5
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
6
+ from sqlalchemy.orm import Session
7
+ from app.database import get_db
8
+ from app.models import User
9
+ from app.services.auth_service import decode_access_token
10
+
11
+ # Security scheme
12
+ security = HTTPBearer()
13
+
14
+
15
+ async def get_current_user(
16
+ credentials: HTTPAuthorizationCredentials = Depends(security),
17
+ db: Session = Depends(get_db)
18
+ ) -> User:
19
+ """
20
+ Dependency to get the current authenticated user from JWT token
21
+
22
+ Raises:
23
+ HTTPException 401: If token is invalid or user not found
24
+ """
25
+ token = credentials.credentials
26
+
27
+ # Decode token
28
+ payload = decode_access_token(token)
29
+ if payload is None:
30
+ raise HTTPException(
31
+ status_code=status.HTTP_401_UNAUTHORIZED,
32
+ detail="Invalid or expired token",
33
+ headers={"WWW-Authenticate": "Bearer"}
34
+ )
35
+
36
+ # Get user ID from token
37
+ user_id = payload.get("sub")
38
+ if user_id is None:
39
+ raise HTTPException(
40
+ status_code=status.HTTP_401_UNAUTHORIZED,
41
+ detail="Invalid token payload",
42
+ headers={"WWW-Authenticate": "Bearer"}
43
+ )
44
+
45
+ # Fetch user from database
46
+ user = db.query(User).filter(User.id == user_id).first()
47
+ if user is None:
48
+ raise HTTPException(
49
+ status_code=status.HTTP_401_UNAUTHORIZED,
50
+ detail="User not found",
51
+ headers={"WWW-Authenticate": "Bearer"}
52
+ )
53
+
54
+ return user
main.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AgentSociety Marketing Platform - FastAPI Backend
3
+ """
4
+ import os
5
+ import logging
6
+ from fastapi import FastAPI
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from contextlib import asynccontextmanager
9
+ from sqlalchemy import text
10
+
11
+ # Configure logging
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
15
+ )
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @asynccontextmanager
20
+ async def lifespan(app: FastAPI):
21
+ """Application lifespan events"""
22
+ from app.config import get_settings
23
+ settings = get_settings()
24
+
25
+ # Startup
26
+ logger.info("Starting AgentSociety API...")
27
+
28
+ # Create upload directory
29
+ os.makedirs(settings.upload_dir, exist_ok=True)
30
+
31
+ # Start results listener (receives simulation results from Ray worker)
32
+ from app.results_listener import start_results_listener
33
+ start_results_listener(redis_url=settings.redis_url)
34
+ logger.info("Results listener started - listening for Ray worker results")
35
+
36
+ yield
37
+
38
+ # Shutdown
39
+ from app.results_listener import stop_results_listener
40
+ stop_results_listener()
41
+ logger.info("Shutting down AgentSociety API...")
42
+
43
+
44
+ # Create FastAPI app - disable redirect_slashes to prevent 307 that strips auth headers
45
+ app = FastAPI(
46
+ title="AgentSociety Marketing Platform",
47
+ description="AI-powered marketing simulation platform that simulates 1,000+ AI agents reacting to video advertisements",
48
+ version="1.0.0",
49
+ lifespan=lifespan,
50
+ redirect_slashes=False
51
+ )
52
+
53
+ # CORS configuration - explicit origins required when credentials=True
54
+ origins = [
55
+ "http://localhost:3000",
56
+ "http://127.0.0.1:3000",
57
+ "http://localhost:8000",
58
+ "http://127.0.0.1:8000",
59
+ ]
60
+
61
+ app.add_middleware(
62
+ CORSMiddleware,
63
+ allow_origins=origins,
64
+ allow_credentials=True,
65
+ allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
66
+ allow_headers=["*"],
67
+ expose_headers=["*"],
68
+ )
69
+
70
+ # Register routers
71
+ from app.routers import auth_router, projects_router, simulations_router
72
+ app.include_router(auth_router)
73
+ app.include_router(projects_router)
74
+ app.include_router(simulations_router)
75
+
76
+
77
+ @app.get("/")
78
+ async def root():
79
+ """Health check endpoint"""
80
+ return {
81
+ "status": "healthy",
82
+ "service": "AgentSociety API",
83
+ "version": "1.0.0"
84
+ }
85
+
86
+
87
+ @app.get("/health")
88
+ async def health_check():
89
+ """Detailed health check"""
90
+ from app.config import get_settings
91
+ settings = get_settings()
92
+
93
+ health = {
94
+ "api": "healthy",
95
+ "database": "unknown",
96
+ "redis": "unknown"
97
+ }
98
+
99
+ # Check database
100
+ try:
101
+ from app.database import engine
102
+ with engine.connect() as conn:
103
+ conn.execute(text("SELECT 1"))
104
+ health["database"] = "healthy"
105
+ except Exception as e:
106
+ health["database"] = f"unhealthy: {str(e)}"
107
+
108
+ # Check Redis
109
+ try:
110
+ import redis
111
+ import ssl as ssl_module
112
+ redis_kwargs = {}
113
+ if settings.redis_url.startswith("rediss://"):
114
+ redis_kwargs["ssl_cert_reqs"] = ssl_module.CERT_REQUIRED
115
+ r = redis.from_url(settings.redis_url, **redis_kwargs)
116
+ r.ping()
117
+ health["redis"] = "healthy"
118
+ except Exception as e:
119
+ health["redis"] = f"unhealthy: {str(e)}"
120
+
121
+ return health
models/__init__.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQLAlchemy models for the AgentSociety platform
3
+ """
4
+ import uuid
5
+ from datetime import datetime
6
+ from sqlalchemy import Column, String, Integer, Float, Text, DateTime, ForeignKey, JSON, BigInteger
7
+ from sqlalchemy.dialects.postgresql import UUID
8
+ from sqlalchemy.orm import relationship
9
+ from app.database import Base
10
+
11
+
12
+ class User(Base):
13
+ __tablename__ = "users"
14
+
15
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
16
+ email = Column(String(255), unique=True, nullable=False, index=True)
17
+ password_hash = Column(String(255), nullable=False)
18
+ created_at = Column(DateTime, default=datetime.utcnow)
19
+ subscription_tier = Column(String(20), default="FREE")
20
+
21
+ # Relationships
22
+ projects = relationship("Project", back_populates="user", cascade="all, delete-orphan")
23
+
24
+
25
+ class Project(Base):
26
+ __tablename__ = "projects"
27
+
28
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
29
+ user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
30
+ title = Column(String(255), nullable=False)
31
+ video_path = Column(String(500), nullable=False)
32
+ video_duration_seconds = Column(Integer, nullable=True)
33
+ vlm_generated_context = Column(Text, nullable=True)
34
+ demographic_filter = Column(JSON, nullable=True)
35
+ status = Column(String(20), default="PENDING")
36
+ created_at = Column(DateTime, default=datetime.utcnow)
37
+
38
+ # Relationships
39
+ user = relationship("User", back_populates="projects")
40
+ simulation_runs = relationship("SimulationRun", back_populates="project", cascade="all, delete-orphan")
41
+
42
+
43
+ class SimulationRun(Base):
44
+ __tablename__ = "simulation_runs"
45
+
46
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
47
+ project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
48
+ status = Column(String(20), default="PENDING")
49
+ num_agents = Column(Integer, default=1000)
50
+ simulation_days = Column(Integer, default=5)
51
+ virality_score = Column(Float, nullable=True)
52
+ sentiment_breakdown = Column(JSON, nullable=True)
53
+ map_data = Column(JSON, nullable=True) # lightweight per-agent coords/opinion/friends
54
+ agent_states = Column(JSON, nullable=True) # full agent profile/emotion/reasoning
55
+ started_at = Column(DateTime, nullable=True)
56
+ completed_at = Column(DateTime, nullable=True)
57
+ error_message = Column(Text, nullable=True)
58
+ created_at = Column(DateTime, default=datetime.utcnow)
59
+
60
+ # Relationships
61
+ project = relationship("Project", back_populates="simulation_runs")
62
+ agent_logs = relationship("AgentLog", back_populates="simulation_run", cascade="all, delete-orphan")
63
+ risk_flags = relationship("RiskFlag", back_populates="simulation_run", cascade="all, delete-orphan")
64
+
65
+
66
+ class AgentLog(Base):
67
+ __tablename__ = "agent_logs"
68
+
69
+ id = Column(BigInteger, primary_key=True, autoincrement=True)
70
+ simulation_run_id = Column(UUID(as_uuid=True), ForeignKey("simulation_runs.id", ondelete="CASCADE"), nullable=False)
71
+ agent_id = Column(String(50), nullable=False)
72
+ timestamp = Column(DateTime, default=datetime.utcnow)
73
+ event_type = Column(String(30), nullable=False)
74
+ event_data = Column(JSON, nullable=False)
75
+
76
+ # Relationships
77
+ simulation_run = relationship("SimulationRun", back_populates="agent_logs")
78
+
79
+
80
+ class RiskFlag(Base):
81
+ __tablename__ = "risk_flags"
82
+
83
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
84
+ simulation_run_id = Column(UUID(as_uuid=True), ForeignKey("simulation_runs.id", ondelete="CASCADE"), nullable=False)
85
+ flag_type = Column(String(50), nullable=False)
86
+ severity = Column(String(10), nullable=False)
87
+ description = Column(Text, nullable=False)
88
+ affected_demographics = Column(JSON, nullable=True)
89
+ sample_agent_reactions = Column(JSON, nullable=True)
90
+ detected_at = Column(DateTime, default=datetime.utcnow)
91
+
92
+ # Relationships
93
+ simulation_run = relationship("SimulationRun", back_populates="risk_flags")
results_listener.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Results Listener - Background thread that listens for simulation results from Ray worker
3
+
4
+ This runs as a background thread in the FastAPI application, listening to Redis
5
+ for simulation results and updating the database accordingly.
6
+ """
7
+ import os
8
+ import json
9
+ import logging
10
+ import threading
11
+ from datetime import datetime
12
+ from typing import Optional
13
+
14
+ import redis
15
+ from sqlalchemy.orm import Session
16
+
17
+ from app.database import SessionLocal
18
+ from app.models import SimulationRun, AgentLog, RiskFlag
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class ResultsListener:
24
+ """Background listener for simulation results from Ray worker"""
25
+
26
+ def __init__(self, redis_url: str = None):
27
+ self.redis_url = redis_url or os.getenv("REDIS_URL", "redis://localhost:6379/0")
28
+ self.running = False
29
+ self.thread: Optional[threading.Thread] = None
30
+
31
+ def start(self):
32
+ """Start the listener in a background thread"""
33
+ if self.running:
34
+ logger.warning("ResultsListener already running")
35
+ return
36
+
37
+ self.running = True
38
+ self.thread = threading.Thread(target=self._listen_loop, daemon=True)
39
+ self.thread.start()
40
+ logger.info("ResultsListener started")
41
+
42
+ def stop(self):
43
+ """Stop the listener"""
44
+ self.running = False
45
+ if self.thread:
46
+ self.thread.join(timeout=5)
47
+ logger.info("ResultsListener stopped")
48
+
49
+ def _listen_loop(self):
50
+ """Main listening loop"""
51
+ try:
52
+ import ssl as ssl_module
53
+ redis_kwargs = {}
54
+ if self.redis_url.startswith("rediss://"):
55
+ redis_kwargs["ssl_cert_reqs"] = ssl_module.CERT_REQUIRED
56
+ redis_client = redis.from_url(self.redis_url, **redis_kwargs)
57
+ pubsub = redis_client.pubsub()
58
+ pubsub.subscribe("simulation_results")
59
+
60
+ logger.info("Subscribed to 'simulation_results' channel")
61
+
62
+ for message in pubsub.listen():
63
+ if not self.running:
64
+ break
65
+
66
+ if message['type'] == 'message':
67
+ try:
68
+ data = json.loads(message['data'])
69
+ self._handle_result(data)
70
+ except json.JSONDecodeError as e:
71
+ logger.error(f"Invalid JSON in result: {e}")
72
+ except Exception as e:
73
+ logger.error(f"Error handling result: {e}")
74
+
75
+ pubsub.unsubscribe()
76
+
77
+ except Exception as e:
78
+ logger.error(f"ResultsListener error: {e}")
79
+ self.running = False
80
+
81
+ def _handle_result(self, data: dict):
82
+ """Process a simulation result from Ray worker"""
83
+ simulation_id = data.get('simulation_id')
84
+ status = data.get('status', 'FAILED')
85
+
86
+ if not simulation_id:
87
+ logger.error("Result missing simulation_id")
88
+ return
89
+
90
+ logger.info(f"Received result for simulation {simulation_id}: {status}")
91
+
92
+ db = SessionLocal()
93
+ try:
94
+ simulation = db.query(SimulationRun).filter(
95
+ SimulationRun.id == simulation_id
96
+ ).first()
97
+
98
+ if not simulation:
99
+ logger.error(f"Simulation {simulation_id} not found in database")
100
+ return
101
+
102
+ if status == 'COMPLETED':
103
+ results = data.get('results', {})
104
+
105
+ # Update simulation with results
106
+ simulation.status = "COMPLETED"
107
+ simulation.completed_at = datetime.utcnow()
108
+ simulation.virality_score = results.get('virality_score', 0)
109
+ simulation.sentiment_breakdown = results.get('sentiment_breakdown', {})
110
+
111
+ # Save map data and agent states for map visualization
112
+ simulation.map_data = results.get('map_data', [])
113
+ simulation.agent_states = results.get('agent_states', [])
114
+
115
+ db.commit()
116
+ logger.info(f"Simulation {simulation_id} marked as COMPLETED")
117
+
118
+ # Save agent logs (separate transaction to avoid rollback issues)
119
+ try:
120
+ self._save_agent_logs(db, simulation.id, results.get('agent_logs', []))
121
+ except Exception as e:
122
+ logger.warning(f"Failed to save agent logs: {e}")
123
+
124
+ # Save risk flags
125
+ try:
126
+ self._save_risk_flags(db, simulation.id, results.get('risk_flags', []))
127
+ except Exception as e:
128
+ logger.warning(f"Failed to save risk flags: {e}")
129
+
130
+ else:
131
+ # Failed simulation
132
+ simulation.status = "FAILED"
133
+ simulation.completed_at = datetime.utcnow()
134
+ simulation.error_message = data.get('error', 'Unknown error')
135
+ db.commit()
136
+ logger.info(f"Simulation {simulation_id} marked as FAILED: {data.get('error')}")
137
+
138
+ except Exception as e:
139
+ logger.error(f"Database error handling result: {e}")
140
+ db.rollback()
141
+ finally:
142
+ db.close()
143
+
144
+ def _save_agent_logs(self, db: Session, simulation_id: str, logs: list):
145
+ """Save agent logs to database"""
146
+ # Limit to prevent overwhelming the database
147
+ for log_data in logs[:50]:
148
+ try:
149
+ event_data = log_data.get('event_data', {})
150
+ # Ensure JSON serializable
151
+ if isinstance(event_data, dict):
152
+ event_data = json.loads(json.dumps(event_data, default=str))
153
+
154
+ agent_log = AgentLog(
155
+ simulation_run_id=simulation_id,
156
+ agent_id=str(log_data.get('agent_id', 'unknown')),
157
+ event_type=str(log_data.get('event_type', 'UNKNOWN')),
158
+ event_data=event_data
159
+ )
160
+ db.add(agent_log)
161
+ except Exception as e:
162
+ logger.warning(f"Failed to add agent log: {e}")
163
+
164
+ db.commit()
165
+ logger.info(f"Saved {min(50, len(logs))} agent logs for simulation {simulation_id}")
166
+
167
+ def _save_risk_flags(self, db: Session, simulation_id: str, flags: list):
168
+ """Save risk flags to database"""
169
+ for flag_data in flags:
170
+ try:
171
+ risk_flag = RiskFlag(
172
+ simulation_run_id=simulation_id,
173
+ flag_type=str(flag_data.get('flag_type', 'UNKNOWN')),
174
+ severity=str(flag_data.get('severity', 'LOW')),
175
+ description=str(flag_data.get('description', '')),
176
+ affected_demographics=flag_data.get('affected_demographics'),
177
+ sample_agent_reactions=flag_data.get('sample_agent_reactions')
178
+ )
179
+ db.add(risk_flag)
180
+ except Exception as e:
181
+ logger.warning(f"Failed to add risk flag: {e}")
182
+
183
+ db.commit()
184
+ logger.info(f"Saved {len(flags)} risk flags for simulation {simulation_id}")
185
+
186
+
187
+ # Global instance
188
+ _results_listener: Optional[ResultsListener] = None
189
+
190
+
191
+ def get_results_listener() -> ResultsListener:
192
+ """Get or create the global results listener"""
193
+ global _results_listener
194
+ if _results_listener is None:
195
+ _results_listener = ResultsListener()
196
+ return _results_listener
197
+
198
+
199
+ def start_results_listener(redis_url: str = None):
200
+ """Start the global results listener"""
201
+ global _results_listener
202
+ if _results_listener is None:
203
+ _results_listener = ResultsListener(redis_url=redis_url)
204
+ _results_listener.start()
205
+
206
+
207
+ def stop_results_listener():
208
+ """Stop the global results listener"""
209
+ if _results_listener:
210
+ _results_listener.stop()
routers/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Router module exports
3
+ """
4
+ from app.routers.auth import router as auth_router
5
+ from app.routers.projects import router as projects_router
6
+ from app.routers.simulations import router as simulations_router
7
+
8
+ __all__ = ["auth_router", "projects_router", "simulations_router"]
routers/auth.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Authentication routes - register, login, and user info
3
+ """
4
+ from fastapi import APIRouter, Depends, HTTPException, status
5
+ from sqlalchemy.orm import Session
6
+ from app.database import get_db
7
+ from app.models import User
8
+ from app.schemas import UserCreate, UserLogin, UserResponse, TokenResponse
9
+ from app.services.auth_service import hash_password, verify_password, create_access_token
10
+ from app.dependencies import get_current_user
11
+
12
+ router = APIRouter(prefix="/auth", tags=["Authentication"])
13
+
14
+
15
+ @router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
16
+ async def register(user_data: UserCreate, db: Session = Depends(get_db)):
17
+ """
18
+ Register a new user account
19
+ """
20
+ # Check if email already exists
21
+ existing_user = db.query(User).filter(User.email == user_data.email).first()
22
+ if existing_user:
23
+ raise HTTPException(
24
+ status_code=status.HTTP_400_BAD_REQUEST,
25
+ detail="Email already registered"
26
+ )
27
+
28
+ # Create new user
29
+ hashed_pw = hash_password(user_data.password)
30
+ new_user = User(
31
+ email=user_data.email,
32
+ password_hash=hashed_pw
33
+ )
34
+
35
+ db.add(new_user)
36
+ db.commit()
37
+ db.refresh(new_user)
38
+
39
+ return new_user
40
+
41
+
42
+ @router.post("/login", response_model=TokenResponse)
43
+ async def login(credentials: UserLogin, db: Session = Depends(get_db)):
44
+ """
45
+ Login and receive JWT access token
46
+ """
47
+ # Find user by email
48
+ user = db.query(User).filter(User.email == credentials.email).first()
49
+
50
+ if not user or not verify_password(credentials.password, user.password_hash):
51
+ raise HTTPException(
52
+ status_code=status.HTTP_401_UNAUTHORIZED,
53
+ detail="Invalid email or password",
54
+ headers={"WWW-Authenticate": "Bearer"}
55
+ )
56
+
57
+ # Create access token
58
+ access_token = create_access_token(data={"sub": str(user.id)})
59
+
60
+ return {"access_token": access_token, "token_type": "bearer"}
61
+
62
+
63
+ @router.get("/me", response_model=UserResponse)
64
+ async def get_me(current_user: User = Depends(get_current_user)):
65
+ """
66
+ Get current user information
67
+ """
68
+ return current_user
routers/projects.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Project management routes - create, list, and get projects
3
+ """
4
+ import os
5
+ import uuid
6
+ import aiofiles
7
+ from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
8
+ from sqlalchemy.orm import Session
9
+ from typing import List, Optional
10
+ from app.database import get_db
11
+ from app.models import User, Project
12
+ from app.schemas import ProjectCreate, ProjectResponse, ProjectListResponse
13
+ from app.dependencies import get_current_user
14
+ from app.config import get_settings
15
+
16
+ settings = get_settings()
17
+ router = APIRouter(prefix="/projects", tags=["Projects"])
18
+
19
+ # Allowed video extensions
20
+ ALLOWED_EXTENSIONS = {".mp4", ".mov", ".avi", ".webm", ".mkv"}
21
+ MAX_FILE_SIZE = 500 * 1024 * 1024 # 500MB
22
+
23
+
24
+ @router.post("", response_model=ProjectResponse, status_code=status.HTTP_201_CREATED)
25
+ async def create_project(
26
+ title: str = Form(...),
27
+ demographic_filter: Optional[str] = Form(None),
28
+ video: UploadFile = File(...),
29
+ current_user: User = Depends(get_current_user),
30
+ db: Session = Depends(get_db)
31
+ ):
32
+ """
33
+ Create a new project with video upload
34
+ """
35
+ # Validate file extension
36
+ file_ext = os.path.splitext(video.filename)[1].lower()
37
+ if file_ext not in ALLOWED_EXTENSIONS:
38
+ raise HTTPException(
39
+ status_code=status.HTTP_400_BAD_REQUEST,
40
+ detail=f"Invalid file type. Allowed types: {', '.join(ALLOWED_EXTENSIONS)}"
41
+ )
42
+
43
+ # Create upload directory if it doesn't exist
44
+ upload_dir = os.path.join(settings.upload_dir, str(current_user.id))
45
+ os.makedirs(upload_dir, exist_ok=True)
46
+
47
+ # Generate unique filename
48
+ file_id = str(uuid.uuid4())
49
+ filename = f"{file_id}{file_ext}"
50
+ file_path = os.path.join(upload_dir, filename)
51
+
52
+ # Save file
53
+ try:
54
+ async with aiofiles.open(file_path, 'wb') as f:
55
+ # Read in chunks to handle large files
56
+ total_size = 0
57
+ while chunk := await video.read(1024 * 1024): # 1MB chunks
58
+ total_size += len(chunk)
59
+ if total_size > MAX_FILE_SIZE:
60
+ os.remove(file_path)
61
+ raise HTTPException(
62
+ status_code=status.HTTP_400_BAD_REQUEST,
63
+ detail=f"File too large. Maximum size: {MAX_FILE_SIZE // (1024*1024)}MB"
64
+ )
65
+ await f.write(chunk)
66
+ except Exception as e:
67
+ if os.path.exists(file_path):
68
+ os.remove(file_path)
69
+ raise HTTPException(
70
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
71
+ detail=f"Failed to save file: {str(e)}"
72
+ )
73
+
74
+ # Parse demographic filter
75
+ demo_filter = None
76
+ if demographic_filter:
77
+ import json
78
+ try:
79
+ demo_filter = json.loads(demographic_filter)
80
+ except json.JSONDecodeError:
81
+ pass
82
+
83
+ # Create project record
84
+ project = Project(
85
+ user_id=current_user.id,
86
+ title=title,
87
+ video_path=file_path,
88
+ demographic_filter=demo_filter,
89
+ status="PENDING"
90
+ )
91
+
92
+ db.add(project)
93
+ db.commit()
94
+ db.refresh(project)
95
+
96
+ # Queue VLM processing task (async)
97
+ from app.tasks import process_video_task
98
+ process_video_task.delay(str(project.id))
99
+
100
+ return project
101
+
102
+
103
+ @router.get("", response_model=List[ProjectListResponse])
104
+ async def list_projects(
105
+ current_user: User = Depends(get_current_user),
106
+ db: Session = Depends(get_db)
107
+ ):
108
+ """
109
+ List all projects for the current user
110
+ """
111
+ projects = db.query(Project).filter(
112
+ Project.user_id == current_user.id
113
+ ).order_by(Project.created_at.desc()).all()
114
+
115
+ return projects
116
+
117
+
118
+ @router.get("/{project_id}", response_model=ProjectResponse)
119
+ async def get_project(
120
+ project_id: str,
121
+ current_user: User = Depends(get_current_user),
122
+ db: Session = Depends(get_db)
123
+ ):
124
+ """
125
+ Get a specific project by ID
126
+ """
127
+ project = db.query(Project).filter(
128
+ Project.id == project_id,
129
+ Project.user_id == current_user.id
130
+ ).first()
131
+
132
+ if not project:
133
+ raise HTTPException(
134
+ status_code=status.HTTP_404_NOT_FOUND,
135
+ detail="Project not found"
136
+ )
137
+
138
+ return project
139
+
140
+
141
+ @router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
142
+ async def delete_project(
143
+ project_id: str,
144
+ current_user: User = Depends(get_current_user),
145
+ db: Session = Depends(get_db)
146
+ ):
147
+ """
148
+ Delete a project and its video file
149
+ """
150
+ project = db.query(Project).filter(
151
+ Project.id == project_id,
152
+ Project.user_id == current_user.id
153
+ ).first()
154
+
155
+ if not project:
156
+ raise HTTPException(
157
+ status_code=status.HTTP_404_NOT_FOUND,
158
+ detail="Project not found"
159
+ )
160
+
161
+ # Delete video file
162
+ if os.path.exists(project.video_path):
163
+ os.remove(project.video_path)
164
+
165
+ # Delete project record
166
+ db.delete(project)
167
+ db.commit()
routers/simulations.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simulation management routes - start, status, and results
3
+ """
4
+ from fastapi import APIRouter, Depends, HTTPException, status
5
+ from sqlalchemy.orm import Session
6
+ from typing import List
7
+ from datetime import datetime
8
+ from app.database import get_db
9
+ from app.models import User, Project, SimulationRun, RiskFlag, AgentLog
10
+ from app.schemas import (
11
+ SimulationCreate,
12
+ SimulationResponse,
13
+ SimulationStatusResponse,
14
+ SimulationResultsResponse,
15
+ RiskFlagResponse,
16
+ MapDataResponse,
17
+ AgentDetailResponse,
18
+ AgentProfileData,
19
+ )
20
+ from app.dependencies import get_current_user
21
+ from app.config import get_settings
22
+
23
+ settings = get_settings()
24
+ router = APIRouter(prefix="/simulations", tags=["Simulations"])
25
+
26
+
27
+ @router.post("/{project_id}/start", response_model=SimulationResponse)
28
+ async def start_simulation(
29
+ project_id: str,
30
+ config: SimulationCreate = None,
31
+ current_user: User = Depends(get_current_user),
32
+ db: Session = Depends(get_db)
33
+ ):
34
+ """
35
+ Start a new simulation for a project
36
+ """
37
+ # Verify project exists and belongs to user
38
+ project = db.query(Project).filter(
39
+ Project.id == project_id,
40
+ Project.user_id == current_user.id
41
+ ).first()
42
+
43
+ if not project:
44
+ raise HTTPException(
45
+ status_code=status.HTTP_404_NOT_FOUND,
46
+ detail="Project not found"
47
+ )
48
+
49
+ # Check if project is ready (VLM processing complete)
50
+ if project.status != "READY":
51
+ raise HTTPException(
52
+ status_code=status.HTTP_400_BAD_REQUEST,
53
+ detail=f"Project is not ready for simulation. Current status: {project.status}"
54
+ )
55
+
56
+ # Use default config if not provided
57
+ if config is None:
58
+ config = SimulationCreate()
59
+
60
+ # Create simulation run record
61
+ simulation = SimulationRun(
62
+ project_id=project.id,
63
+ num_agents=config.num_agents,
64
+ simulation_days=config.simulation_days,
65
+ status="PENDING"
66
+ )
67
+
68
+ db.add(simulation)
69
+ db.commit()
70
+ db.refresh(simulation)
71
+
72
+ # Start simulation task (async)
73
+ from app.tasks import run_simulation_task
74
+ run_simulation_task.delay(str(simulation.id))
75
+
76
+ return simulation
77
+
78
+
79
+ @router.get("/{simulation_id}", response_model=SimulationResponse)
80
+ async def get_simulation(
81
+ simulation_id: str,
82
+ current_user: User = Depends(get_current_user),
83
+ db: Session = Depends(get_db)
84
+ ):
85
+ """
86
+ Get simulation details
87
+ """
88
+ simulation = db.query(SimulationRun).join(Project).filter(
89
+ SimulationRun.id == simulation_id,
90
+ Project.user_id == current_user.id
91
+ ).first()
92
+
93
+ if not simulation:
94
+ raise HTTPException(
95
+ status_code=status.HTTP_404_NOT_FOUND,
96
+ detail="Simulation not found"
97
+ )
98
+
99
+ return simulation
100
+
101
+
102
+ @router.get("/{simulation_id}/status", response_model=SimulationStatusResponse)
103
+ async def get_simulation_status(
104
+ simulation_id: str,
105
+ current_user: User = Depends(get_current_user),
106
+ db: Session = Depends(get_db)
107
+ ):
108
+ """
109
+ Get current simulation status (for polling)
110
+ """
111
+ simulation = db.query(SimulationRun).join(Project).filter(
112
+ SimulationRun.id == simulation_id,
113
+ Project.user_id == current_user.id
114
+ ).first()
115
+
116
+ if not simulation:
117
+ raise HTTPException(
118
+ status_code=status.HTTP_404_NOT_FOUND,
119
+ detail="Simulation not found"
120
+ )
121
+
122
+ # Calculate progress for running simulations
123
+ progress = None
124
+ current_day = None
125
+ active_agents = None
126
+
127
+ if simulation.status == "RUNNING":
128
+ # Get progress from Redis cache if available
129
+ import redis
130
+ import ssl as ssl_module
131
+ try:
132
+ redis_kwargs = {}
133
+ if settings.redis_url.startswith("rediss://"):
134
+ redis_kwargs["ssl_cert_reqs"] = ssl_module.CERT_REQUIRED
135
+ r = redis.from_url(settings.redis_url, **redis_kwargs)
136
+ cached = r.get(f"sim:{simulation_id}:status")
137
+ if cached:
138
+ import json
139
+ status_data = json.loads(cached)
140
+ progress = status_data.get("progress")
141
+ current_day = status_data.get("current_day")
142
+ active_agents = status_data.get("active_agents")
143
+ except:
144
+ pass
145
+ elif simulation.status == "COMPLETED":
146
+ progress = 100
147
+
148
+ return SimulationStatusResponse(
149
+ id=simulation.id,
150
+ status=simulation.status,
151
+ progress=progress,
152
+ current_day=current_day,
153
+ active_agents=active_agents
154
+ )
155
+
156
+
157
+ @router.get("/{simulation_id}/results", response_model=SimulationResultsResponse)
158
+ async def get_simulation_results(
159
+ simulation_id: str,
160
+ current_user: User = Depends(get_current_user),
161
+ db: Session = Depends(get_db)
162
+ ):
163
+ """
164
+ Get full simulation results with risk flags
165
+ """
166
+ simulation = db.query(SimulationRun).join(Project).filter(
167
+ SimulationRun.id == simulation_id,
168
+ Project.user_id == current_user.id
169
+ ).first()
170
+
171
+ if not simulation:
172
+ raise HTTPException(
173
+ status_code=status.HTTP_404_NOT_FOUND,
174
+ detail="Simulation not found"
175
+ )
176
+
177
+ if simulation.status != "COMPLETED":
178
+ raise HTTPException(
179
+ status_code=status.HTTP_400_BAD_REQUEST,
180
+ detail="Simulation has not completed yet"
181
+ )
182
+
183
+ # Get risk flags
184
+ risk_flags = db.query(RiskFlag).filter(
185
+ RiskFlag.simulation_run_id == simulation.id
186
+ ).order_by(RiskFlag.severity.desc()).all()
187
+
188
+ # Get sample agent logs
189
+ agent_sample = db.query(AgentLog).filter(
190
+ AgentLog.simulation_run_id == simulation.id,
191
+ AgentLog.event_type.in_(["BOYCOTT", "ENDORSEMENT"])
192
+ ).limit(10).all()
193
+
194
+ sample_data = [
195
+ {
196
+ "agent_id": log.agent_id,
197
+ "event_type": log.event_type,
198
+ "event_data": log.event_data,
199
+ "timestamp": log.timestamp.isoformat()
200
+ }
201
+ for log in agent_sample
202
+ ]
203
+
204
+ return SimulationResultsResponse(
205
+ simulation=simulation,
206
+ risk_flags=risk_flags,
207
+ agent_sample=sample_data
208
+ )
209
+
210
+
211
+ @router.get("/project/{project_id}", response_model=List[SimulationResponse])
212
+ async def list_project_simulations(
213
+ project_id: str,
214
+ current_user: User = Depends(get_current_user),
215
+ db: Session = Depends(get_db)
216
+ ):
217
+ """
218
+ List all simulations for a project
219
+ """
220
+ project = db.query(Project).filter(
221
+ Project.id == project_id,
222
+ Project.user_id == current_user.id
223
+ ).first()
224
+
225
+ if not project:
226
+ raise HTTPException(
227
+ status_code=status.HTTP_404_NOT_FOUND,
228
+ detail="Project not found"
229
+ )
230
+
231
+ simulations = db.query(SimulationRun).filter(
232
+ SimulationRun.project_id == project_id
233
+ ).order_by(SimulationRun.created_at.desc()).all()
234
+
235
+ return simulations
236
+
237
+
238
+ @router.get("/{simulation_id}/map-data", response_model=MapDataResponse)
239
+ async def get_simulation_map_data(
240
+ simulation_id: str,
241
+ current_user: User = Depends(get_current_user),
242
+ db: Session = Depends(get_db)
243
+ ):
244
+ """
245
+ Get lightweight map data for all agents in a completed simulation.
246
+ Returns coordinates, opinion and friends for each agent.
247
+ """
248
+ simulation = db.query(SimulationRun).join(Project).filter(
249
+ SimulationRun.id == simulation_id,
250
+ Project.user_id == current_user.id
251
+ ).first()
252
+
253
+ if not simulation:
254
+ raise HTTPException(
255
+ status_code=status.HTTP_404_NOT_FOUND,
256
+ detail="Simulation not found"
257
+ )
258
+
259
+ if simulation.status != "COMPLETED":
260
+ raise HTTPException(
261
+ status_code=status.HTTP_400_BAD_REQUEST,
262
+ detail="Simulation has not completed yet"
263
+ )
264
+
265
+ map_data = simulation.map_data or []
266
+ return MapDataResponse(map_data=map_data)
267
+
268
+
269
+ @router.get("/{simulation_id}/agents/{agent_id}", response_model=AgentDetailResponse)
270
+ async def get_agent_detail(
271
+ simulation_id: str,
272
+ agent_id: str,
273
+ current_user: User = Depends(get_current_user),
274
+ db: Session = Depends(get_db)
275
+ ):
276
+ """
277
+ Get full detail for a specific agent in a simulation.
278
+ Returns profile, emotion, opinion, reasoning and friends.
279
+ """
280
+ simulation = db.query(SimulationRun).join(Project).filter(
281
+ SimulationRun.id == simulation_id,
282
+ Project.user_id == current_user.id
283
+ ).first()
284
+
285
+ if not simulation:
286
+ raise HTTPException(
287
+ status_code=status.HTTP_404_NOT_FOUND,
288
+ detail="Simulation not found"
289
+ )
290
+
291
+ if simulation.status != "COMPLETED":
292
+ raise HTTPException(
293
+ status_code=status.HTTP_400_BAD_REQUEST,
294
+ detail="Simulation has not completed yet"
295
+ )
296
+
297
+ # Find the agent in agent_states
298
+ agent_states = simulation.agent_states or []
299
+ agent_data = None
300
+ for state in agent_states:
301
+ if state.get("agent_id") == agent_id:
302
+ agent_data = state
303
+ break
304
+
305
+ if not agent_data:
306
+ raise HTTPException(
307
+ status_code=status.HTTP_404_NOT_FOUND,
308
+ detail=f"Agent {agent_id} not found in simulation"
309
+ )
310
+
311
+ profile_raw = agent_data.get("profile", {})
312
+ return AgentDetailResponse(
313
+ agent_id=agent_data["agent_id"],
314
+ coordinates=agent_data.get("coordinates", [0, 0]),
315
+ opinion=agent_data.get("opinion", "NEUTRAL"),
316
+ emotion=agent_data.get("emotion", "neutral"),
317
+ emotion_intensity=agent_data.get("emotion_intensity", 0),
318
+ reasoning=agent_data.get("reasoning", ""),
319
+ friends=agent_data.get("friends", []),
320
+ profile=AgentProfileData(
321
+ age=profile_raw.get("age"),
322
+ gender=profile_raw.get("gender"),
323
+ location=profile_raw.get("location"),
324
+ occupation=profile_raw.get("occupation"),
325
+ education=profile_raw.get("education"),
326
+ values=profile_raw.get("values", []),
327
+ ),
328
+ )
schemas.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic schemas for request/response validation
3
+ """
4
+ from pydantic import BaseModel, EmailStr, Field
5
+ from typing import Optional, Dict, Any, List
6
+ from datetime import datetime
7
+ from uuid import UUID
8
+
9
+
10
+ # ----- User Schemas -----
11
+ class UserCreate(BaseModel):
12
+ email: EmailStr
13
+ password: str = Field(min_length=8, max_length=100)
14
+
15
+
16
+ class UserLogin(BaseModel):
17
+ email: EmailStr
18
+ password: str
19
+
20
+
21
+ class UserResponse(BaseModel):
22
+ id: UUID
23
+ email: str
24
+ subscription_tier: str
25
+ created_at: datetime
26
+
27
+ class Config:
28
+ from_attributes = True
29
+
30
+
31
+ class TokenResponse(BaseModel):
32
+ access_token: str
33
+ token_type: str = "bearer"
34
+
35
+
36
+ # ----- Project Schemas -----
37
+ class DemographicFilter(BaseModel):
38
+ age_range: Optional[List[int]] = None
39
+ location: Optional[str] = None
40
+ gender: Optional[str] = None
41
+ values: Optional[List[str]] = None
42
+
43
+
44
+ class ProjectCreate(BaseModel):
45
+ title: str = Field(min_length=1, max_length=200)
46
+ demographic_filter: Optional[DemographicFilter] = None
47
+
48
+
49
+ class ProjectResponse(BaseModel):
50
+ id: UUID
51
+ title: str
52
+ video_path: str
53
+ video_duration_seconds: Optional[int]
54
+ vlm_generated_context: Optional[str]
55
+ demographic_filter: Optional[Dict[str, Any]]
56
+ status: str
57
+ created_at: datetime
58
+
59
+ class Config:
60
+ from_attributes = True
61
+
62
+
63
+ class ProjectListResponse(BaseModel):
64
+ id: UUID
65
+ title: str
66
+ status: str
67
+ created_at: datetime
68
+
69
+ class Config:
70
+ from_attributes = True
71
+
72
+
73
+ # ----- Simulation Schemas -----
74
+ class SimulationCreate(BaseModel):
75
+ num_agents: int = Field(default=10, ge=1, le=10000)
76
+ simulation_days: int = Field(default=5, ge=1, le=30)
77
+
78
+
79
+ class SentimentBreakdown(BaseModel):
80
+ positive: int = 0
81
+ neutral: int = 0
82
+ negative: int = 0
83
+
84
+
85
+ class RiskFlagResponse(BaseModel):
86
+ id: UUID
87
+ flag_type: str
88
+ severity: str
89
+ description: str
90
+ affected_demographics: Optional[Dict[str, Any]]
91
+ sample_agent_reactions: Optional[List[Dict[str, Any]]]
92
+ detected_at: datetime
93
+
94
+ class Config:
95
+ from_attributes = True
96
+
97
+
98
+ class SimulationResponse(BaseModel):
99
+ id: UUID
100
+ project_id: UUID
101
+ status: str
102
+ num_agents: int
103
+ simulation_days: int
104
+ virality_score: Optional[float]
105
+ sentiment_breakdown: Optional[Dict[str, int]]
106
+ started_at: Optional[datetime]
107
+ completed_at: Optional[datetime]
108
+ error_message: Optional[str]
109
+ created_at: datetime
110
+
111
+ class Config:
112
+ from_attributes = True
113
+
114
+
115
+ class SimulationStatusResponse(BaseModel):
116
+ id: UUID
117
+ status: str
118
+ progress: Optional[int] = None
119
+ current_day: Optional[int] = None
120
+ active_agents: Optional[int] = None
121
+
122
+
123
+ class SimulationResultsResponse(BaseModel):
124
+ simulation: SimulationResponse
125
+ risk_flags: List[RiskFlagResponse]
126
+ agent_sample: Optional[List[Dict[str, Any]]] = None
127
+
128
+
129
+ # ----- Map Visualization Schemas -----
130
+ class MapAgentData(BaseModel):
131
+ agent_id: str
132
+ coordinates: List[float]
133
+ opinion: str
134
+ friends: List[str]
135
+
136
+
137
+ class MapDataResponse(BaseModel):
138
+ map_data: List[MapAgentData]
139
+
140
+
141
+ class AgentProfileData(BaseModel):
142
+ age: Optional[int] = None
143
+ gender: Optional[str] = None
144
+ location: Optional[str] = None
145
+ occupation: Optional[str] = None
146
+ education: Optional[str] = None
147
+ values: List[str] = []
148
+
149
+
150
+ class AgentDetailResponse(BaseModel):
151
+ agent_id: str
152
+ coordinates: List[float]
153
+ opinion: str
154
+ emotion: str
155
+ emotion_intensity: float = 0
156
+ reasoning: str = ""
157
+ friends: List[str] = []
158
+ profile: AgentProfileData
159
+
services/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Service module exports
3
+ """
4
+ from app.services.auth_service import hash_password, verify_password, create_access_token
5
+ from app.services.vlm_service import process_video, analyze_video_with_gemini
6
+
7
+ __all__ = [
8
+ "hash_password",
9
+ "verify_password",
10
+ "create_access_token",
11
+ "process_video",
12
+ "analyze_video_with_gemini"
13
+ ]
services/auth_service.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Authentication service - password hashing and JWT token management
3
+ """
4
+ from datetime import datetime, timedelta
5
+ from typing import Optional
6
+ import bcrypt
7
+ from jose import JWTError, jwt
8
+ from app.config import get_settings
9
+
10
+ settings = get_settings()
11
+
12
+
13
+ def hash_password(password: str) -> str:
14
+ """Hash a password using bcrypt directly (passlib-free, bcrypt>=4.0 compatible)"""
15
+ password_bytes = password.encode("utf-8")
16
+ salt = bcrypt.gensalt()
17
+ hashed = bcrypt.hashpw(password_bytes, salt)
18
+ return hashed.decode("utf-8")
19
+
20
+
21
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
22
+ """Verify a password against its bcrypt hash"""
23
+ try:
24
+ return bcrypt.checkpw(
25
+ plain_password.encode("utf-8"),
26
+ hashed_password.encode("utf-8")
27
+ )
28
+ except Exception:
29
+ return False
30
+
31
+
32
+ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
33
+ """
34
+ Create a JWT access token
35
+
36
+ Args:
37
+ data: Payload to encode (should include 'sub' with user ID)
38
+ expires_delta: Token expiry time
39
+
40
+ Returns:
41
+ Encoded JWT token
42
+ """
43
+ to_encode = data.copy()
44
+
45
+ if expires_delta:
46
+ expire = datetime.utcnow() + expires_delta
47
+ else:
48
+ expire = datetime.utcnow() + timedelta(hours=settings.jwt_expiry_hours)
49
+
50
+ to_encode.update({"exp": expire})
51
+
52
+ encoded_jwt = jwt.encode(
53
+ to_encode,
54
+ settings.jwt_secret,
55
+ algorithm=settings.jwt_algorithm
56
+ )
57
+
58
+ return encoded_jwt
59
+
60
+
61
+ def decode_access_token(token: str) -> Optional[dict]:
62
+ """
63
+ Decode and validate a JWT token
64
+
65
+ Args:
66
+ token: JWT token string
67
+
68
+ Returns:
69
+ Token payload if valid, None otherwise
70
+ """
71
+ try:
72
+ payload = jwt.decode(
73
+ token,
74
+ settings.jwt_secret,
75
+ algorithms=[settings.jwt_algorithm]
76
+ )
77
+ return payload
78
+ except JWTError:
79
+ return None
services/vlm_service.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Video processing service using Google Gemini Files API
3
+ Uploads entire video instead of extracting frames for better analysis
4
+ """
5
+ import os
6
+ import time
7
+ import logging
8
+ from typing import Tuple
9
+ from google import genai
10
+ from app.config import get_settings
11
+
12
+ settings = get_settings()
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Configure Gemini client
16
+ _client = None
17
+
18
+ def get_client():
19
+ """Get or create Gemini client"""
20
+ global _client
21
+ if _client is None:
22
+ _client = genai.Client(api_key=settings.gemini_api_key)
23
+ return _client
24
+
25
+
26
+ def get_video_duration_cv2(video_path: str) -> int:
27
+ """Get video duration in seconds using OpenCV"""
28
+ try:
29
+ import cv2
30
+ video = cv2.VideoCapture(video_path)
31
+ fps = video.get(cv2.CAP_PROP_FPS)
32
+ frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
33
+ video.release()
34
+
35
+ if fps > 0:
36
+ return int(frame_count / fps)
37
+ except Exception as e:
38
+ logger.warning(f"Could not get video duration: {e}")
39
+ return 0
40
+
41
+
42
+ def analyze_video_with_gemini(video_path: str) -> str:
43
+ """
44
+ Analyze video using Gemini Files API
45
+
46
+ This uploads the entire video file to Gemini for comprehensive analysis,
47
+ which provides better context than frame extraction.
48
+
49
+ Args:
50
+ video_path: Local path to video file
51
+
52
+ Returns:
53
+ Scene descriptions and analysis
54
+ """
55
+ try:
56
+ client = get_client()
57
+
58
+ # Step 1: Upload video to Gemini
59
+ logger.info(f"Uploading video to Gemini: {video_path}")
60
+ uploaded_file = client.files.upload(file=video_path)
61
+ logger.info(f"Uploaded file: {uploaded_file.name}")
62
+
63
+ # Step 2: Wait for processing (videos need processing time)
64
+ max_wait = 120 # Maximum 2 minutes wait
65
+ wait_time = 0
66
+
67
+ while uploaded_file.state == "PROCESSING":
68
+ if wait_time >= max_wait:
69
+ logger.error("Video processing timeout")
70
+ return "Video processing timeout. Please try with a shorter video."
71
+
72
+ logger.info(f"Waiting for video processing... ({wait_time}s)")
73
+ time.sleep(10)
74
+ wait_time += 10
75
+ uploaded_file = client.files.get(name=uploaded_file.name)
76
+
77
+ if uploaded_file.state == "FAILED":
78
+ logger.error("Video processing failed on Gemini side")
79
+ return "Video processing failed. Please try a different video format."
80
+
81
+ logger.info("Video processing complete, analyzing content...")
82
+
83
+ # Step 3: Analyze with Gemini
84
+ prompt = """Analyze this video advertisement in detail. Describe:
85
+
86
+ 1. **Visual Elements**: People shown (ages, genders, relationships), settings, products
87
+ 2. **Cultural Context**: Any cultural symbols, traditions, or practices shown
88
+ 3. **Narrative**: What story is being told? What emotions are evoked?
89
+ 4. **Target Audience**: Who appears to be the intended audience?
90
+ 5. **Potential Sensitivities**: Any elements that might be controversial or offensive to certain groups (religious, political, cultural, age-related)
91
+ 6. **Message**: What is the main message or call to action?
92
+
93
+ Be specific and detailed. Focus on elements that could affect how different demographic groups might react to this advertisement."""
94
+
95
+ response = client.models.generate_content(
96
+ model="gemini-3.1-flash-lite-preview",
97
+ contents=[uploaded_file, prompt]
98
+ )
99
+
100
+ # Step 4: Clean up uploaded file (optional - files expire automatically)
101
+ try:
102
+ client.files.delete(name=uploaded_file.name)
103
+ logger.info("Cleaned up uploaded video file")
104
+ except:
105
+ pass # File cleanup is optional
106
+
107
+ if response.text:
108
+ logger.info(f"Generated analysis: {len(response.text)} characters")
109
+ return response.text
110
+ else:
111
+ return "No analysis could be generated for this video."
112
+
113
+ except Exception as e:
114
+ logger.error(f"Gemini video analysis failed: {e}")
115
+ return f"Video analysis failed: {str(e)}"
116
+
117
+
118
+ def process_video(video_path: str) -> Tuple[str, int]:
119
+ """
120
+ Full video processing pipeline using Gemini Files API
121
+
122
+ Args:
123
+ video_path: Path to video file
124
+
125
+ Returns:
126
+ Tuple of (scene_descriptions, duration_seconds)
127
+ """
128
+ logger.info(f"Processing video: {video_path}")
129
+
130
+ # Get duration
131
+ duration = get_video_duration_cv2(video_path)
132
+
133
+ # Analyze with Gemini Files API
134
+ descriptions = analyze_video_with_gemini(video_path)
135
+
136
+ return descriptions, duration
tasks.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Celery tasks for background processing
3
+ """
4
+ import os
5
+ import sys
6
+ import logging
7
+ from celery import Celery
8
+ from datetime import datetime
9
+ from app.config import get_settings
10
+
11
+ # Add simulation directory to path (it's a sibling of backend)
12
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
+ project_root = os.path.dirname(backend_dir)
14
+ simulation_path = os.path.join(project_root, "simulation")
15
+ if simulation_path not in sys.path:
16
+ sys.path.insert(0, project_root)
17
+
18
+ settings = get_settings()
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Initialize Celery
22
+ celery_app = Celery(
23
+ "agentsociety",
24
+ broker=settings.redis_url,
25
+ backend=settings.redis_url
26
+ )
27
+
28
+ import ssl as ssl_module
29
+
30
+ celery_app.conf.update(
31
+ task_serializer="json",
32
+ accept_content=["json"],
33
+ result_serializer="json",
34
+ timezone="UTC",
35
+ enable_utc=True,
36
+ task_track_started=True,
37
+ task_time_limit=3600, # 1 hour max
38
+ broker_connection_retry_on_startup=True, # silence CPendingDeprecationWarning in Celery 5/6
39
+ )
40
+
41
+ # Enable SSL for rediss:// connections (e.g. Upstash)
42
+ if settings.redis_url.startswith("rediss://"):
43
+ celery_app.conf.update(
44
+ broker_use_ssl={"ssl_cert_reqs": ssl_module.CERT_REQUIRED},
45
+ redis_backend_use_ssl={"ssl_cert_reqs": ssl_module.CERT_REQUIRED},
46
+ )
47
+
48
+
49
+ @celery_app.task(bind=True)
50
+ def process_video_task(self, project_id: str):
51
+ """
52
+ Background task to process video with VLM
53
+ """
54
+ from app.database import SessionLocal
55
+ from app.models import Project
56
+ from app.services.vlm_service import process_video
57
+
58
+ db = SessionLocal()
59
+
60
+ try:
61
+ # Get project
62
+ project = db.query(Project).filter(Project.id == project_id).first()
63
+ if not project:
64
+ logger.error(f"Project not found: {project_id}")
65
+ return {"error": "Project not found"}
66
+
67
+ # Update status
68
+ project.status = "PROCESSING"
69
+ db.commit()
70
+
71
+ logger.info(f"Processing video for project {project_id}")
72
+
73
+ # Process video
74
+ descriptions, duration = process_video(project.video_path)
75
+
76
+ # Update project with results
77
+ project.vlm_generated_context = descriptions
78
+ project.video_duration_seconds = duration
79
+ project.status = "READY"
80
+ db.commit()
81
+
82
+ logger.info(f"Video processing complete for project {project_id}")
83
+
84
+ return {
85
+ "project_id": project_id,
86
+ "status": "READY",
87
+ "duration": duration
88
+ }
89
+
90
+ except Exception as e:
91
+ logger.error(f"Video processing failed for project {project_id}: {e}")
92
+ try:
93
+ project = db.query(Project).filter(Project.id == project_id).first()
94
+ if project:
95
+ project.status = "FAILED"
96
+ db.commit()
97
+ except:
98
+ pass
99
+ return {"error": str(e)}
100
+ finally:
101
+ db.close()
102
+
103
+
104
+ @celery_app.task(bind=True)
105
+ def run_simulation_task(self, simulation_id: str):
106
+ """
107
+ Background task to queue simulation for Ray worker
108
+
109
+ This task:
110
+ 1. Validates the simulation and project
111
+ 2. Publishes request to Redis 'simulation_requests' channel
112
+ 3. Ray worker (separate process) handles the actual simulation
113
+ 4. Results listener updates the database when complete
114
+ """
115
+ from app.database import SessionLocal
116
+ from app.models import SimulationRun, Project
117
+ import json
118
+ import redis
119
+
120
+ db = SessionLocal()
121
+ redis_kwargs = {}
122
+ if settings.redis_url.startswith("rediss://"):
123
+ redis_kwargs["ssl_cert_reqs"] = ssl_module.CERT_REQUIRED
124
+ redis_client = redis.from_url(settings.redis_url, **redis_kwargs)
125
+
126
+ try:
127
+ # Get simulation
128
+ simulation = db.query(SimulationRun).filter(SimulationRun.id == simulation_id).first()
129
+ if not simulation:
130
+ logger.error(f"Simulation not found: {simulation_id}")
131
+ return {"error": "Simulation not found"}
132
+
133
+ # Get project
134
+ project = db.query(Project).filter(Project.id == simulation.project_id).first()
135
+ if not project or not project.vlm_generated_context:
136
+ logger.error(f"Project not ready for simulation: {simulation.project_id}")
137
+ simulation.status = "FAILED"
138
+ simulation.error_message = "Project video analysis not complete"
139
+ db.commit()
140
+ return {"error": "Project not ready"}
141
+
142
+ # Update status to RUNNING (Ray worker will process)
143
+ simulation.status = "RUNNING"
144
+ simulation.started_at = datetime.utcnow()
145
+ db.commit()
146
+
147
+ logger.info(f"Sending simulation {simulation_id} to Ray worker")
148
+
149
+ # Publish request to Redis for Ray worker
150
+ request = {
151
+ "simulation_id": str(simulation.id),
152
+ "project_id": str(project.id),
153
+ "ad_content": project.vlm_generated_context,
154
+ "demographic_filter": project.demographic_filter,
155
+ "num_agents": simulation.num_agents,
156
+ "simulation_days": simulation.simulation_days
157
+ }
158
+
159
+ redis_client.publish("simulation_requests", json.dumps(request))
160
+
161
+ logger.info(f"Simulation {simulation_id} published to Ray worker queue")
162
+
163
+ return {
164
+ "simulation_id": simulation_id,
165
+ "status": "RUNNING",
166
+ "message": "Simulation sent to Ray worker for processing"
167
+ }
168
+
169
+ except Exception as e:
170
+ logger.error(f"Failed to queue simulation {simulation_id}: {e}")
171
+ try:
172
+ simulation = db.query(SimulationRun).filter(SimulationRun.id == simulation_id).first()
173
+ if simulation:
174
+ simulation.status = "FAILED"
175
+ simulation.error_message = str(e)
176
+ simulation.completed_at = datetime.utcnow()
177
+ db.commit()
178
+ except:
179
+ pass
180
+ return {"error": str(e)}
181
+ finally:
182
+ db.close()