| import argparse |
| import json |
| import os |
| import subprocess |
| import sys |
| import webbrowser |
| from datetime import datetime |
| from threading import Lock |
|
|
| import uvicorn |
| from fastapi import BackgroundTasks, FastAPI, Request |
| from fastapi.responses import FileResponse |
| from fastapi.staticfiles import StaticFiles |
|
|
| import toml |
|
|
| app = FastAPI() |
|
|
| lock = Lock() |
|
|
| |
| sf = StaticFiles(directory="frontend/dist") |
| _o_fr = sf.file_response |
| def _hooked_file_response(*args, **kwargs): |
| full_path = args[0] |
| r = _o_fr(*args, **kwargs) |
| if full_path.endswith(".js"): |
| r.media_type = "application/javascript" |
| elif full_path.endswith(".css"): |
| r.media_type = "text/css" |
| return r |
| sf.file_response = _hooked_file_response |
|
|
| parser = argparse.ArgumentParser(description="GUI for training network") |
| parser.add_argument("--port", type=int, default=28000, help="Port to run the server on") |
|
|
| def run_train(toml_path: str): |
| print(f"Training started with config file / 训练开始,使用配置文件: {toml_path}") |
| args = [ |
| "accelerate", "launch", "--num_cpu_threads_per_process", "8", |
| "./sd-scripts/train_network.py", |
| "--config_file", toml_path, |
| ] |
| try: |
| result = subprocess.run(args, shell=True, env=os.environ) |
| if result.returncode != 0: |
| print(f"Training failed / 训练失败") |
| else: |
| print(f"Training finished / 训练完成") |
| except Exception as e: |
| print(f"An error occurred when training / 创建训练进程时出现致命错误: {e}") |
| finally: |
| lock.release() |
|
|
|
|
| @app.post("/api/run") |
| async def create_toml_file(request: Request, background_tasks: BackgroundTasks): |
| acquired = lock.acquire(blocking=False) |
|
|
| if not acquired: |
| print("Training is already running / 已有正在进行的训练") |
| return {"status": "fail", "detail": "Training is already running"} |
|
|
| timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") |
| toml_file = f"toml/{timestamp}.toml" |
| toml_data = await request.body() |
| j = json.loads(toml_data.decode("utf-8")) |
| with open(toml_file, "w") as f: |
| f.write(toml.dumps(j)) |
| background_tasks.add_task(run_train, toml_file) |
| return {"status": "success"} |
|
|
| @app.middleware("http") |
| async def add_cache_control_header(request, call_next): |
| response = await call_next(request) |
| response.headers["Cache-Control"] = "max-age=0" |
| return response |
|
|
| @app.get("/") |
| async def index(): |
| return FileResponse("./frontend/dist/index.html") |
|
|
|
|
| app.mount("/", sf, name="static") |
|
|
| if __name__ == "__main__": |
| args, _ = parser.parse_known_args() |
| print(f"Server started at http://127.0.0.1:{args.port}") |
| if sys.platform == "win32": |
| |
| os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1" |
| |
| webbrowser.open(f"http://127.0.0.1:{args.port}") |
| uvicorn.run(app, host="127.0.0.1", port=28000, log_level="error") |
|
|