Ken Sang Tang commited on
Commit
8a36f2d
·
verified ·
1 Parent(s): 0f097d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py CHANGED
@@ -3,6 +3,18 @@ from fastapi.responses import HTMLResponse
3
  from fastapi.requests import Request
4
  from fastapi.templating import Jinja2Templates
5
  from fastapi import Body, FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  templates = Jinja2Templates(directory="templates")
@@ -12,6 +24,71 @@ import os
12
 
13
  app = FastAPI()
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @app.get("/", response_class=HTMLResponse)
16
  async def chat(request: Request):
17
  return templates.TemplateResponse("index.html", {"request": request})
 
3
  from fastapi.requests import Request
4
  from fastapi.templating import Jinja2Templates
5
  from fastapi import Body, FastAPI
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.staticfiles import StaticFiles
8
+ from pydantic import BaseModel
9
+ from loguru import logger
10
+ import aiohttp
11
+ import uvicorn
12
+ import asyncio
13
+ import os
14
+ import uuid
15
+ import toml
16
+ from datetime import datetime
17
+ from json import dumps
18
 
19
 
20
  templates = Jinja2Templates(directory="templates")
 
24
 
25
  app = FastAPI()
26
 
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ app.mount("/static", StaticFiles(directory="static"), name="static")
36
+ templates = Jinja2Templates(directory="templates")
37
+
38
+ class Task(BaseModel):
39
+ id: str
40
+ prompt: str
41
+ created_at: datetime
42
+ status: str
43
+ steps: list = []
44
+
45
+ def model_dump(self, *args, **kwargs):
46
+ data = super().model_dump(*args, **kwargs)
47
+ data["created_at"] = self.created_at.isoformat()
48
+ return data
49
+
50
+ class TaskManager:
51
+ def __init__(self):
52
+ self.tasks = {}
53
+ self.queues = {}
54
+
55
+ def create_task(self, prompt: str) -> Task:
56
+ task_id = str(uuid.uuid4())
57
+ task = Task(
58
+ id=task_id, prompt=prompt, created_at=datetime.now(), status="pending"
59
+ )
60
+ self.tasks[task_id] = task
61
+ self.queues[task_id] = asyncio.Queue()
62
+ return task
63
+
64
+ async def update_task_step(self, task_id: str, step: int, result: str, step_type: str = "step"):
65
+ if task_id in self.tasks:
66
+ task = self.tasks[task_id]
67
+ task.steps.append({"step": step, "result": result, "type": step_type})
68
+ await self.queues[task_id].put({"type": step_type, "step": step, "result": result})
69
+ await self.queues[task_id].put({"type": "status", "status": task.status, "steps": task.steps})
70
+
71
+ async def complete_task(self, task_id: str):
72
+ if task_id in self.tasks:
73
+ task = self.tasks[task_id]
74
+ task.status = "completed"
75
+ await self.queues[task_id].put({"type": "status", "status": task.status, "steps": task.steps})
76
+ await self.queues[task_id].put({"type": "complete"})
77
+
78
+ async def fail_task(self, task_id: str, error: str):
79
+ if task_id in self.tasks:
80
+ self.tasks[task_id].status = f"failed: {error}"
81
+ await self.queues[task_id].put({"type": "error", "message": error})
82
+
83
+ task_manager = TaskManager()
84
+
85
+ @app.post("/tasks")
86
+ async def create_task(prompt: str = Body(..., embed=True)):
87
+ task = task_manager.create_task(prompt)
88
+ asyncio.create_task(run_task(task.id, prompt))
89
+ return {"task_id": task.id}
90
+
91
+
92
  @app.get("/", response_class=HTMLResponse)
93
  async def chat(request: Request):
94
  return templates.TemplateResponse("index.html", {"request": request})