Upload 17 files
Browse files- Dockerfile +46 -0
- __init__.py +4 -0
- config.py +78 -0
- database.py +34 -0
- dependencies.py +54 -0
- main.py +121 -0
- models/__init__.py +93 -0
- results_listener.py +210 -0
- routers/__init__.py +8 -0
- routers/auth.py +68 -0
- routers/projects.py +167 -0
- routers/simulations.py +328 -0
- schemas.py +159 -0
- services/__init__.py +13 -0
- services/auth_service.py +79 -0
- services/vlm_service.py +136 -0
- tasks.py +182 -0
Dockerfile
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces (Docker) - FastAPI backend
|
| 2 |
+
# Build context: this folder
|
| 3 |
+
|
| 4 |
+
FROM python:3.11-slim
|
| 5 |
+
|
| 6 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 7 |
+
PYTHONUNBUFFERED=1
|
| 8 |
+
|
| 9 |
+
WORKDIR /code
|
| 10 |
+
|
| 11 |
+
# System deps: build tools + Postgres client libs (psycopg2) + minimal OpenCV runtime deps
|
| 12 |
+
RUN apt-get update \
|
| 13 |
+
&& apt-get install -y --no-install-recommends \
|
| 14 |
+
build-essential \
|
| 15 |
+
gcc \
|
| 16 |
+
libpq-dev \
|
| 17 |
+
libgl1 \
|
| 18 |
+
libglib2.0-0 \
|
| 19 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 20 |
+
|
| 21 |
+
# Install Python dependencies.
|
| 22 |
+
# Note: This repo folder doesn't include requirements.txt/pyproject.toml, so we install pinned-at-runtime deps directly.
|
| 23 |
+
RUN pip install --no-cache-dir --upgrade pip \
|
| 24 |
+
&& pip install --no-cache-dir \
|
| 25 |
+
fastapi \
|
| 26 |
+
"uvicorn[standard]" \
|
| 27 |
+
sqlalchemy \
|
| 28 |
+
psycopg2-binary \
|
| 29 |
+
pydantic \
|
| 30 |
+
pydantic-settings \
|
| 31 |
+
bcrypt \
|
| 32 |
+
"python-jose[cryptography]" \
|
| 33 |
+
redis \
|
| 34 |
+
celery \
|
| 35 |
+
aiofiles \
|
| 36 |
+
python-multipart \
|
| 37 |
+
google-genai \
|
| 38 |
+
opencv-python-headless
|
| 39 |
+
|
| 40 |
+
# Copy your FastAPI package into /code/app so imports like `app.main:app` work.
|
| 41 |
+
COPY . /code/app
|
| 42 |
+
|
| 43 |
+
# Hugging Face Spaces sets PORT; default to 7860 (HF convention) if not present.
|
| 44 |
+
EXPOSE 7860
|
| 45 |
+
|
| 46 |
+
CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port ${PORT:-7860}"]
|
__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Backend application module
|
| 3 |
+
"""
|
| 4 |
+
from app.main import app
|
config.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration settings loaded from environment variables
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from pydantic_settings import BaseSettings
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Determine .env path - check both current dir and parent dir
|
| 10 |
+
def find_env_file():
|
| 11 |
+
"""Find .env file in current or parent directory"""
|
| 12 |
+
current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
+
|
| 14 |
+
# Check backend directory
|
| 15 |
+
if os.path.exists(os.path.join(current_dir, ".env")):
|
| 16 |
+
return os.path.join(current_dir, ".env")
|
| 17 |
+
|
| 18 |
+
# Check parent directory (agent-society-platform)
|
| 19 |
+
parent_dir = os.path.dirname(current_dir)
|
| 20 |
+
if os.path.exists(os.path.join(parent_dir, ".env")):
|
| 21 |
+
return os.path.join(parent_dir, ".env")
|
| 22 |
+
|
| 23 |
+
return ".env"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Settings(BaseSettings):
|
| 27 |
+
# Database
|
| 28 |
+
database_url: str = "postgresql://agentsociety:dev_password@localhost:5433/agentsociety_db"
|
| 29 |
+
|
| 30 |
+
# Redis
|
| 31 |
+
redis_url: str = "redis://localhost:6379/0"
|
| 32 |
+
|
| 33 |
+
# MQTT
|
| 34 |
+
mqtt_broker_host: str = "localhost"
|
| 35 |
+
mqtt_broker_port: int = 1883
|
| 36 |
+
mqtt_transport: str = "tcp"
|
| 37 |
+
mqtt_path: str = ""
|
| 38 |
+
|
| 39 |
+
# ChromaDB
|
| 40 |
+
chroma_host: str = "localhost"
|
| 41 |
+
chroma_port: int = 8000
|
| 42 |
+
chroma_ssl: bool = False
|
| 43 |
+
|
| 44 |
+
# AWS S3 (optional)
|
| 45 |
+
aws_access_key_id: str = ""
|
| 46 |
+
aws_secret_access_key: str = ""
|
| 47 |
+
aws_s3_bucket: str = "agentsociety-videos-dev"
|
| 48 |
+
aws_region: str = "ap-south-1"
|
| 49 |
+
|
| 50 |
+
# Google Gemini API (used by VLM video analysis)
|
| 51 |
+
gemini_api_key: str = ""
|
| 52 |
+
gemini_api_keys: str = "" # comma-separated keys for rotation
|
| 53 |
+
|
| 54 |
+
# Qwen LLM (HuggingFace Space - Ollama API, used by simulation agents)
|
| 55 |
+
qwen_api_url: str = "https://vish85521-doc.hf.space/api/generate"
|
| 56 |
+
qwen_model_name: str = "qwen3.5:397b-cloud"
|
| 57 |
+
|
| 58 |
+
# Security
|
| 59 |
+
jwt_secret: str = "change_this_to_a_random_32_character_string"
|
| 60 |
+
jwt_algorithm: str = "HS256"
|
| 61 |
+
jwt_expiry_hours: int = 24
|
| 62 |
+
|
| 63 |
+
# Simulation
|
| 64 |
+
default_num_agents: int = 10
|
| 65 |
+
default_simulation_days: int = 5
|
| 66 |
+
|
| 67 |
+
# File Storage
|
| 68 |
+
upload_dir: str = "uploads"
|
| 69 |
+
|
| 70 |
+
class Config:
|
| 71 |
+
env_file = find_env_file()
|
| 72 |
+
env_file_encoding = "utf-8"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@lru_cache()
|
| 76 |
+
def get_settings() -> Settings:
|
| 77 |
+
return Settings()
|
| 78 |
+
|
database.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database connection and session management
|
| 3 |
+
"""
|
| 4 |
+
from sqlalchemy import create_engine
|
| 5 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 6 |
+
from sqlalchemy.orm import sessionmaker
|
| 7 |
+
from app.config import get_settings
|
| 8 |
+
|
| 9 |
+
settings = get_settings()
|
| 10 |
+
|
| 11 |
+
# Create engine
|
| 12 |
+
engine = create_engine(
|
| 13 |
+
settings.database_url,
|
| 14 |
+
pool_pre_ping=True,
|
| 15 |
+
pool_size=10,
|
| 16 |
+
max_overflow=20
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Create session factory
|
| 20 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 21 |
+
|
| 22 |
+
# Base class for models
|
| 23 |
+
Base = declarative_base()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_db():
|
| 27 |
+
"""
|
| 28 |
+
Dependency that provides database session
|
| 29 |
+
"""
|
| 30 |
+
db = SessionLocal()
|
| 31 |
+
try:
|
| 32 |
+
yield db
|
| 33 |
+
finally:
|
| 34 |
+
db.close()
|
dependencies.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI dependencies for authentication and database access
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import Depends, HTTPException, status
|
| 5 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 6 |
+
from sqlalchemy.orm import Session
|
| 7 |
+
from app.database import get_db
|
| 8 |
+
from app.models import User
|
| 9 |
+
from app.services.auth_service import decode_access_token
|
| 10 |
+
|
| 11 |
+
# Security scheme
|
| 12 |
+
security = HTTPBearer()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def get_current_user(
|
| 16 |
+
credentials: HTTPAuthorizationCredentials = Depends(security),
|
| 17 |
+
db: Session = Depends(get_db)
|
| 18 |
+
) -> User:
|
| 19 |
+
"""
|
| 20 |
+
Dependency to get the current authenticated user from JWT token
|
| 21 |
+
|
| 22 |
+
Raises:
|
| 23 |
+
HTTPException 401: If token is invalid or user not found
|
| 24 |
+
"""
|
| 25 |
+
token = credentials.credentials
|
| 26 |
+
|
| 27 |
+
# Decode token
|
| 28 |
+
payload = decode_access_token(token)
|
| 29 |
+
if payload is None:
|
| 30 |
+
raise HTTPException(
|
| 31 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 32 |
+
detail="Invalid or expired token",
|
| 33 |
+
headers={"WWW-Authenticate": "Bearer"}
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Get user ID from token
|
| 37 |
+
user_id = payload.get("sub")
|
| 38 |
+
if user_id is None:
|
| 39 |
+
raise HTTPException(
|
| 40 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 41 |
+
detail="Invalid token payload",
|
| 42 |
+
headers={"WWW-Authenticate": "Bearer"}
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Fetch user from database
|
| 46 |
+
user = db.query(User).filter(User.id == user_id).first()
|
| 47 |
+
if user is None:
|
| 48 |
+
raise HTTPException(
|
| 49 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 50 |
+
detail="User not found",
|
| 51 |
+
headers={"WWW-Authenticate": "Bearer"}
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return user
|
main.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AgentSociety Marketing Platform - FastAPI Backend
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
from fastapi import FastAPI
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from contextlib import asynccontextmanager
|
| 9 |
+
from sqlalchemy import text
|
| 10 |
+
|
| 11 |
+
# Configure logging
|
| 12 |
+
logging.basicConfig(
|
| 13 |
+
level=logging.INFO,
|
| 14 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
| 15 |
+
)
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@asynccontextmanager
|
| 20 |
+
async def lifespan(app: FastAPI):
|
| 21 |
+
"""Application lifespan events"""
|
| 22 |
+
from app.config import get_settings
|
| 23 |
+
settings = get_settings()
|
| 24 |
+
|
| 25 |
+
# Startup
|
| 26 |
+
logger.info("Starting AgentSociety API...")
|
| 27 |
+
|
| 28 |
+
# Create upload directory
|
| 29 |
+
os.makedirs(settings.upload_dir, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
# Start results listener (receives simulation results from Ray worker)
|
| 32 |
+
from app.results_listener import start_results_listener
|
| 33 |
+
start_results_listener(redis_url=settings.redis_url)
|
| 34 |
+
logger.info("Results listener started - listening for Ray worker results")
|
| 35 |
+
|
| 36 |
+
yield
|
| 37 |
+
|
| 38 |
+
# Shutdown
|
| 39 |
+
from app.results_listener import stop_results_listener
|
| 40 |
+
stop_results_listener()
|
| 41 |
+
logger.info("Shutting down AgentSociety API...")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Create FastAPI app - disable redirect_slashes to prevent 307 that strips auth headers
|
| 45 |
+
app = FastAPI(
|
| 46 |
+
title="AgentSociety Marketing Platform",
|
| 47 |
+
description="AI-powered marketing simulation platform that simulates 1,000+ AI agents reacting to video advertisements",
|
| 48 |
+
version="1.0.0",
|
| 49 |
+
lifespan=lifespan,
|
| 50 |
+
redirect_slashes=False
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# CORS configuration - explicit origins required when credentials=True
|
| 54 |
+
origins = [
|
| 55 |
+
"http://localhost:3000",
|
| 56 |
+
"http://127.0.0.1:3000",
|
| 57 |
+
"http://localhost:8000",
|
| 58 |
+
"http://127.0.0.1:8000",
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
app.add_middleware(
|
| 62 |
+
CORSMiddleware,
|
| 63 |
+
allow_origins=origins,
|
| 64 |
+
allow_credentials=True,
|
| 65 |
+
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
| 66 |
+
allow_headers=["*"],
|
| 67 |
+
expose_headers=["*"],
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Register routers
|
| 71 |
+
from app.routers import auth_router, projects_router, simulations_router
|
| 72 |
+
app.include_router(auth_router)
|
| 73 |
+
app.include_router(projects_router)
|
| 74 |
+
app.include_router(simulations_router)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@app.get("/")
|
| 78 |
+
async def root():
|
| 79 |
+
"""Health check endpoint"""
|
| 80 |
+
return {
|
| 81 |
+
"status": "healthy",
|
| 82 |
+
"service": "AgentSociety API",
|
| 83 |
+
"version": "1.0.0"
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@app.get("/health")
|
| 88 |
+
async def health_check():
|
| 89 |
+
"""Detailed health check"""
|
| 90 |
+
from app.config import get_settings
|
| 91 |
+
settings = get_settings()
|
| 92 |
+
|
| 93 |
+
health = {
|
| 94 |
+
"api": "healthy",
|
| 95 |
+
"database": "unknown",
|
| 96 |
+
"redis": "unknown"
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
# Check database
|
| 100 |
+
try:
|
| 101 |
+
from app.database import engine
|
| 102 |
+
with engine.connect() as conn:
|
| 103 |
+
conn.execute(text("SELECT 1"))
|
| 104 |
+
health["database"] = "healthy"
|
| 105 |
+
except Exception as e:
|
| 106 |
+
health["database"] = f"unhealthy: {str(e)}"
|
| 107 |
+
|
| 108 |
+
# Check Redis
|
| 109 |
+
try:
|
| 110 |
+
import redis
|
| 111 |
+
import ssl as ssl_module
|
| 112 |
+
redis_kwargs = {}
|
| 113 |
+
if settings.redis_url.startswith("rediss://"):
|
| 114 |
+
redis_kwargs["ssl_cert_reqs"] = ssl_module.CERT_REQUIRED
|
| 115 |
+
r = redis.from_url(settings.redis_url, **redis_kwargs)
|
| 116 |
+
r.ping()
|
| 117 |
+
health["redis"] = "healthy"
|
| 118 |
+
except Exception as e:
|
| 119 |
+
health["redis"] = f"unhealthy: {str(e)}"
|
| 120 |
+
|
| 121 |
+
return health
|
models/__init__.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQLAlchemy models for the AgentSociety platform
|
| 3 |
+
"""
|
| 4 |
+
import uuid
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from sqlalchemy import Column, String, Integer, Float, Text, DateTime, ForeignKey, JSON, BigInteger
|
| 7 |
+
from sqlalchemy.dialects.postgresql import UUID
|
| 8 |
+
from sqlalchemy.orm import relationship
|
| 9 |
+
from app.database import Base
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class User(Base):
|
| 13 |
+
__tablename__ = "users"
|
| 14 |
+
|
| 15 |
+
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
| 16 |
+
email = Column(String(255), unique=True, nullable=False, index=True)
|
| 17 |
+
password_hash = Column(String(255), nullable=False)
|
| 18 |
+
created_at = Column(DateTime, default=datetime.utcnow)
|
| 19 |
+
subscription_tier = Column(String(20), default="FREE")
|
| 20 |
+
|
| 21 |
+
# Relationships
|
| 22 |
+
projects = relationship("Project", back_populates="user", cascade="all, delete-orphan")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Project(Base):
|
| 26 |
+
__tablename__ = "projects"
|
| 27 |
+
|
| 28 |
+
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
| 29 |
+
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
| 30 |
+
title = Column(String(255), nullable=False)
|
| 31 |
+
video_path = Column(String(500), nullable=False)
|
| 32 |
+
video_duration_seconds = Column(Integer, nullable=True)
|
| 33 |
+
vlm_generated_context = Column(Text, nullable=True)
|
| 34 |
+
demographic_filter = Column(JSON, nullable=True)
|
| 35 |
+
status = Column(String(20), default="PENDING")
|
| 36 |
+
created_at = Column(DateTime, default=datetime.utcnow)
|
| 37 |
+
|
| 38 |
+
# Relationships
|
| 39 |
+
user = relationship("User", back_populates="projects")
|
| 40 |
+
simulation_runs = relationship("SimulationRun", back_populates="project", cascade="all, delete-orphan")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class SimulationRun(Base):
|
| 44 |
+
__tablename__ = "simulation_runs"
|
| 45 |
+
|
| 46 |
+
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
| 47 |
+
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
| 48 |
+
status = Column(String(20), default="PENDING")
|
| 49 |
+
num_agents = Column(Integer, default=1000)
|
| 50 |
+
simulation_days = Column(Integer, default=5)
|
| 51 |
+
virality_score = Column(Float, nullable=True)
|
| 52 |
+
sentiment_breakdown = Column(JSON, nullable=True)
|
| 53 |
+
map_data = Column(JSON, nullable=True) # lightweight per-agent coords/opinion/friends
|
| 54 |
+
agent_states = Column(JSON, nullable=True) # full agent profile/emotion/reasoning
|
| 55 |
+
started_at = Column(DateTime, nullable=True)
|
| 56 |
+
completed_at = Column(DateTime, nullable=True)
|
| 57 |
+
error_message = Column(Text, nullable=True)
|
| 58 |
+
created_at = Column(DateTime, default=datetime.utcnow)
|
| 59 |
+
|
| 60 |
+
# Relationships
|
| 61 |
+
project = relationship("Project", back_populates="simulation_runs")
|
| 62 |
+
agent_logs = relationship("AgentLog", back_populates="simulation_run", cascade="all, delete-orphan")
|
| 63 |
+
risk_flags = relationship("RiskFlag", back_populates="simulation_run", cascade="all, delete-orphan")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class AgentLog(Base):
|
| 67 |
+
__tablename__ = "agent_logs"
|
| 68 |
+
|
| 69 |
+
id = Column(BigInteger, primary_key=True, autoincrement=True)
|
| 70 |
+
simulation_run_id = Column(UUID(as_uuid=True), ForeignKey("simulation_runs.id", ondelete="CASCADE"), nullable=False)
|
| 71 |
+
agent_id = Column(String(50), nullable=False)
|
| 72 |
+
timestamp = Column(DateTime, default=datetime.utcnow)
|
| 73 |
+
event_type = Column(String(30), nullable=False)
|
| 74 |
+
event_data = Column(JSON, nullable=False)
|
| 75 |
+
|
| 76 |
+
# Relationships
|
| 77 |
+
simulation_run = relationship("SimulationRun", back_populates="agent_logs")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class RiskFlag(Base):
|
| 81 |
+
__tablename__ = "risk_flags"
|
| 82 |
+
|
| 83 |
+
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
| 84 |
+
simulation_run_id = Column(UUID(as_uuid=True), ForeignKey("simulation_runs.id", ondelete="CASCADE"), nullable=False)
|
| 85 |
+
flag_type = Column(String(50), nullable=False)
|
| 86 |
+
severity = Column(String(10), nullable=False)
|
| 87 |
+
description = Column(Text, nullable=False)
|
| 88 |
+
affected_demographics = Column(JSON, nullable=True)
|
| 89 |
+
sample_agent_reactions = Column(JSON, nullable=True)
|
| 90 |
+
detected_at = Column(DateTime, default=datetime.utcnow)
|
| 91 |
+
|
| 92 |
+
# Relationships
|
| 93 |
+
simulation_run = relationship("SimulationRun", back_populates="risk_flags")
|
results_listener.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Results Listener - Background thread that listens for simulation results from Ray worker
|
| 3 |
+
|
| 4 |
+
This runs as a background thread in the FastAPI application, listening to Redis
|
| 5 |
+
for simulation results and updating the database accordingly.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
import threading
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
import redis
|
| 15 |
+
from sqlalchemy.orm import Session
|
| 16 |
+
|
| 17 |
+
from app.database import SessionLocal
|
| 18 |
+
from app.models import SimulationRun, AgentLog, RiskFlag
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ResultsListener:
|
| 24 |
+
"""Background listener for simulation results from Ray worker"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, redis_url: str = None):
|
| 27 |
+
self.redis_url = redis_url or os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
| 28 |
+
self.running = False
|
| 29 |
+
self.thread: Optional[threading.Thread] = None
|
| 30 |
+
|
| 31 |
+
def start(self):
|
| 32 |
+
"""Start the listener in a background thread"""
|
| 33 |
+
if self.running:
|
| 34 |
+
logger.warning("ResultsListener already running")
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
self.running = True
|
| 38 |
+
self.thread = threading.Thread(target=self._listen_loop, daemon=True)
|
| 39 |
+
self.thread.start()
|
| 40 |
+
logger.info("ResultsListener started")
|
| 41 |
+
|
| 42 |
+
def stop(self):
|
| 43 |
+
"""Stop the listener"""
|
| 44 |
+
self.running = False
|
| 45 |
+
if self.thread:
|
| 46 |
+
self.thread.join(timeout=5)
|
| 47 |
+
logger.info("ResultsListener stopped")
|
| 48 |
+
|
| 49 |
+
def _listen_loop(self):
|
| 50 |
+
"""Main listening loop"""
|
| 51 |
+
try:
|
| 52 |
+
import ssl as ssl_module
|
| 53 |
+
redis_kwargs = {}
|
| 54 |
+
if self.redis_url.startswith("rediss://"):
|
| 55 |
+
redis_kwargs["ssl_cert_reqs"] = ssl_module.CERT_REQUIRED
|
| 56 |
+
redis_client = redis.from_url(self.redis_url, **redis_kwargs)
|
| 57 |
+
pubsub = redis_client.pubsub()
|
| 58 |
+
pubsub.subscribe("simulation_results")
|
| 59 |
+
|
| 60 |
+
logger.info("Subscribed to 'simulation_results' channel")
|
| 61 |
+
|
| 62 |
+
for message in pubsub.listen():
|
| 63 |
+
if not self.running:
|
| 64 |
+
break
|
| 65 |
+
|
| 66 |
+
if message['type'] == 'message':
|
| 67 |
+
try:
|
| 68 |
+
data = json.loads(message['data'])
|
| 69 |
+
self._handle_result(data)
|
| 70 |
+
except json.JSONDecodeError as e:
|
| 71 |
+
logger.error(f"Invalid JSON in result: {e}")
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.error(f"Error handling result: {e}")
|
| 74 |
+
|
| 75 |
+
pubsub.unsubscribe()
|
| 76 |
+
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.error(f"ResultsListener error: {e}")
|
| 79 |
+
self.running = False
|
| 80 |
+
|
| 81 |
+
def _handle_result(self, data: dict):
|
| 82 |
+
"""Process a simulation result from Ray worker"""
|
| 83 |
+
simulation_id = data.get('simulation_id')
|
| 84 |
+
status = data.get('status', 'FAILED')
|
| 85 |
+
|
| 86 |
+
if not simulation_id:
|
| 87 |
+
logger.error("Result missing simulation_id")
|
| 88 |
+
return
|
| 89 |
+
|
| 90 |
+
logger.info(f"Received result for simulation {simulation_id}: {status}")
|
| 91 |
+
|
| 92 |
+
db = SessionLocal()
|
| 93 |
+
try:
|
| 94 |
+
simulation = db.query(SimulationRun).filter(
|
| 95 |
+
SimulationRun.id == simulation_id
|
| 96 |
+
).first()
|
| 97 |
+
|
| 98 |
+
if not simulation:
|
| 99 |
+
logger.error(f"Simulation {simulation_id} not found in database")
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
if status == 'COMPLETED':
|
| 103 |
+
results = data.get('results', {})
|
| 104 |
+
|
| 105 |
+
# Update simulation with results
|
| 106 |
+
simulation.status = "COMPLETED"
|
| 107 |
+
simulation.completed_at = datetime.utcnow()
|
| 108 |
+
simulation.virality_score = results.get('virality_score', 0)
|
| 109 |
+
simulation.sentiment_breakdown = results.get('sentiment_breakdown', {})
|
| 110 |
+
|
| 111 |
+
# Save map data and agent states for map visualization
|
| 112 |
+
simulation.map_data = results.get('map_data', [])
|
| 113 |
+
simulation.agent_states = results.get('agent_states', [])
|
| 114 |
+
|
| 115 |
+
db.commit()
|
| 116 |
+
logger.info(f"Simulation {simulation_id} marked as COMPLETED")
|
| 117 |
+
|
| 118 |
+
# Save agent logs (separate transaction to avoid rollback issues)
|
| 119 |
+
try:
|
| 120 |
+
self._save_agent_logs(db, simulation.id, results.get('agent_logs', []))
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.warning(f"Failed to save agent logs: {e}")
|
| 123 |
+
|
| 124 |
+
# Save risk flags
|
| 125 |
+
try:
|
| 126 |
+
self._save_risk_flags(db, simulation.id, results.get('risk_flags', []))
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.warning(f"Failed to save risk flags: {e}")
|
| 129 |
+
|
| 130 |
+
else:
|
| 131 |
+
# Failed simulation
|
| 132 |
+
simulation.status = "FAILED"
|
| 133 |
+
simulation.completed_at = datetime.utcnow()
|
| 134 |
+
simulation.error_message = data.get('error', 'Unknown error')
|
| 135 |
+
db.commit()
|
| 136 |
+
logger.info(f"Simulation {simulation_id} marked as FAILED: {data.get('error')}")
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.error(f"Database error handling result: {e}")
|
| 140 |
+
db.rollback()
|
| 141 |
+
finally:
|
| 142 |
+
db.close()
|
| 143 |
+
|
| 144 |
+
def _save_agent_logs(self, db: Session, simulation_id: str, logs: list):
|
| 145 |
+
"""Save agent logs to database"""
|
| 146 |
+
# Limit to prevent overwhelming the database
|
| 147 |
+
for log_data in logs[:50]:
|
| 148 |
+
try:
|
| 149 |
+
event_data = log_data.get('event_data', {})
|
| 150 |
+
# Ensure JSON serializable
|
| 151 |
+
if isinstance(event_data, dict):
|
| 152 |
+
event_data = json.loads(json.dumps(event_data, default=str))
|
| 153 |
+
|
| 154 |
+
agent_log = AgentLog(
|
| 155 |
+
simulation_run_id=simulation_id,
|
| 156 |
+
agent_id=str(log_data.get('agent_id', 'unknown')),
|
| 157 |
+
event_type=str(log_data.get('event_type', 'UNKNOWN')),
|
| 158 |
+
event_data=event_data
|
| 159 |
+
)
|
| 160 |
+
db.add(agent_log)
|
| 161 |
+
except Exception as e:
|
| 162 |
+
logger.warning(f"Failed to add agent log: {e}")
|
| 163 |
+
|
| 164 |
+
db.commit()
|
| 165 |
+
logger.info(f"Saved {min(50, len(logs))} agent logs for simulation {simulation_id}")
|
| 166 |
+
|
| 167 |
+
def _save_risk_flags(self, db: Session, simulation_id: str, flags: list):
|
| 168 |
+
"""Save risk flags to database"""
|
| 169 |
+
for flag_data in flags:
|
| 170 |
+
try:
|
| 171 |
+
risk_flag = RiskFlag(
|
| 172 |
+
simulation_run_id=simulation_id,
|
| 173 |
+
flag_type=str(flag_data.get('flag_type', 'UNKNOWN')),
|
| 174 |
+
severity=str(flag_data.get('severity', 'LOW')),
|
| 175 |
+
description=str(flag_data.get('description', '')),
|
| 176 |
+
affected_demographics=flag_data.get('affected_demographics'),
|
| 177 |
+
sample_agent_reactions=flag_data.get('sample_agent_reactions')
|
| 178 |
+
)
|
| 179 |
+
db.add(risk_flag)
|
| 180 |
+
except Exception as e:
|
| 181 |
+
logger.warning(f"Failed to add risk flag: {e}")
|
| 182 |
+
|
| 183 |
+
db.commit()
|
| 184 |
+
logger.info(f"Saved {len(flags)} risk flags for simulation {simulation_id}")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# Global instance
|
| 188 |
+
_results_listener: Optional[ResultsListener] = None
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def get_results_listener() -> ResultsListener:
|
| 192 |
+
"""Get or create the global results listener"""
|
| 193 |
+
global _results_listener
|
| 194 |
+
if _results_listener is None:
|
| 195 |
+
_results_listener = ResultsListener()
|
| 196 |
+
return _results_listener
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def start_results_listener(redis_url: str = None):
|
| 200 |
+
"""Start the global results listener"""
|
| 201 |
+
global _results_listener
|
| 202 |
+
if _results_listener is None:
|
| 203 |
+
_results_listener = ResultsListener(redis_url=redis_url)
|
| 204 |
+
_results_listener.start()
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def stop_results_listener():
|
| 208 |
+
"""Stop the global results listener"""
|
| 209 |
+
if _results_listener:
|
| 210 |
+
_results_listener.stop()
|
routers/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Router module exports
|
| 3 |
+
"""
|
| 4 |
+
from app.routers.auth import router as auth_router
|
| 5 |
+
from app.routers.projects import router as projects_router
|
| 6 |
+
from app.routers.simulations import router as simulations_router
|
| 7 |
+
|
| 8 |
+
__all__ = ["auth_router", "projects_router", "simulations_router"]
|
routers/auth.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication routes - register, login, and user info
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
| 5 |
+
from sqlalchemy.orm import Session
|
| 6 |
+
from app.database import get_db
|
| 7 |
+
from app.models import User
|
| 8 |
+
from app.schemas import UserCreate, UserLogin, UserResponse, TokenResponse
|
| 9 |
+
from app.services.auth_service import hash_password, verify_password, create_access_token
|
| 10 |
+
from app.dependencies import get_current_user
|
| 11 |
+
|
| 12 |
+
router = APIRouter(prefix="/auth", tags=["Authentication"])
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
| 16 |
+
async def register(user_data: UserCreate, db: Session = Depends(get_db)):
|
| 17 |
+
"""
|
| 18 |
+
Register a new user account
|
| 19 |
+
"""
|
| 20 |
+
# Check if email already exists
|
| 21 |
+
existing_user = db.query(User).filter(User.email == user_data.email).first()
|
| 22 |
+
if existing_user:
|
| 23 |
+
raise HTTPException(
|
| 24 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 25 |
+
detail="Email already registered"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Create new user
|
| 29 |
+
hashed_pw = hash_password(user_data.password)
|
| 30 |
+
new_user = User(
|
| 31 |
+
email=user_data.email,
|
| 32 |
+
password_hash=hashed_pw
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
db.add(new_user)
|
| 36 |
+
db.commit()
|
| 37 |
+
db.refresh(new_user)
|
| 38 |
+
|
| 39 |
+
return new_user
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@router.post("/login", response_model=TokenResponse)
|
| 43 |
+
async def login(credentials: UserLogin, db: Session = Depends(get_db)):
|
| 44 |
+
"""
|
| 45 |
+
Login and receive JWT access token
|
| 46 |
+
"""
|
| 47 |
+
# Find user by email
|
| 48 |
+
user = db.query(User).filter(User.email == credentials.email).first()
|
| 49 |
+
|
| 50 |
+
if not user or not verify_password(credentials.password, user.password_hash):
|
| 51 |
+
raise HTTPException(
|
| 52 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 53 |
+
detail="Invalid email or password",
|
| 54 |
+
headers={"WWW-Authenticate": "Bearer"}
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Create access token
|
| 58 |
+
access_token = create_access_token(data={"sub": str(user.id)})
|
| 59 |
+
|
| 60 |
+
return {"access_token": access_token, "token_type": "bearer"}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@router.get("/me", response_model=UserResponse)
|
| 64 |
+
async def get_me(current_user: User = Depends(get_current_user)):
|
| 65 |
+
"""
|
| 66 |
+
Get current user information
|
| 67 |
+
"""
|
| 68 |
+
return current_user
|
routers/projects.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Project management routes - create, list, and get projects
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import uuid
|
| 6 |
+
import aiofiles
|
| 7 |
+
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
|
| 8 |
+
from sqlalchemy.orm import Session
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
from app.database import get_db
|
| 11 |
+
from app.models import User, Project
|
| 12 |
+
from app.schemas import ProjectCreate, ProjectResponse, ProjectListResponse
|
| 13 |
+
from app.dependencies import get_current_user
|
| 14 |
+
from app.config import get_settings
|
| 15 |
+
|
| 16 |
+
settings = get_settings()
|
| 17 |
+
router = APIRouter(prefix="/projects", tags=["Projects"])
|
| 18 |
+
|
| 19 |
+
# Allowed video extensions
|
| 20 |
+
ALLOWED_EXTENSIONS = {".mp4", ".mov", ".avi", ".webm", ".mkv"}
|
| 21 |
+
MAX_FILE_SIZE = 500 * 1024 * 1024 # 500MB
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@router.post("", response_model=ProjectResponse, status_code=status.HTTP_201_CREATED)
|
| 25 |
+
async def create_project(
|
| 26 |
+
title: str = Form(...),
|
| 27 |
+
demographic_filter: Optional[str] = Form(None),
|
| 28 |
+
video: UploadFile = File(...),
|
| 29 |
+
current_user: User = Depends(get_current_user),
|
| 30 |
+
db: Session = Depends(get_db)
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Create a new project with video upload
|
| 34 |
+
"""
|
| 35 |
+
# Validate file extension
|
| 36 |
+
file_ext = os.path.splitext(video.filename)[1].lower()
|
| 37 |
+
if file_ext not in ALLOWED_EXTENSIONS:
|
| 38 |
+
raise HTTPException(
|
| 39 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 40 |
+
detail=f"Invalid file type. Allowed types: {', '.join(ALLOWED_EXTENSIONS)}"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Create upload directory if it doesn't exist
|
| 44 |
+
upload_dir = os.path.join(settings.upload_dir, str(current_user.id))
|
| 45 |
+
os.makedirs(upload_dir, exist_ok=True)
|
| 46 |
+
|
| 47 |
+
# Generate unique filename
|
| 48 |
+
file_id = str(uuid.uuid4())
|
| 49 |
+
filename = f"{file_id}{file_ext}"
|
| 50 |
+
file_path = os.path.join(upload_dir, filename)
|
| 51 |
+
|
| 52 |
+
# Save file
|
| 53 |
+
try:
|
| 54 |
+
async with aiofiles.open(file_path, 'wb') as f:
|
| 55 |
+
# Read in chunks to handle large files
|
| 56 |
+
total_size = 0
|
| 57 |
+
while chunk := await video.read(1024 * 1024): # 1MB chunks
|
| 58 |
+
total_size += len(chunk)
|
| 59 |
+
if total_size > MAX_FILE_SIZE:
|
| 60 |
+
os.remove(file_path)
|
| 61 |
+
raise HTTPException(
|
| 62 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 63 |
+
detail=f"File too large. Maximum size: {MAX_FILE_SIZE // (1024*1024)}MB"
|
| 64 |
+
)
|
| 65 |
+
await f.write(chunk)
|
| 66 |
+
except Exception as e:
|
| 67 |
+
if os.path.exists(file_path):
|
| 68 |
+
os.remove(file_path)
|
| 69 |
+
raise HTTPException(
|
| 70 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 71 |
+
detail=f"Failed to save file: {str(e)}"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Parse demographic filter
|
| 75 |
+
demo_filter = None
|
| 76 |
+
if demographic_filter:
|
| 77 |
+
import json
|
| 78 |
+
try:
|
| 79 |
+
demo_filter = json.loads(demographic_filter)
|
| 80 |
+
except json.JSONDecodeError:
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
# Create project record
|
| 84 |
+
project = Project(
|
| 85 |
+
user_id=current_user.id,
|
| 86 |
+
title=title,
|
| 87 |
+
video_path=file_path,
|
| 88 |
+
demographic_filter=demo_filter,
|
| 89 |
+
status="PENDING"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
db.add(project)
|
| 93 |
+
db.commit()
|
| 94 |
+
db.refresh(project)
|
| 95 |
+
|
| 96 |
+
# Queue VLM processing task (async)
|
| 97 |
+
from app.tasks import process_video_task
|
| 98 |
+
process_video_task.delay(str(project.id))
|
| 99 |
+
|
| 100 |
+
return project
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@router.get("", response_model=List[ProjectListResponse])
|
| 104 |
+
async def list_projects(
|
| 105 |
+
current_user: User = Depends(get_current_user),
|
| 106 |
+
db: Session = Depends(get_db)
|
| 107 |
+
):
|
| 108 |
+
"""
|
| 109 |
+
List all projects for the current user
|
| 110 |
+
"""
|
| 111 |
+
projects = db.query(Project).filter(
|
| 112 |
+
Project.user_id == current_user.id
|
| 113 |
+
).order_by(Project.created_at.desc()).all()
|
| 114 |
+
|
| 115 |
+
return projects
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@router.get("/{project_id}", response_model=ProjectResponse)
|
| 119 |
+
async def get_project(
|
| 120 |
+
project_id: str,
|
| 121 |
+
current_user: User = Depends(get_current_user),
|
| 122 |
+
db: Session = Depends(get_db)
|
| 123 |
+
):
|
| 124 |
+
"""
|
| 125 |
+
Get a specific project by ID
|
| 126 |
+
"""
|
| 127 |
+
project = db.query(Project).filter(
|
| 128 |
+
Project.id == project_id,
|
| 129 |
+
Project.user_id == current_user.id
|
| 130 |
+
).first()
|
| 131 |
+
|
| 132 |
+
if not project:
|
| 133 |
+
raise HTTPException(
|
| 134 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 135 |
+
detail="Project not found"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
return project
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
|
| 142 |
+
async def delete_project(
|
| 143 |
+
project_id: str,
|
| 144 |
+
current_user: User = Depends(get_current_user),
|
| 145 |
+
db: Session = Depends(get_db)
|
| 146 |
+
):
|
| 147 |
+
"""
|
| 148 |
+
Delete a project and its video file
|
| 149 |
+
"""
|
| 150 |
+
project = db.query(Project).filter(
|
| 151 |
+
Project.id == project_id,
|
| 152 |
+
Project.user_id == current_user.id
|
| 153 |
+
).first()
|
| 154 |
+
|
| 155 |
+
if not project:
|
| 156 |
+
raise HTTPException(
|
| 157 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 158 |
+
detail="Project not found"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Delete video file
|
| 162 |
+
if os.path.exists(project.video_path):
|
| 163 |
+
os.remove(project.video_path)
|
| 164 |
+
|
| 165 |
+
# Delete project record
|
| 166 |
+
db.delete(project)
|
| 167 |
+
db.commit()
|
routers/simulations.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simulation management routes - start, status, and results
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
| 5 |
+
from sqlalchemy.orm import Session
|
| 6 |
+
from typing import List
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from app.database import get_db
|
| 9 |
+
from app.models import User, Project, SimulationRun, RiskFlag, AgentLog
|
| 10 |
+
from app.schemas import (
|
| 11 |
+
SimulationCreate,
|
| 12 |
+
SimulationResponse,
|
| 13 |
+
SimulationStatusResponse,
|
| 14 |
+
SimulationResultsResponse,
|
| 15 |
+
RiskFlagResponse,
|
| 16 |
+
MapDataResponse,
|
| 17 |
+
AgentDetailResponse,
|
| 18 |
+
AgentProfileData,
|
| 19 |
+
)
|
| 20 |
+
from app.dependencies import get_current_user
|
| 21 |
+
from app.config import get_settings
|
| 22 |
+
|
| 23 |
+
settings = get_settings()
|
| 24 |
+
router = APIRouter(prefix="/simulations", tags=["Simulations"])
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@router.post("/{project_id}/start", response_model=SimulationResponse)
|
| 28 |
+
async def start_simulation(
|
| 29 |
+
project_id: str,
|
| 30 |
+
config: SimulationCreate = None,
|
| 31 |
+
current_user: User = Depends(get_current_user),
|
| 32 |
+
db: Session = Depends(get_db)
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Start a new simulation for a project
|
| 36 |
+
"""
|
| 37 |
+
# Verify project exists and belongs to user
|
| 38 |
+
project = db.query(Project).filter(
|
| 39 |
+
Project.id == project_id,
|
| 40 |
+
Project.user_id == current_user.id
|
| 41 |
+
).first()
|
| 42 |
+
|
| 43 |
+
if not project:
|
| 44 |
+
raise HTTPException(
|
| 45 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 46 |
+
detail="Project not found"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Check if project is ready (VLM processing complete)
|
| 50 |
+
if project.status != "READY":
|
| 51 |
+
raise HTTPException(
|
| 52 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 53 |
+
detail=f"Project is not ready for simulation. Current status: {project.status}"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Use default config if not provided
|
| 57 |
+
if config is None:
|
| 58 |
+
config = SimulationCreate()
|
| 59 |
+
|
| 60 |
+
# Create simulation run record
|
| 61 |
+
simulation = SimulationRun(
|
| 62 |
+
project_id=project.id,
|
| 63 |
+
num_agents=config.num_agents,
|
| 64 |
+
simulation_days=config.simulation_days,
|
| 65 |
+
status="PENDING"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
db.add(simulation)
|
| 69 |
+
db.commit()
|
| 70 |
+
db.refresh(simulation)
|
| 71 |
+
|
| 72 |
+
# Start simulation task (async)
|
| 73 |
+
from app.tasks import run_simulation_task
|
| 74 |
+
run_simulation_task.delay(str(simulation.id))
|
| 75 |
+
|
| 76 |
+
return simulation
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@router.get("/{simulation_id}", response_model=SimulationResponse)
|
| 80 |
+
async def get_simulation(
|
| 81 |
+
simulation_id: str,
|
| 82 |
+
current_user: User = Depends(get_current_user),
|
| 83 |
+
db: Session = Depends(get_db)
|
| 84 |
+
):
|
| 85 |
+
"""
|
| 86 |
+
Get simulation details
|
| 87 |
+
"""
|
| 88 |
+
simulation = db.query(SimulationRun).join(Project).filter(
|
| 89 |
+
SimulationRun.id == simulation_id,
|
| 90 |
+
Project.user_id == current_user.id
|
| 91 |
+
).first()
|
| 92 |
+
|
| 93 |
+
if not simulation:
|
| 94 |
+
raise HTTPException(
|
| 95 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 96 |
+
detail="Simulation not found"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
return simulation
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@router.get("/{simulation_id}/status", response_model=SimulationStatusResponse)
|
| 103 |
+
async def get_simulation_status(
|
| 104 |
+
simulation_id: str,
|
| 105 |
+
current_user: User = Depends(get_current_user),
|
| 106 |
+
db: Session = Depends(get_db)
|
| 107 |
+
):
|
| 108 |
+
"""
|
| 109 |
+
Get current simulation status (for polling)
|
| 110 |
+
"""
|
| 111 |
+
simulation = db.query(SimulationRun).join(Project).filter(
|
| 112 |
+
SimulationRun.id == simulation_id,
|
| 113 |
+
Project.user_id == current_user.id
|
| 114 |
+
).first()
|
| 115 |
+
|
| 116 |
+
if not simulation:
|
| 117 |
+
raise HTTPException(
|
| 118 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 119 |
+
detail="Simulation not found"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Calculate progress for running simulations
|
| 123 |
+
progress = None
|
| 124 |
+
current_day = None
|
| 125 |
+
active_agents = None
|
| 126 |
+
|
| 127 |
+
if simulation.status == "RUNNING":
|
| 128 |
+
# Get progress from Redis cache if available
|
| 129 |
+
import redis
|
| 130 |
+
import ssl as ssl_module
|
| 131 |
+
try:
|
| 132 |
+
redis_kwargs = {}
|
| 133 |
+
if settings.redis_url.startswith("rediss://"):
|
| 134 |
+
redis_kwargs["ssl_cert_reqs"] = ssl_module.CERT_REQUIRED
|
| 135 |
+
r = redis.from_url(settings.redis_url, **redis_kwargs)
|
| 136 |
+
cached = r.get(f"sim:{simulation_id}:status")
|
| 137 |
+
if cached:
|
| 138 |
+
import json
|
| 139 |
+
status_data = json.loads(cached)
|
| 140 |
+
progress = status_data.get("progress")
|
| 141 |
+
current_day = status_data.get("current_day")
|
| 142 |
+
active_agents = status_data.get("active_agents")
|
| 143 |
+
except:
|
| 144 |
+
pass
|
| 145 |
+
elif simulation.status == "COMPLETED":
|
| 146 |
+
progress = 100
|
| 147 |
+
|
| 148 |
+
return SimulationStatusResponse(
|
| 149 |
+
id=simulation.id,
|
| 150 |
+
status=simulation.status,
|
| 151 |
+
progress=progress,
|
| 152 |
+
current_day=current_day,
|
| 153 |
+
active_agents=active_agents
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@router.get("/{simulation_id}/results", response_model=SimulationResultsResponse)
|
| 158 |
+
async def get_simulation_results(
|
| 159 |
+
simulation_id: str,
|
| 160 |
+
current_user: User = Depends(get_current_user),
|
| 161 |
+
db: Session = Depends(get_db)
|
| 162 |
+
):
|
| 163 |
+
"""
|
| 164 |
+
Get full simulation results with risk flags
|
| 165 |
+
"""
|
| 166 |
+
simulation = db.query(SimulationRun).join(Project).filter(
|
| 167 |
+
SimulationRun.id == simulation_id,
|
| 168 |
+
Project.user_id == current_user.id
|
| 169 |
+
).first()
|
| 170 |
+
|
| 171 |
+
if not simulation:
|
| 172 |
+
raise HTTPException(
|
| 173 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 174 |
+
detail="Simulation not found"
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if simulation.status != "COMPLETED":
|
| 178 |
+
raise HTTPException(
|
| 179 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 180 |
+
detail="Simulation has not completed yet"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Get risk flags
|
| 184 |
+
risk_flags = db.query(RiskFlag).filter(
|
| 185 |
+
RiskFlag.simulation_run_id == simulation.id
|
| 186 |
+
).order_by(RiskFlag.severity.desc()).all()
|
| 187 |
+
|
| 188 |
+
# Get sample agent logs
|
| 189 |
+
agent_sample = db.query(AgentLog).filter(
|
| 190 |
+
AgentLog.simulation_run_id == simulation.id,
|
| 191 |
+
AgentLog.event_type.in_(["BOYCOTT", "ENDORSEMENT"])
|
| 192 |
+
).limit(10).all()
|
| 193 |
+
|
| 194 |
+
sample_data = [
|
| 195 |
+
{
|
| 196 |
+
"agent_id": log.agent_id,
|
| 197 |
+
"event_type": log.event_type,
|
| 198 |
+
"event_data": log.event_data,
|
| 199 |
+
"timestamp": log.timestamp.isoformat()
|
| 200 |
+
}
|
| 201 |
+
for log in agent_sample
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
return SimulationResultsResponse(
|
| 205 |
+
simulation=simulation,
|
| 206 |
+
risk_flags=risk_flags,
|
| 207 |
+
agent_sample=sample_data
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@router.get("/project/{project_id}", response_model=List[SimulationResponse])
|
| 212 |
+
async def list_project_simulations(
|
| 213 |
+
project_id: str,
|
| 214 |
+
current_user: User = Depends(get_current_user),
|
| 215 |
+
db: Session = Depends(get_db)
|
| 216 |
+
):
|
| 217 |
+
"""
|
| 218 |
+
List all simulations for a project
|
| 219 |
+
"""
|
| 220 |
+
project = db.query(Project).filter(
|
| 221 |
+
Project.id == project_id,
|
| 222 |
+
Project.user_id == current_user.id
|
| 223 |
+
).first()
|
| 224 |
+
|
| 225 |
+
if not project:
|
| 226 |
+
raise HTTPException(
|
| 227 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 228 |
+
detail="Project not found"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
simulations = db.query(SimulationRun).filter(
|
| 232 |
+
SimulationRun.project_id == project_id
|
| 233 |
+
).order_by(SimulationRun.created_at.desc()).all()
|
| 234 |
+
|
| 235 |
+
return simulations
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@router.get("/{simulation_id}/map-data", response_model=MapDataResponse)
|
| 239 |
+
async def get_simulation_map_data(
|
| 240 |
+
simulation_id: str,
|
| 241 |
+
current_user: User = Depends(get_current_user),
|
| 242 |
+
db: Session = Depends(get_db)
|
| 243 |
+
):
|
| 244 |
+
"""
|
| 245 |
+
Get lightweight map data for all agents in a completed simulation.
|
| 246 |
+
Returns coordinates, opinion and friends for each agent.
|
| 247 |
+
"""
|
| 248 |
+
simulation = db.query(SimulationRun).join(Project).filter(
|
| 249 |
+
SimulationRun.id == simulation_id,
|
| 250 |
+
Project.user_id == current_user.id
|
| 251 |
+
).first()
|
| 252 |
+
|
| 253 |
+
if not simulation:
|
| 254 |
+
raise HTTPException(
|
| 255 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 256 |
+
detail="Simulation not found"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if simulation.status != "COMPLETED":
|
| 260 |
+
raise HTTPException(
|
| 261 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 262 |
+
detail="Simulation has not completed yet"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
map_data = simulation.map_data or []
|
| 266 |
+
return MapDataResponse(map_data=map_data)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
@router.get("/{simulation_id}/agents/{agent_id}", response_model=AgentDetailResponse)
|
| 270 |
+
async def get_agent_detail(
|
| 271 |
+
simulation_id: str,
|
| 272 |
+
agent_id: str,
|
| 273 |
+
current_user: User = Depends(get_current_user),
|
| 274 |
+
db: Session = Depends(get_db)
|
| 275 |
+
):
|
| 276 |
+
"""
|
| 277 |
+
Get full detail for a specific agent in a simulation.
|
| 278 |
+
Returns profile, emotion, opinion, reasoning and friends.
|
| 279 |
+
"""
|
| 280 |
+
simulation = db.query(SimulationRun).join(Project).filter(
|
| 281 |
+
SimulationRun.id == simulation_id,
|
| 282 |
+
Project.user_id == current_user.id
|
| 283 |
+
).first()
|
| 284 |
+
|
| 285 |
+
if not simulation:
|
| 286 |
+
raise HTTPException(
|
| 287 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 288 |
+
detail="Simulation not found"
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
if simulation.status != "COMPLETED":
|
| 292 |
+
raise HTTPException(
|
| 293 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 294 |
+
detail="Simulation has not completed yet"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# Find the agent in agent_states
|
| 298 |
+
agent_states = simulation.agent_states or []
|
| 299 |
+
agent_data = None
|
| 300 |
+
for state in agent_states:
|
| 301 |
+
if state.get("agent_id") == agent_id:
|
| 302 |
+
agent_data = state
|
| 303 |
+
break
|
| 304 |
+
|
| 305 |
+
if not agent_data:
|
| 306 |
+
raise HTTPException(
|
| 307 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 308 |
+
detail=f"Agent {agent_id} not found in simulation"
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
profile_raw = agent_data.get("profile", {})
|
| 312 |
+
return AgentDetailResponse(
|
| 313 |
+
agent_id=agent_data["agent_id"],
|
| 314 |
+
coordinates=agent_data.get("coordinates", [0, 0]),
|
| 315 |
+
opinion=agent_data.get("opinion", "NEUTRAL"),
|
| 316 |
+
emotion=agent_data.get("emotion", "neutral"),
|
| 317 |
+
emotion_intensity=agent_data.get("emotion_intensity", 0),
|
| 318 |
+
reasoning=agent_data.get("reasoning", ""),
|
| 319 |
+
friends=agent_data.get("friends", []),
|
| 320 |
+
profile=AgentProfileData(
|
| 321 |
+
age=profile_raw.get("age"),
|
| 322 |
+
gender=profile_raw.get("gender"),
|
| 323 |
+
location=profile_raw.get("location"),
|
| 324 |
+
occupation=profile_raw.get("occupation"),
|
| 325 |
+
education=profile_raw.get("education"),
|
| 326 |
+
values=profile_raw.get("values", []),
|
| 327 |
+
),
|
| 328 |
+
)
|
schemas.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic schemas for request/response validation
|
| 3 |
+
"""
|
| 4 |
+
from pydantic import BaseModel, EmailStr, Field
|
| 5 |
+
from typing import Optional, Dict, Any, List
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from uuid import UUID
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# ----- User Schemas -----
|
| 11 |
+
class UserCreate(BaseModel):
|
| 12 |
+
email: EmailStr
|
| 13 |
+
password: str = Field(min_length=8, max_length=100)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class UserLogin(BaseModel):
|
| 17 |
+
email: EmailStr
|
| 18 |
+
password: str
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class UserResponse(BaseModel):
|
| 22 |
+
id: UUID
|
| 23 |
+
email: str
|
| 24 |
+
subscription_tier: str
|
| 25 |
+
created_at: datetime
|
| 26 |
+
|
| 27 |
+
class Config:
|
| 28 |
+
from_attributes = True
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TokenResponse(BaseModel):
|
| 32 |
+
access_token: str
|
| 33 |
+
token_type: str = "bearer"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ----- Project Schemas -----
|
| 37 |
+
class DemographicFilter(BaseModel):
|
| 38 |
+
age_range: Optional[List[int]] = None
|
| 39 |
+
location: Optional[str] = None
|
| 40 |
+
gender: Optional[str] = None
|
| 41 |
+
values: Optional[List[str]] = None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ProjectCreate(BaseModel):
|
| 45 |
+
title: str = Field(min_length=1, max_length=200)
|
| 46 |
+
demographic_filter: Optional[DemographicFilter] = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ProjectResponse(BaseModel):
|
| 50 |
+
id: UUID
|
| 51 |
+
title: str
|
| 52 |
+
video_path: str
|
| 53 |
+
video_duration_seconds: Optional[int]
|
| 54 |
+
vlm_generated_context: Optional[str]
|
| 55 |
+
demographic_filter: Optional[Dict[str, Any]]
|
| 56 |
+
status: str
|
| 57 |
+
created_at: datetime
|
| 58 |
+
|
| 59 |
+
class Config:
|
| 60 |
+
from_attributes = True
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class ProjectListResponse(BaseModel):
|
| 64 |
+
id: UUID
|
| 65 |
+
title: str
|
| 66 |
+
status: str
|
| 67 |
+
created_at: datetime
|
| 68 |
+
|
| 69 |
+
class Config:
|
| 70 |
+
from_attributes = True
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ----- Simulation Schemas -----
|
| 74 |
+
class SimulationCreate(BaseModel):
|
| 75 |
+
num_agents: int = Field(default=10, ge=1, le=10000)
|
| 76 |
+
simulation_days: int = Field(default=5, ge=1, le=30)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class SentimentBreakdown(BaseModel):
|
| 80 |
+
positive: int = 0
|
| 81 |
+
neutral: int = 0
|
| 82 |
+
negative: int = 0
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class RiskFlagResponse(BaseModel):
|
| 86 |
+
id: UUID
|
| 87 |
+
flag_type: str
|
| 88 |
+
severity: str
|
| 89 |
+
description: str
|
| 90 |
+
affected_demographics: Optional[Dict[str, Any]]
|
| 91 |
+
sample_agent_reactions: Optional[List[Dict[str, Any]]]
|
| 92 |
+
detected_at: datetime
|
| 93 |
+
|
| 94 |
+
class Config:
|
| 95 |
+
from_attributes = True
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class SimulationResponse(BaseModel):
|
| 99 |
+
id: UUID
|
| 100 |
+
project_id: UUID
|
| 101 |
+
status: str
|
| 102 |
+
num_agents: int
|
| 103 |
+
simulation_days: int
|
| 104 |
+
virality_score: Optional[float]
|
| 105 |
+
sentiment_breakdown: Optional[Dict[str, int]]
|
| 106 |
+
started_at: Optional[datetime]
|
| 107 |
+
completed_at: Optional[datetime]
|
| 108 |
+
error_message: Optional[str]
|
| 109 |
+
created_at: datetime
|
| 110 |
+
|
| 111 |
+
class Config:
|
| 112 |
+
from_attributes = True
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class SimulationStatusResponse(BaseModel):
|
| 116 |
+
id: UUID
|
| 117 |
+
status: str
|
| 118 |
+
progress: Optional[int] = None
|
| 119 |
+
current_day: Optional[int] = None
|
| 120 |
+
active_agents: Optional[int] = None
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class SimulationResultsResponse(BaseModel):
|
| 124 |
+
simulation: SimulationResponse
|
| 125 |
+
risk_flags: List[RiskFlagResponse]
|
| 126 |
+
agent_sample: Optional[List[Dict[str, Any]]] = None
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ----- Map Visualization Schemas -----
|
| 130 |
+
class MapAgentData(BaseModel):
|
| 131 |
+
agent_id: str
|
| 132 |
+
coordinates: List[float]
|
| 133 |
+
opinion: str
|
| 134 |
+
friends: List[str]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class MapDataResponse(BaseModel):
|
| 138 |
+
map_data: List[MapAgentData]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class AgentProfileData(BaseModel):
|
| 142 |
+
age: Optional[int] = None
|
| 143 |
+
gender: Optional[str] = None
|
| 144 |
+
location: Optional[str] = None
|
| 145 |
+
occupation: Optional[str] = None
|
| 146 |
+
education: Optional[str] = None
|
| 147 |
+
values: List[str] = []
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class AgentDetailResponse(BaseModel):
|
| 151 |
+
agent_id: str
|
| 152 |
+
coordinates: List[float]
|
| 153 |
+
opinion: str
|
| 154 |
+
emotion: str
|
| 155 |
+
emotion_intensity: float = 0
|
| 156 |
+
reasoning: str = ""
|
| 157 |
+
friends: List[str] = []
|
| 158 |
+
profile: AgentProfileData
|
| 159 |
+
|
services/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Service module exports
|
| 3 |
+
"""
|
| 4 |
+
from app.services.auth_service import hash_password, verify_password, create_access_token
|
| 5 |
+
from app.services.vlm_service import process_video, analyze_video_with_gemini
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"hash_password",
|
| 9 |
+
"verify_password",
|
| 10 |
+
"create_access_token",
|
| 11 |
+
"process_video",
|
| 12 |
+
"analyze_video_with_gemini"
|
| 13 |
+
]
|
services/auth_service.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication service - password hashing and JWT token management
|
| 3 |
+
"""
|
| 4 |
+
from datetime import datetime, timedelta
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import bcrypt
|
| 7 |
+
from jose import JWTError, jwt
|
| 8 |
+
from app.config import get_settings
|
| 9 |
+
|
| 10 |
+
settings = get_settings()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def hash_password(password: str) -> str:
|
| 14 |
+
"""Hash a password using bcrypt directly (passlib-free, bcrypt>=4.0 compatible)"""
|
| 15 |
+
password_bytes = password.encode("utf-8")
|
| 16 |
+
salt = bcrypt.gensalt()
|
| 17 |
+
hashed = bcrypt.hashpw(password_bytes, salt)
|
| 18 |
+
return hashed.decode("utf-8")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 22 |
+
"""Verify a password against its bcrypt hash"""
|
| 23 |
+
try:
|
| 24 |
+
return bcrypt.checkpw(
|
| 25 |
+
plain_password.encode("utf-8"),
|
| 26 |
+
hashed_password.encode("utf-8")
|
| 27 |
+
)
|
| 28 |
+
except Exception:
|
| 29 |
+
return False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
| 33 |
+
"""
|
| 34 |
+
Create a JWT access token
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
data: Payload to encode (should include 'sub' with user ID)
|
| 38 |
+
expires_delta: Token expiry time
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Encoded JWT token
|
| 42 |
+
"""
|
| 43 |
+
to_encode = data.copy()
|
| 44 |
+
|
| 45 |
+
if expires_delta:
|
| 46 |
+
expire = datetime.utcnow() + expires_delta
|
| 47 |
+
else:
|
| 48 |
+
expire = datetime.utcnow() + timedelta(hours=settings.jwt_expiry_hours)
|
| 49 |
+
|
| 50 |
+
to_encode.update({"exp": expire})
|
| 51 |
+
|
| 52 |
+
encoded_jwt = jwt.encode(
|
| 53 |
+
to_encode,
|
| 54 |
+
settings.jwt_secret,
|
| 55 |
+
algorithm=settings.jwt_algorithm
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
return encoded_jwt
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def decode_access_token(token: str) -> Optional[dict]:
|
| 62 |
+
"""
|
| 63 |
+
Decode and validate a JWT token
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
token: JWT token string
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Token payload if valid, None otherwise
|
| 70 |
+
"""
|
| 71 |
+
try:
|
| 72 |
+
payload = jwt.decode(
|
| 73 |
+
token,
|
| 74 |
+
settings.jwt_secret,
|
| 75 |
+
algorithms=[settings.jwt_algorithm]
|
| 76 |
+
)
|
| 77 |
+
return payload
|
| 78 |
+
except JWTError:
|
| 79 |
+
return None
|
services/vlm_service.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Video processing service using Google Gemini Files API
|
| 3 |
+
Uploads entire video instead of extracting frames for better analysis
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
from google import genai
|
| 10 |
+
from app.config import get_settings
|
| 11 |
+
|
| 12 |
+
settings = get_settings()
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
# Configure Gemini client
|
| 16 |
+
_client = None
|
| 17 |
+
|
| 18 |
+
def get_client():
|
| 19 |
+
"""Get or create Gemini client"""
|
| 20 |
+
global _client
|
| 21 |
+
if _client is None:
|
| 22 |
+
_client = genai.Client(api_key=settings.gemini_api_key)
|
| 23 |
+
return _client
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_video_duration_cv2(video_path: str) -> int:
|
| 27 |
+
"""Get video duration in seconds using OpenCV"""
|
| 28 |
+
try:
|
| 29 |
+
import cv2
|
| 30 |
+
video = cv2.VideoCapture(video_path)
|
| 31 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
| 32 |
+
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
|
| 33 |
+
video.release()
|
| 34 |
+
|
| 35 |
+
if fps > 0:
|
| 36 |
+
return int(frame_count / fps)
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.warning(f"Could not get video duration: {e}")
|
| 39 |
+
return 0
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def analyze_video_with_gemini(video_path: str) -> str:
|
| 43 |
+
"""
|
| 44 |
+
Analyze video using Gemini Files API
|
| 45 |
+
|
| 46 |
+
This uploads the entire video file to Gemini for comprehensive analysis,
|
| 47 |
+
which provides better context than frame extraction.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
video_path: Local path to video file
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Scene descriptions and analysis
|
| 54 |
+
"""
|
| 55 |
+
try:
|
| 56 |
+
client = get_client()
|
| 57 |
+
|
| 58 |
+
# Step 1: Upload video to Gemini
|
| 59 |
+
logger.info(f"Uploading video to Gemini: {video_path}")
|
| 60 |
+
uploaded_file = client.files.upload(file=video_path)
|
| 61 |
+
logger.info(f"Uploaded file: {uploaded_file.name}")
|
| 62 |
+
|
| 63 |
+
# Step 2: Wait for processing (videos need processing time)
|
| 64 |
+
max_wait = 120 # Maximum 2 minutes wait
|
| 65 |
+
wait_time = 0
|
| 66 |
+
|
| 67 |
+
while uploaded_file.state == "PROCESSING":
|
| 68 |
+
if wait_time >= max_wait:
|
| 69 |
+
logger.error("Video processing timeout")
|
| 70 |
+
return "Video processing timeout. Please try with a shorter video."
|
| 71 |
+
|
| 72 |
+
logger.info(f"Waiting for video processing... ({wait_time}s)")
|
| 73 |
+
time.sleep(10)
|
| 74 |
+
wait_time += 10
|
| 75 |
+
uploaded_file = client.files.get(name=uploaded_file.name)
|
| 76 |
+
|
| 77 |
+
if uploaded_file.state == "FAILED":
|
| 78 |
+
logger.error("Video processing failed on Gemini side")
|
| 79 |
+
return "Video processing failed. Please try a different video format."
|
| 80 |
+
|
| 81 |
+
logger.info("Video processing complete, analyzing content...")
|
| 82 |
+
|
| 83 |
+
# Step 3: Analyze with Gemini
|
| 84 |
+
prompt = """Analyze this video advertisement in detail. Describe:
|
| 85 |
+
|
| 86 |
+
1. **Visual Elements**: People shown (ages, genders, relationships), settings, products
|
| 87 |
+
2. **Cultural Context**: Any cultural symbols, traditions, or practices shown
|
| 88 |
+
3. **Narrative**: What story is being told? What emotions are evoked?
|
| 89 |
+
4. **Target Audience**: Who appears to be the intended audience?
|
| 90 |
+
5. **Potential Sensitivities**: Any elements that might be controversial or offensive to certain groups (religious, political, cultural, age-related)
|
| 91 |
+
6. **Message**: What is the main message or call to action?
|
| 92 |
+
|
| 93 |
+
Be specific and detailed. Focus on elements that could affect how different demographic groups might react to this advertisement."""
|
| 94 |
+
|
| 95 |
+
response = client.models.generate_content(
|
| 96 |
+
model="gemini-3.1-flash-lite-preview",
|
| 97 |
+
contents=[uploaded_file, prompt]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Step 4: Clean up uploaded file (optional - files expire automatically)
|
| 101 |
+
try:
|
| 102 |
+
client.files.delete(name=uploaded_file.name)
|
| 103 |
+
logger.info("Cleaned up uploaded video file")
|
| 104 |
+
except:
|
| 105 |
+
pass # File cleanup is optional
|
| 106 |
+
|
| 107 |
+
if response.text:
|
| 108 |
+
logger.info(f"Generated analysis: {len(response.text)} characters")
|
| 109 |
+
return response.text
|
| 110 |
+
else:
|
| 111 |
+
return "No analysis could be generated for this video."
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.error(f"Gemini video analysis failed: {e}")
|
| 115 |
+
return f"Video analysis failed: {str(e)}"
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def process_video(video_path: str) -> Tuple[str, int]:
|
| 119 |
+
"""
|
| 120 |
+
Full video processing pipeline using Gemini Files API
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
video_path: Path to video file
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Tuple of (scene_descriptions, duration_seconds)
|
| 127 |
+
"""
|
| 128 |
+
logger.info(f"Processing video: {video_path}")
|
| 129 |
+
|
| 130 |
+
# Get duration
|
| 131 |
+
duration = get_video_duration_cv2(video_path)
|
| 132 |
+
|
| 133 |
+
# Analyze with Gemini Files API
|
| 134 |
+
descriptions = analyze_video_with_gemini(video_path)
|
| 135 |
+
|
| 136 |
+
return descriptions, duration
|
tasks.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Celery tasks for background processing
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import logging
|
| 7 |
+
from celery import Celery
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from app.config import get_settings
|
| 10 |
+
|
| 11 |
+
# Add simulation directory to path (it's a sibling of backend)
|
| 12 |
+
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
+
project_root = os.path.dirname(backend_dir)
|
| 14 |
+
simulation_path = os.path.join(project_root, "simulation")
|
| 15 |
+
if simulation_path not in sys.path:
|
| 16 |
+
sys.path.insert(0, project_root)
|
| 17 |
+
|
| 18 |
+
settings = get_settings()
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# Initialize Celery
|
| 22 |
+
celery_app = Celery(
|
| 23 |
+
"agentsociety",
|
| 24 |
+
broker=settings.redis_url,
|
| 25 |
+
backend=settings.redis_url
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
import ssl as ssl_module
|
| 29 |
+
|
| 30 |
+
celery_app.conf.update(
|
| 31 |
+
task_serializer="json",
|
| 32 |
+
accept_content=["json"],
|
| 33 |
+
result_serializer="json",
|
| 34 |
+
timezone="UTC",
|
| 35 |
+
enable_utc=True,
|
| 36 |
+
task_track_started=True,
|
| 37 |
+
task_time_limit=3600, # 1 hour max
|
| 38 |
+
broker_connection_retry_on_startup=True, # silence CPendingDeprecationWarning in Celery 5/6
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Enable SSL for rediss:// connections (e.g. Upstash)
|
| 42 |
+
if settings.redis_url.startswith("rediss://"):
|
| 43 |
+
celery_app.conf.update(
|
| 44 |
+
broker_use_ssl={"ssl_cert_reqs": ssl_module.CERT_REQUIRED},
|
| 45 |
+
redis_backend_use_ssl={"ssl_cert_reqs": ssl_module.CERT_REQUIRED},
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@celery_app.task(bind=True)
|
| 50 |
+
def process_video_task(self, project_id: str):
|
| 51 |
+
"""
|
| 52 |
+
Background task to process video with VLM
|
| 53 |
+
"""
|
| 54 |
+
from app.database import SessionLocal
|
| 55 |
+
from app.models import Project
|
| 56 |
+
from app.services.vlm_service import process_video
|
| 57 |
+
|
| 58 |
+
db = SessionLocal()
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
# Get project
|
| 62 |
+
project = db.query(Project).filter(Project.id == project_id).first()
|
| 63 |
+
if not project:
|
| 64 |
+
logger.error(f"Project not found: {project_id}")
|
| 65 |
+
return {"error": "Project not found"}
|
| 66 |
+
|
| 67 |
+
# Update status
|
| 68 |
+
project.status = "PROCESSING"
|
| 69 |
+
db.commit()
|
| 70 |
+
|
| 71 |
+
logger.info(f"Processing video for project {project_id}")
|
| 72 |
+
|
| 73 |
+
# Process video
|
| 74 |
+
descriptions, duration = process_video(project.video_path)
|
| 75 |
+
|
| 76 |
+
# Update project with results
|
| 77 |
+
project.vlm_generated_context = descriptions
|
| 78 |
+
project.video_duration_seconds = duration
|
| 79 |
+
project.status = "READY"
|
| 80 |
+
db.commit()
|
| 81 |
+
|
| 82 |
+
logger.info(f"Video processing complete for project {project_id}")
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"project_id": project_id,
|
| 86 |
+
"status": "READY",
|
| 87 |
+
"duration": duration
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.error(f"Video processing failed for project {project_id}: {e}")
|
| 92 |
+
try:
|
| 93 |
+
project = db.query(Project).filter(Project.id == project_id).first()
|
| 94 |
+
if project:
|
| 95 |
+
project.status = "FAILED"
|
| 96 |
+
db.commit()
|
| 97 |
+
except:
|
| 98 |
+
pass
|
| 99 |
+
return {"error": str(e)}
|
| 100 |
+
finally:
|
| 101 |
+
db.close()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@celery_app.task(bind=True)
|
| 105 |
+
def run_simulation_task(self, simulation_id: str):
|
| 106 |
+
"""
|
| 107 |
+
Background task to queue simulation for Ray worker
|
| 108 |
+
|
| 109 |
+
This task:
|
| 110 |
+
1. Validates the simulation and project
|
| 111 |
+
2. Publishes request to Redis 'simulation_requests' channel
|
| 112 |
+
3. Ray worker (separate process) handles the actual simulation
|
| 113 |
+
4. Results listener updates the database when complete
|
| 114 |
+
"""
|
| 115 |
+
from app.database import SessionLocal
|
| 116 |
+
from app.models import SimulationRun, Project
|
| 117 |
+
import json
|
| 118 |
+
import redis
|
| 119 |
+
|
| 120 |
+
db = SessionLocal()
|
| 121 |
+
redis_kwargs = {}
|
| 122 |
+
if settings.redis_url.startswith("rediss://"):
|
| 123 |
+
redis_kwargs["ssl_cert_reqs"] = ssl_module.CERT_REQUIRED
|
| 124 |
+
redis_client = redis.from_url(settings.redis_url, **redis_kwargs)
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
# Get simulation
|
| 128 |
+
simulation = db.query(SimulationRun).filter(SimulationRun.id == simulation_id).first()
|
| 129 |
+
if not simulation:
|
| 130 |
+
logger.error(f"Simulation not found: {simulation_id}")
|
| 131 |
+
return {"error": "Simulation not found"}
|
| 132 |
+
|
| 133 |
+
# Get project
|
| 134 |
+
project = db.query(Project).filter(Project.id == simulation.project_id).first()
|
| 135 |
+
if not project or not project.vlm_generated_context:
|
| 136 |
+
logger.error(f"Project not ready for simulation: {simulation.project_id}")
|
| 137 |
+
simulation.status = "FAILED"
|
| 138 |
+
simulation.error_message = "Project video analysis not complete"
|
| 139 |
+
db.commit()
|
| 140 |
+
return {"error": "Project not ready"}
|
| 141 |
+
|
| 142 |
+
# Update status to RUNNING (Ray worker will process)
|
| 143 |
+
simulation.status = "RUNNING"
|
| 144 |
+
simulation.started_at = datetime.utcnow()
|
| 145 |
+
db.commit()
|
| 146 |
+
|
| 147 |
+
logger.info(f"Sending simulation {simulation_id} to Ray worker")
|
| 148 |
+
|
| 149 |
+
# Publish request to Redis for Ray worker
|
| 150 |
+
request = {
|
| 151 |
+
"simulation_id": str(simulation.id),
|
| 152 |
+
"project_id": str(project.id),
|
| 153 |
+
"ad_content": project.vlm_generated_context,
|
| 154 |
+
"demographic_filter": project.demographic_filter,
|
| 155 |
+
"num_agents": simulation.num_agents,
|
| 156 |
+
"simulation_days": simulation.simulation_days
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
redis_client.publish("simulation_requests", json.dumps(request))
|
| 160 |
+
|
| 161 |
+
logger.info(f"Simulation {simulation_id} published to Ray worker queue")
|
| 162 |
+
|
| 163 |
+
return {
|
| 164 |
+
"simulation_id": simulation_id,
|
| 165 |
+
"status": "RUNNING",
|
| 166 |
+
"message": "Simulation sent to Ray worker for processing"
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
except Exception as e:
|
| 170 |
+
logger.error(f"Failed to queue simulation {simulation_id}: {e}")
|
| 171 |
+
try:
|
| 172 |
+
simulation = db.query(SimulationRun).filter(SimulationRun.id == simulation_id).first()
|
| 173 |
+
if simulation:
|
| 174 |
+
simulation.status = "FAILED"
|
| 175 |
+
simulation.error_message = str(e)
|
| 176 |
+
simulation.completed_at = datetime.utcnow()
|
| 177 |
+
db.commit()
|
| 178 |
+
except:
|
| 179 |
+
pass
|
| 180 |
+
return {"error": str(e)}
|
| 181 |
+
finally:
|
| 182 |
+
db.close()
|