Spaces:
Sleeping
Sleeping
File size: 5,654 Bytes
d5461ce a0c7896 8a36f2d 643414d 8a36f2d a0c7896 d5461ce a73cf00 d5461ce dcda928 bcc2d91 8a36f2d d5461ce 0590b5e d5461ce 8a36f2d 0590b5e d5461ce 0590b5e d5461ce 0590b5e d5461ce 1a68f03 779fb9a 0c7fe7f 779fb9a 643414d 779fb9a 643414d 3f67da4 643414d 779fb9a 37d617c 4c419f4 6dab65e 4c419f4 1a68f03 dcda928 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | from fastapi import FastAPI, Body, Request, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
from fastapi.requests import Request
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from loguru import logger
import aiohttp
import uvicorn
import asyncio
import os
import uuid
import toml
import time
from datetime import datetime
from json import dumps
# Load OPENMANUS_ENDPOINT_URL from env or config fallback
OPENMANUS_ENDPOINT_URL = os.getenv("OPENMANUS_ENDPOINT_URL")
if not OPENMANUS_ENDPOINT_URL:
config_path = "config/config.toml"
if os.path.exists(config_path):
config = toml.load(config_path)
OPENMANUS_ENDPOINT_URL = config.get("OPENMANUS_ENDPOINT_URL")
if not OPENMANUS_ENDPOINT_URL:
raise EnvironmentError("OPENMANUS_ENDPOINT_URL must be set in env or config/config.toml")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
class Task(BaseModel):
id: str
prompt: str
created_at: datetime
status: str
steps: list = []
def model_dump(self, *args, **kwargs):
data = super().model_dump(*args, **kwargs)
data["created_at"] = self.created_at.isoformat()
return data
class TaskManager:
def __init__(self):
self.tasks = {}
self.queues = {}
def create_task(self, prompt: str) -> Task:
task_id = str(uuid.uuid4())
task = Task(
id=task_id, prompt=prompt, created_at=datetime.now(), status="pending"
)
self.tasks[task_id] = task
self.queues[task_id] = asyncio.Queue()
return task
async def update_task_step(self, task_id: str, step: int, result: str, step_type: str = "step"):
if task_id in self.tasks:
task = self.tasks[task_id]
task.steps.append({"step": step, "result": result, "type": step_type})
await self.queues[task_id].put({"type": step_type, "step": step, "result": result})
await self.queues[task_id].put({"type": "status", "status": task.status, "steps": task.steps})
async def complete_task(self, task_id: str):
if task_id in self.tasks:
task = self.tasks[task_id]
task.status = "completed"
await self.queues[task_id].put({"type": "status", "status": task.status, "steps": task.steps})
await self.queues[task_id].put({"type": "complete"})
async def fail_task(self, task_id: str, error: str):
if task_id in self.tasks:
self.tasks[task_id].status = f"failed: {error}"
await self.queues[task_id].put({"type": "error", "message": error})
task_manager = TaskManager()
@app.post("/tasks")
async def create_task(prompt: str = Body(..., embed=True)):
task = task_manager.create_task(prompt)
asyncio.create_task(run_task(task.id, prompt))
return {"task_id": task.id}
async def run_task(task_id: str, prompt: str):
try:
logger.info(f"Simulating task: {task_id} with prompt: {prompt}")
task_manager.tasks[task_id].status = "running"
# Simulated processing
await asyncio.sleep(2) # simulate delay
result_text = f"Simulated response for prompt: '{prompt}'"
await task_manager.update_task_step(task_id, 0, result_text, "result")
await task_manager.complete_task(task_id)
except Exception as e:
logger.error(f"Simulated task failed: {e}")
await task_manager.fail_task(task_id, str(e))
@app.get("/tasks/{task_id}/events")
async def task_events(task_id: str):
logger.info(f"Client subscribed to events for task: {task_id}")
async def event_generator():
if task_id not in task_manager.queues:
yield f"event: error\ndata: {dumps({'message': 'Task not found'})}\n\n"
return
queue = task_manager.queues[task_id]
task = task_manager.tasks.get(task_id)
if task:
yield f"event: status\ndata: {dumps({'type': 'status', 'status': task.status, 'steps': task.steps})}\n\n"
last_event_time = time.time()
while True:
try:
# wait up to 5 seconds for new events
try:
event = await asyncio.wait_for(queue.get(), timeout=5.0)
formatted_event = dumps(event)
yield f"event: {event['type']}\ndata: {formatted_event}\n\n"
last_event_time = time.time()
if event["type"] in ["complete", "error"]:
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield ": heartbeat\n\n"
except asyncio.CancelledError:
break
except Exception as e:
yield f"event: error\ndata: {dumps({'message': str(e)})}\n\n"
break
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@app.get("/", response_class=HTMLResponse)
async def homepage(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
|