| import os |
| import asyncio |
| from contextlib import asynccontextmanager |
| from fastapi import FastAPI, UploadFile, File, Form, Header, BackgroundTasks |
| from fastapi.responses import JSONResponse |
| from pydantic import BaseModel |
| from typing import Optional, List |
| from job_queue import JobQueue |
| import db |
| import uvicorn |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| await db.init_db() |
| asyncio.create_task(queue.process_jobs()) |
| yield |
| |
|
|
| app = FastAPI(lifespan=lifespan) |
| queue = JobQueue() |
|
|
| class MergeRequest(BaseModel): |
| model_config = {'protected_namespaces': ()} |
|
|
| model_a_source: str |
| model_a_id: str |
| model_b_source: str |
| model_b_id: str |
| method: str = "linear" |
| merge_type: Optional[str] = "linear" |
| linear_alpha: Optional[float] = 0.5 |
| evo_params: Optional[dict] = None |
| dataset: Optional[str] = None |
| output_repo_name: Optional[str] = None |
| hf_token_manual: Optional[str] = None |
| civitai_key: Optional[str] = None |
| franken_layers: Optional[List[str]] = None |
|
|
| @app.post("/api/submit-job") |
| async def submit_job( |
| req: MergeRequest, |
| x_hf_user_access_token: Optional[str] = Header(None) |
| ): |
| token = req.hf_token_manual or x_hf_user_access_token |
| if not token: |
| return JSONResponse(status_code=400, content={"error": "HF token required"}) |
| job_id = await queue.add_job(req, token) |
| return {"job_id": job_id} |
|
|
| @app.get("/api/job-status/{job_id}") |
| async def job_status(job_id: str): |
| job = await db.get_job(job_id) |
| if not job: |
| return JSONResponse(status_code=404, content={"error": "Job not found"}) |
| return job |
|
|
| @app.get("/api/queue") |
| async def get_queue(): |
| return await queue.get_queue_status() |
|
|
| @app.post("/api/upload-dataset") |
| async def upload_dataset(file: UploadFile = File(...)): |
| import tempfile |
| os.makedirs("/app/data/datasets", exist_ok=True) |
| path = f"/app/data/datasets/{file.filename}" |
| with open(path, "wb") as f: |
| f.write(await file.read()) |
| return {"path": path} |
|
|
| if __name__ == "__main__": |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, log_level="info") |