File size: 2,184 Bytes
f6ae376 1a888f1 0acd5b7 f6ae376 0acd5b7 1a888f1 0acd5b7 1a888f1 f6ae376 1a888f1 f6ae376 0acd5b7 ffb3dbc 1a888f1 f6ae376 0acd5b7 f6ae376 0acd5b7 d26f9b1 0acd5b7 f6ae376 0acd5b7 ffb3dbc 0acd5b7 1a888f1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | import os
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File, Form, Header, BackgroundTasks
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import Optional, List
from job_queue import JobQueue
import db
import uvicorn
@asynccontextmanager
async def lifespan(app: FastAPI):
# 起動時処理
await db.init_db()
asyncio.create_task(queue.process_jobs())
yield
# シャットダウン時処理(必要に応じて)
app = FastAPI(lifespan=lifespan)
queue = JobQueue()
class MergeRequest(BaseModel):
model_config = {'protected_namespaces': ()}
model_a_source: str
model_a_id: str
model_b_source: str
model_b_id: str
method: str = "linear"
merge_type: Optional[str] = "linear"
linear_alpha: Optional[float] = 0.5
evo_params: Optional[dict] = None
dataset: Optional[str] = None
output_repo_name: Optional[str] = None
hf_token_manual: Optional[str] = None
civitai_key: Optional[str] = None
franken_layers: Optional[List[str]] = None
@app.post("/api/submit-job")
async def submit_job(
req: MergeRequest,
x_hf_user_access_token: Optional[str] = Header(None)
):
token = req.hf_token_manual or x_hf_user_access_token
if not token:
return JSONResponse(status_code=400, content={"error": "HF token required"})
job_id = await queue.add_job(req, token)
return {"job_id": job_id}
@app.get("/api/job-status/{job_id}")
async def job_status(job_id: str):
job = await db.get_job(job_id)
if not job:
return JSONResponse(status_code=404, content={"error": "Job not found"})
return job
@app.get("/api/queue")
async def get_queue():
return await queue.get_queue_status()
@app.post("/api/upload-dataset")
async def upload_dataset(file: UploadFile = File(...)):
import tempfile
os.makedirs("/app/data/datasets", exist_ok=True)
path = f"/app/data/datasets/{file.filename}"
with open(path, "wb") as f:
f.write(await file.read())
return {"path": path}
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, log_level="info") |