rikunarita-2 commited on
Commit
0acd5b7
·
verified ·
1 Parent(s): 24fdb6c

Create main.py

Browse files
Files changed (1) hide show
  1. backend/main.py +61 -0
backend/main.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from fastapi import FastAPI, UploadFile, File, Form, Header, BackgroundTasks
4
+ from fastapi.responses import JSONResponse
5
+ from pydantic import BaseModel
6
+ from typing import Optional, List
7
+ from queue import JobQueue
8
+ from worker import run_job
9
+ import db
10
+
11
+ app = FastAPI()
12
+ queue = JobQueue()
13
+
14
+ @app.on_event("startup")
15
+ async def startup():
16
+ await db.init_db()
17
+ asyncio.create_task(queue.process_jobs())
18
+
19
+ class MergeRequest(BaseModel):
20
+ model_a_source: str # "hf" or "civitai"
21
+ model_a_id: str
22
+ model_b_source: str
23
+ model_b_id: str
24
+ method: str = "linear" # linear, evolutionary
25
+ linear_alpha: Optional[float] = 0.5
26
+ evo_params: Optional[dict] = None
27
+ dataset: Optional[str] = None # base64 encoded zip or path in job dir
28
+ output_repo_name: Optional[str] = None
29
+ hf_token_manual: Optional[str] = None
30
+ civitai_key: Optional[str] = None
31
+
32
+ @app.post("/api/submit-job")
33
+ async def submit_job(
34
+ req: MergeRequest,
35
+ x_hf_user_access_token: Optional[str] = Header(None)
36
+ ):
37
+ token = req.hf_token_manual or x_hf_user_access_token
38
+ if not token:
39
+ return JSONResponse(status_code=400, content={"error": "HF token required"})
40
+ job_id = await queue.add_job(req, token)
41
+ return {"job_id": job_id}
42
+
43
+ @app.get("/api/job-status/{job_id}")
44
+ async def job_status(job_id: str):
45
+ job = await db.get_job(job_id)
46
+ if not job:
47
+ return JSONResponse(status_code=404, content={"error": "Job not found"})
48
+ return job
49
+
50
+ @app.get("/api/queue")
51
+ async def get_queue():
52
+ return await queue.get_queue_status()
53
+
54
+ @app.post("/api/upload-dataset")
55
+ async def upload_dataset(file: UploadFile = File(...)):
56
+ import tempfile
57
+ os.makedirs("/data/datasets", exist_ok=True)
58
+ path = f"/data/datasets/{file.filename}"
59
+ with open(path, "wb") as f:
60
+ f.write(await file.read())
61
+ return {"path": path}