Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Body, Request, HTTPException | |
| from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse | |
| from fastapi.requests import Request | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from loguru import logger | |
| import aiohttp | |
| import uvicorn | |
| import asyncio | |
| import os | |
| import uuid | |
| import toml | |
| import time | |
| from datetime import datetime | |
| from json import dumps | |
| # Load OPENMANUS_ENDPOINT_URL from env or config fallback | |
| OPENMANUS_ENDPOINT_URL = os.getenv("OPENMANUS_ENDPOINT_URL") | |
| if not OPENMANUS_ENDPOINT_URL: | |
| config_path = "config/config.toml" | |
| if os.path.exists(config_path): | |
| config = toml.load(config_path) | |
| OPENMANUS_ENDPOINT_URL = config.get("OPENMANUS_ENDPOINT_URL") | |
| if not OPENMANUS_ENDPOINT_URL: | |
| raise EnvironmentError("OPENMANUS_ENDPOINT_URL must be set in env or config/config.toml") | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| class Task(BaseModel): | |
| id: str | |
| prompt: str | |
| created_at: datetime | |
| status: str | |
| steps: list = [] | |
| def model_dump(self, *args, **kwargs): | |
| data = super().model_dump(*args, **kwargs) | |
| data["created_at"] = self.created_at.isoformat() | |
| return data | |
| class TaskManager: | |
| def __init__(self): | |
| self.tasks = {} | |
| self.queues = {} | |
| def create_task(self, prompt: str) -> Task: | |
| task_id = str(uuid.uuid4()) | |
| task = Task( | |
| id=task_id, prompt=prompt, created_at=datetime.now(), status="pending" | |
| ) | |
| self.tasks[task_id] = task | |
| self.queues[task_id] = asyncio.Queue() | |
| return task | |
| async def update_task_step(self, task_id: str, step: int, result: str, step_type: str = "step"): | |
| if task_id in self.tasks: | |
| task = self.tasks[task_id] | |
| task.steps.append({"step": step, "result": result, "type": step_type}) | |
| await self.queues[task_id].put({"type": step_type, "step": step, "result": result}) | |
| await self.queues[task_id].put({"type": "status", "status": task.status, "steps": task.steps}) | |
| async def complete_task(self, task_id: str): | |
| if task_id in self.tasks: | |
| task = self.tasks[task_id] | |
| task.status = "completed" | |
| await self.queues[task_id].put({"type": "status", "status": task.status, "steps": task.steps}) | |
| await self.queues[task_id].put({"type": "complete"}) | |
| async def fail_task(self, task_id: str, error: str): | |
| if task_id in self.tasks: | |
| self.tasks[task_id].status = f"failed: {error}" | |
| await self.queues[task_id].put({"type": "error", "message": error}) | |
| task_manager = TaskManager() | |
| async def create_task(prompt: str = Body(..., embed=True)): | |
| task = task_manager.create_task(prompt) | |
| asyncio.create_task(run_task(task.id, prompt)) | |
| return {"task_id": task.id} | |
| async def run_task(task_id: str, prompt: str): | |
| try: | |
| logger.info(f"Simulating task: {task_id} with prompt: {prompt}") | |
| task_manager.tasks[task_id].status = "running" | |
| # Simulated processing | |
| await asyncio.sleep(2) # simulate delay | |
| result_text = f"Simulated response for prompt: '{prompt}'" | |
| await task_manager.update_task_step(task_id, 0, result_text, "result") | |
| await task_manager.complete_task(task_id) | |
| except Exception as e: | |
| logger.error(f"Simulated task failed: {e}") | |
| await task_manager.fail_task(task_id, str(e)) | |
| async def task_events(task_id: str): | |
| logger.info(f"Client subscribed to events for task: {task_id}") | |
| async def event_generator(): | |
| if task_id not in task_manager.queues: | |
| yield f"event: error\ndata: {dumps({'message': 'Task not found'})}\n\n" | |
| return | |
| queue = task_manager.queues[task_id] | |
| task = task_manager.tasks.get(task_id) | |
| if task: | |
| yield f"event: status\ndata: {dumps({'type': 'status', 'status': task.status, 'steps': task.steps})}\n\n" | |
| last_event_time = time.time() | |
| while True: | |
| try: | |
| # wait up to 5 seconds for new events | |
| try: | |
| event = await asyncio.wait_for(queue.get(), timeout=5.0) | |
| formatted_event = dumps(event) | |
| yield f"event: {event['type']}\ndata: {formatted_event}\n\n" | |
| last_event_time = time.time() | |
| if event["type"] in ["complete", "error"]: | |
| break | |
| except asyncio.TimeoutError: | |
| # Send heartbeat to keep connection alive | |
| yield ": heartbeat\n\n" | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| yield f"event: error\ndata: {dumps({'message': str(e)})}\n\n" | |
| break | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| async def homepage(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |