File size: 2,184 Bytes
f6ae376
 
1a888f1
0acd5b7
 
 
 
f6ae376
0acd5b7
1a888f1
0acd5b7
1a888f1
 
 
f6ae376
 
1a888f1
 
 
 
 
f6ae376
0acd5b7
ffb3dbc
1a888f1
f6ae376
0acd5b7
 
 
f6ae376
 
0acd5b7
 
d26f9b1
0acd5b7
 
 
f6ae376
0acd5b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffb3dbc
 
0acd5b7
 
1a888f1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File, Form, Header, BackgroundTasks
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import Optional, List
from job_queue import JobQueue
import db
import uvicorn

@asynccontextmanager
async def lifespan(app: FastAPI):
    # 起動時処理
    await db.init_db()
    asyncio.create_task(queue.process_jobs())
    yield
    # シャットダウン時処理(必要に応じて)

app = FastAPI(lifespan=lifespan)
queue = JobQueue()

class MergeRequest(BaseModel):
    model_config = {'protected_namespaces': ()}

    model_a_source: str
    model_a_id: str
    model_b_source: str
    model_b_id: str
    method: str = "linear"
    merge_type: Optional[str] = "linear"
    linear_alpha: Optional[float] = 0.5
    evo_params: Optional[dict] = None
    dataset: Optional[str] = None
    output_repo_name: Optional[str] = None
    hf_token_manual: Optional[str] = None
    civitai_key: Optional[str] = None
    franken_layers: Optional[List[str]] = None

@app.post("/api/submit-job")
async def submit_job(
    req: MergeRequest,
    x_hf_user_access_token: Optional[str] = Header(None)
):
    token = req.hf_token_manual or x_hf_user_access_token
    if not token:
        return JSONResponse(status_code=400, content={"error": "HF token required"})
    job_id = await queue.add_job(req, token)
    return {"job_id": job_id}

@app.get("/api/job-status/{job_id}")
async def job_status(job_id: str):
    job = await db.get_job(job_id)
    if not job:
        return JSONResponse(status_code=404, content={"error": "Job not found"})
    return job

@app.get("/api/queue")
async def get_queue():
    return await queue.get_queue_status()

@app.post("/api/upload-dataset")
async def upload_dataset(file: UploadFile = File(...)):
    import tempfile
    os.makedirs("/app/data/datasets", exist_ok=True)
    path = f"/app/data/datasets/{file.filename}"
    with open(path, "wb") as f:
        f.write(await file.read())
    return {"path": path}

if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=8000, log_level="info")