File size: 5,654 Bytes
d5461ce
 
a0c7896
 
8a36f2d
 
 
 
 
 
 
 
 
 
643414d
8a36f2d
 
a0c7896
 
d5461ce
 
 
 
 
 
 
a73cf00
d5461ce
 
dcda928
bcc2d91
 
8a36f2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5461ce
 
0590b5e
d5461ce
8a36f2d
0590b5e
 
 
 
 
d5461ce
0590b5e
d5461ce
0590b5e
d5461ce
1a68f03
779fb9a
 
0c7fe7f
 
779fb9a
 
 
 
 
 
 
 
 
 
643414d
 
779fb9a
 
643414d
 
 
 
 
 
 
 
3f67da4
643414d
 
 
 
779fb9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37d617c
4c419f4
6dab65e
 
4c419f4
1a68f03
dcda928
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from fastapi import FastAPI, Body, Request, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
from fastapi.requests import Request
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from loguru import logger
import aiohttp
import uvicorn
import asyncio
import os
import uuid
import toml
import time
from datetime import datetime
from json import dumps


# Load OPENMANUS_ENDPOINT_URL from env or config fallback
OPENMANUS_ENDPOINT_URL = os.getenv("OPENMANUS_ENDPOINT_URL")
if not OPENMANUS_ENDPOINT_URL:
    config_path = "config/config.toml"
    if os.path.exists(config_path):
        config = toml.load(config_path)
        OPENMANUS_ENDPOINT_URL = config.get("OPENMANUS_ENDPOINT_URL")

if not OPENMANUS_ENDPOINT_URL:
    raise EnvironmentError("OPENMANUS_ENDPOINT_URL must be set in env or config/config.toml")

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")

class Task(BaseModel):
    id: str
    prompt: str
    created_at: datetime
    status: str
    steps: list = []

    def model_dump(self, *args, **kwargs):
        data = super().model_dump(*args, **kwargs)
        data["created_at"] = self.created_at.isoformat()
        return data

class TaskManager:
    def __init__(self):
        self.tasks = {}
        self.queues = {}

    def create_task(self, prompt: str) -> Task:
        task_id = str(uuid.uuid4())
        task = Task(
            id=task_id, prompt=prompt, created_at=datetime.now(), status="pending"
        )
        self.tasks[task_id] = task
        self.queues[task_id] = asyncio.Queue()
        return task

    async def update_task_step(self, task_id: str, step: int, result: str, step_type: str = "step"):
        if task_id in self.tasks:
            task = self.tasks[task_id]
            task.steps.append({"step": step, "result": result, "type": step_type})
            await self.queues[task_id].put({"type": step_type, "step": step, "result": result})
            await self.queues[task_id].put({"type": "status", "status": task.status, "steps": task.steps})

    async def complete_task(self, task_id: str):
        if task_id in self.tasks:
            task = self.tasks[task_id]
            task.status = "completed"
            await self.queues[task_id].put({"type": "status", "status": task.status, "steps": task.steps})
            await self.queues[task_id].put({"type": "complete"})

    async def fail_task(self, task_id: str, error: str):
        if task_id in self.tasks:
            self.tasks[task_id].status = f"failed: {error}"
            await self.queues[task_id].put({"type": "error", "message": error})

task_manager = TaskManager()

@app.post("/tasks")
async def create_task(prompt: str = Body(..., embed=True)):
    task = task_manager.create_task(prompt)
    asyncio.create_task(run_task(task.id, prompt))
    return {"task_id": task.id}

async def run_task(task_id: str, prompt: str):
    try:
        logger.info(f"Simulating task: {task_id} with prompt: {prompt}")
        task_manager.tasks[task_id].status = "running"

        # Simulated processing
        await asyncio.sleep(2)  # simulate delay
        result_text = f"Simulated response for prompt: '{prompt}'"

        await task_manager.update_task_step(task_id, 0, result_text, "result")
        await task_manager.complete_task(task_id)

    except Exception as e:
        logger.error(f"Simulated task failed: {e}")
        await task_manager.fail_task(task_id, str(e))

@app.get("/tasks/{task_id}/events")
async def task_events(task_id: str):
    logger.info(f"Client subscribed to events for task: {task_id}")
    
    async def event_generator():
        if task_id not in task_manager.queues:
            yield f"event: error\ndata: {dumps({'message': 'Task not found'})}\n\n"
            return

        queue = task_manager.queues[task_id]
        task = task_manager.tasks.get(task_id)
        if task:
            yield f"event: status\ndata: {dumps({'type': 'status', 'status': task.status, 'steps': task.steps})}\n\n"

        last_event_time = time.time()
        
        while True:
            try:
                # wait up to 5 seconds for new events
                try:
                    event = await asyncio.wait_for(queue.get(), timeout=5.0)
                    formatted_event = dumps(event)
                    yield f"event: {event['type']}\ndata: {formatted_event}\n\n"
                    last_event_time = time.time()

                    if event["type"] in ["complete", "error"]:
                        break
                except asyncio.TimeoutError:
                    # Send heartbeat to keep connection alive
                    yield ": heartbeat\n\n"
                

            except asyncio.CancelledError:
                break
            except Exception as e:
                yield f"event: error\ndata: {dumps({'message': str(e)})}\n\n"
                break

    return StreamingResponse(
        event_generator(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no",
        },
    )

@app.get("/", response_class=HTMLResponse)
async def homepage(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})