Rodrigo Ortega commited on
Commit ·
62ab980
1
Parent(s): 27377d7
Add AI Toolkit training setup and live job progress
Browse files- app/nemoflix_amd/api.py +196 -3
- requirements.txt +1 -0
- scripts/install-ai-toolkit.sh +189 -0
- scripts/startup-script.sh +1 -0
- studio/src/App.tsx +30 -9
- studio/src/components/JobCard.tsx +57 -0
- studio/src/types.ts +16 -0
- studio/tsconfig.json +21 -0
- studio/tsconfig.node.json +10 -0
- training/README.md +37 -0
- training/flux2_identity_template.yaml +78 -0
- training/wan22_i2v_character_template.yaml +77 -0
app/nemoflix_amd/api.py
CHANGED
|
@@ -1,9 +1,15 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
import tempfile
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Any, Literal
|
| 6 |
-
from
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from fastapi import FastAPI, File, HTTPException, UploadFile
|
| 9 |
from pydantic import BaseModel, Field
|
|
@@ -84,6 +90,143 @@ def comfy() -> ComfyClient:
|
|
| 84 |
return ComfyClient(settings.comfy_url, settings.request_timeout_seconds)
|
| 85 |
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
@app.get("/api/health")
|
| 88 |
async def health() -> dict[str, Any]:
|
| 89 |
client = comfy()
|
|
@@ -170,14 +313,16 @@ async def generate_video(body: VideoGenerateRequest) -> VideoGenerateResponse:
|
|
| 170 |
return VideoGenerateResponse(ok=True, mode=body.mode, workflow=workflow)
|
| 171 |
|
| 172 |
try:
|
| 173 |
-
result = await comfy().queue_prompt(workflow, client_id=
|
| 174 |
except Exception as exc: # noqa: BLE001
|
| 175 |
raise HTTPException(status_code=502, detail=f"ComfyUI prompt submission failed: {exc}") from exc
|
| 176 |
|
|
|
|
|
|
|
| 177 |
return VideoGenerateResponse(
|
| 178 |
ok="prompt_id" in result,
|
| 179 |
mode=body.mode,
|
| 180 |
-
prompt_id=
|
| 181 |
number=result.get("number"),
|
| 182 |
node_errors=result.get("node_errors"),
|
| 183 |
)
|
|
@@ -224,6 +369,54 @@ async def _queue_position(client: ComfyClient, prompt_id: str) -> int | None:
|
|
| 224 |
return None
|
| 225 |
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
@app.get("/api/jobs/{prompt_id}", response_model=JobStatusResponse)
|
| 228 |
async def job(prompt_id: str) -> JobStatusResponse:
|
| 229 |
client = comfy()
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
+
import contextlib
|
| 5 |
+
import json
|
| 6 |
import tempfile
|
| 7 |
+
from datetime import UTC, datetime
|
| 8 |
from pathlib import Path
|
| 9 |
from typing import Any, Literal
|
| 10 |
+
from urllib.parse import urlparse, urlunparse
|
| 11 |
+
|
| 12 |
+
import websockets
|
| 13 |
|
| 14 |
from fastapi import FastAPI, File, HTTPException, UploadFile
|
| 15 |
from pydantic import BaseModel, Field
|
|
|
|
| 90 |
return ComfyClient(settings.comfy_url, settings.request_timeout_seconds)
|
| 91 |
|
| 92 |
|
| 93 |
+
_JOBS: dict[str, dict[str, Any]] = {}
|
| 94 |
+
_COMFY_CLIENT_ID = "nemoflix-amd-gallery"
|
| 95 |
+
_WS_TASK: asyncio.Task | None = None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _ws_url(base_url: str) -> str:
|
| 99 |
+
parsed = urlparse(base_url)
|
| 100 |
+
scheme = "wss" if parsed.scheme == "https" else "ws"
|
| 101 |
+
return urlunparse((scheme, parsed.netloc, "/ws", "", f"clientId={_COMFY_CLIENT_ID}", ""))
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _register_submitted_job(prompt_id: str | None, body: VideoGenerateRequest, status: str = "pending") -> None:
|
| 105 |
+
if not prompt_id:
|
| 106 |
+
return
|
| 107 |
+
_JOBS[prompt_id] = {
|
| 108 |
+
"prompt_id": prompt_id,
|
| 109 |
+
"status": status,
|
| 110 |
+
"mode": body.mode,
|
| 111 |
+
"prompt": body.prompt,
|
| 112 |
+
"width": body.width,
|
| 113 |
+
"height": body.height,
|
| 114 |
+
"length": body.length,
|
| 115 |
+
"fps": body.fps,
|
| 116 |
+
"created_at": datetime.now(UTC).isoformat(),
|
| 117 |
+
"current_node": None,
|
| 118 |
+
"step_value": 0,
|
| 119 |
+
"step_max": 0,
|
| 120 |
+
"nodes_finished": 0,
|
| 121 |
+
"nodes_total": 0,
|
| 122 |
+
"progress_percent": None,
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _update_job_from_progress_state(prompt_id: str, nodes: dict[str, Any]) -> None:
|
| 127 |
+
if not prompt_id:
|
| 128 |
+
return
|
| 129 |
+
job = _JOBS.setdefault(prompt_id, {"prompt_id": prompt_id, "status": "running", "created_at": None})
|
| 130 |
+
total = len(nodes)
|
| 131 |
+
finished = 0
|
| 132 |
+
running = 0
|
| 133 |
+
current_node = None
|
| 134 |
+
step_value = 0
|
| 135 |
+
step_max = 0
|
| 136 |
+
for node_id, node in nodes.items():
|
| 137 |
+
if not isinstance(node, dict):
|
| 138 |
+
continue
|
| 139 |
+
state = node.get("state")
|
| 140 |
+
if state == "finished":
|
| 141 |
+
finished += 1
|
| 142 |
+
elif state == "running":
|
| 143 |
+
running += 1
|
| 144 |
+
if current_node is None:
|
| 145 |
+
current_node = node.get("display_node_id") or node.get("node_id") or node_id
|
| 146 |
+
step_value = int(node.get("value") or 0)
|
| 147 |
+
step_max = int(node.get("max") or 0)
|
| 148 |
+
percent = round((finished / total) * 100, 1) if total else None
|
| 149 |
+
job.update({
|
| 150 |
+
"status": "running",
|
| 151 |
+
"nodes_total": total,
|
| 152 |
+
"nodes_finished": finished,
|
| 153 |
+
"nodes_running": running,
|
| 154 |
+
"current_node": current_node,
|
| 155 |
+
"step_value": step_value,
|
| 156 |
+
"step_max": step_max,
|
| 157 |
+
"progress_percent": percent,
|
| 158 |
+
"updated_at": datetime.now(UTC).isoformat(),
|
| 159 |
+
})
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
async def _comfy_ws_bridge() -> None:
|
| 163 |
+
settings = get_settings()
|
| 164 |
+
url = _ws_url(settings.comfy_url)
|
| 165 |
+
while True:
|
| 166 |
+
try:
|
| 167 |
+
async with websockets.connect(url, ping_interval=20, ping_timeout=20) as ws:
|
| 168 |
+
async for raw in ws:
|
| 169 |
+
if isinstance(raw, bytes):
|
| 170 |
+
continue
|
| 171 |
+
try:
|
| 172 |
+
msg = json.loads(raw)
|
| 173 |
+
except json.JSONDecodeError:
|
| 174 |
+
continue
|
| 175 |
+
msg_type = msg.get("type")
|
| 176 |
+
data = msg.get("data", {}) if isinstance(msg.get("data"), dict) else {}
|
| 177 |
+
prompt_id = data.get("prompt_id") or data.get("prompt")
|
| 178 |
+
if msg_type == "execution_start" and isinstance(prompt_id, str):
|
| 179 |
+
_JOBS.setdefault(prompt_id, {"prompt_id": prompt_id, "created_at": None}).update({"status": "running"})
|
| 180 |
+
elif msg_type == "progress_state" and isinstance(prompt_id, str):
|
| 181 |
+
nodes = data.get("nodes", {})
|
| 182 |
+
if isinstance(nodes, dict):
|
| 183 |
+
_update_job_from_progress_state(prompt_id, nodes)
|
| 184 |
+
elif msg_type == "progress" and isinstance(prompt_id, str):
|
| 185 |
+
job = _JOBS.setdefault(prompt_id, {"prompt_id": prompt_id, "status": "running", "created_at": None})
|
| 186 |
+
value = int(data.get("value") or 0)
|
| 187 |
+
max_value = int(data.get("max") or 0)
|
| 188 |
+
job.update({
|
| 189 |
+
"status": "running",
|
| 190 |
+
"step_value": value,
|
| 191 |
+
"step_max": max_value,
|
| 192 |
+
"progress_percent": round((value / max_value) * 100, 1) if max_value else None,
|
| 193 |
+
"updated_at": datetime.now(UTC).isoformat(),
|
| 194 |
+
})
|
| 195 |
+
elif msg_type == "execution_success" and isinstance(prompt_id, str):
|
| 196 |
+
_JOBS.setdefault(prompt_id, {"prompt_id": prompt_id, "created_at": None}).update({
|
| 197 |
+
"status": "completed",
|
| 198 |
+
"progress_percent": 100,
|
| 199 |
+
"updated_at": datetime.now(UTC).isoformat(),
|
| 200 |
+
})
|
| 201 |
+
elif msg_type in {"execution_error", "execution_interrupted"} and isinstance(prompt_id, str):
|
| 202 |
+
_JOBS.setdefault(prompt_id, {"prompt_id": prompt_id, "created_at": None}).update({
|
| 203 |
+
"status": "failed",
|
| 204 |
+
"error": data.get("exception_message") or msg_type,
|
| 205 |
+
"updated_at": datetime.now(UTC).isoformat(),
|
| 206 |
+
})
|
| 207 |
+
except asyncio.CancelledError:
|
| 208 |
+
raise
|
| 209 |
+
except Exception:
|
| 210 |
+
await asyncio.sleep(3)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@app.on_event("startup")
|
| 214 |
+
async def start_comfy_bridge() -> None:
|
| 215 |
+
global _WS_TASK
|
| 216 |
+
if _WS_TASK is None or _WS_TASK.done():
|
| 217 |
+
_WS_TASK = asyncio.create_task(_comfy_ws_bridge())
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@app.on_event("shutdown")
|
| 221 |
+
async def stop_comfy_bridge() -> None:
|
| 222 |
+
global _WS_TASK
|
| 223 |
+
if _WS_TASK:
|
| 224 |
+
_WS_TASK.cancel()
|
| 225 |
+
with contextlib.suppress(asyncio.CancelledError):
|
| 226 |
+
await _WS_TASK
|
| 227 |
+
_WS_TASK = None
|
| 228 |
+
|
| 229 |
+
|
| 230 |
@app.get("/api/health")
|
| 231 |
async def health() -> dict[str, Any]:
|
| 232 |
client = comfy()
|
|
|
|
| 313 |
return VideoGenerateResponse(ok=True, mode=body.mode, workflow=workflow)
|
| 314 |
|
| 315 |
try:
|
| 316 |
+
result = await comfy().queue_prompt(workflow, client_id=_COMFY_CLIENT_ID)
|
| 317 |
except Exception as exc: # noqa: BLE001
|
| 318 |
raise HTTPException(status_code=502, detail=f"ComfyUI prompt submission failed: {exc}") from exc
|
| 319 |
|
| 320 |
+
prompt_id = result.get("prompt_id")
|
| 321 |
+
_register_submitted_job(prompt_id, body, "pending")
|
| 322 |
return VideoGenerateResponse(
|
| 323 |
ok="prompt_id" in result,
|
| 324 |
mode=body.mode,
|
| 325 |
+
prompt_id=prompt_id,
|
| 326 |
number=result.get("number"),
|
| 327 |
node_errors=result.get("node_errors"),
|
| 328 |
)
|
|
|
|
| 369 |
return None
|
| 370 |
|
| 371 |
|
| 372 |
+
@app.get("/api/jobs")
|
| 373 |
+
async def jobs() -> dict[str, Any]:
|
| 374 |
+
"""Return jobs submitted through this API.
|
| 375 |
+
|
| 376 |
+
ComfyUI is still the execution engine, but this endpoint intentionally does
|
| 377 |
+
not list arbitrary Comfy queue entries. The gallery should only show jobs we
|
| 378 |
+
submitted and registered locally; completed media is discovered separately by
|
| 379 |
+
/api/listing.
|
| 380 |
+
"""
|
| 381 |
+
client = comfy()
|
| 382 |
+
|
| 383 |
+
try:
|
| 384 |
+
queue = await client.get("/queue")
|
| 385 |
+
except Exception as exc: # noqa: BLE001
|
| 386 |
+
jobs_list = sorted(_JOBS.values(), key=lambda j: j.get("created_at") or "", reverse=True)
|
| 387 |
+
return {"jobs": jobs_list, "count": len(jobs_list), "error": str(exc)}
|
| 388 |
+
|
| 389 |
+
running_ids: set[str] = set()
|
| 390 |
+
pending_positions: dict[str, int] = {}
|
| 391 |
+
|
| 392 |
+
for item in queue.get("queue_running", []) if isinstance(queue, dict) else []:
|
| 393 |
+
if isinstance(item, list) and len(item) > 1 and isinstance(item[1], str):
|
| 394 |
+
running_ids.add(item[1])
|
| 395 |
+
|
| 396 |
+
for position, item in enumerate(queue.get("queue_pending", []) if isinstance(queue, dict) else [], start=1):
|
| 397 |
+
if isinstance(item, list) and len(item) > 1 and isinstance(item[1], str):
|
| 398 |
+
pending_positions[item[1]] = position
|
| 399 |
+
|
| 400 |
+
for prompt_id, job in _JOBS.items():
|
| 401 |
+
if job.get("status") in {"completed", "failed"}:
|
| 402 |
+
continue
|
| 403 |
+
if prompt_id in running_ids:
|
| 404 |
+
job["status"] = "running"
|
| 405 |
+
job["queue_position"] = None
|
| 406 |
+
elif prompt_id in pending_positions:
|
| 407 |
+
job["status"] = "pending"
|
| 408 |
+
job["queue_position"] = pending_positions[prompt_id]
|
| 409 |
+
elif job.get("status") in {"pending", "running"}:
|
| 410 |
+
job["status"] = "unknown"
|
| 411 |
+
job["queue_position"] = None
|
| 412 |
+
|
| 413 |
+
jobs_list = sorted(
|
| 414 |
+
_JOBS.values(),
|
| 415 |
+
key=lambda j: (j.get("status") != "running", j.get("queue_position") or 0, j.get("created_at") or ""),
|
| 416 |
+
)
|
| 417 |
+
return {"jobs": jobs_list, "count": len(jobs_list)}
|
| 418 |
+
|
| 419 |
+
|
| 420 |
@app.get("/api/jobs/{prompt_id}", response_model=JobStatusResponse)
|
| 421 |
async def job(prompt_id: str) -> JobStatusResponse:
|
| 422 |
client = comfy()
|
requirements.txt
CHANGED
|
@@ -3,3 +3,4 @@ uvicorn[standard]==0.34.0
|
|
| 3 |
httpx==0.28.1
|
| 4 |
python-multipart==0.0.20
|
| 5 |
pydantic-settings==2.7.1
|
|
|
|
|
|
| 3 |
httpx==0.28.1
|
| 4 |
python-multipart==0.0.20
|
| 5 |
pydantic-settings==2.7.1
|
| 6 |
+
websockets==16.0
|
scripts/install-ai-toolkit.sh
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -Eeuo pipefail
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
# Install Ostris AI Toolkit on a disposable AMD MI300X ROCm droplet.
|
| 6 |
+
# Run after scripts/startup-script.sh so ROCm/system basics are already present.
|
| 7 |
+
# This is intentionally idempotent: safe to rerun on a fresh or partially initialized box.
|
| 8 |
+
|
| 9 |
+
APT_GET="apt-get -o DPkg::Lock::Timeout=300"
|
| 10 |
+
TOOLKIT_DIR="${TOOLKIT_DIR:-/root/ai-toolkit}"
|
| 11 |
+
TOOLKIT_VENV="${TOOLKIT_VENV:-/root/ai-toolkit-venv}"
|
| 12 |
+
TRAINING_DIR="${TRAINING_DIR:-/root/nemoflix-training}"
|
| 13 |
+
ROCM_INDEX_PRIMARY="${ROCM_INDEX_PRIMARY:-https://download.pytorch.org/whl/rocm7.2}"
|
| 14 |
+
ROCM_INDEX_FALLBACK="${ROCM_INDEX_FALLBACK:-https://download.pytorch.org/whl/rocm7.0}"
|
| 15 |
+
AI_TOOLKIT_REF="${AI_TOOLKIT_REF:-main}"
|
| 16 |
+
# CLI training is the default. The UI pulls NodeSource apt repo + nodejs and can
|
| 17 |
+
# trigger service restart/deferred-restart behavior on cloud images; keep it opt-in.
|
| 18 |
+
INSTALL_UI_DEPS="${INSTALL_UI_DEPS:-0}"
|
| 19 |
+
PYTHON_BIN="$TOOLKIT_VENV/bin/python"
|
| 20 |
+
export DEBIAN_FRONTEND=noninteractive
|
| 21 |
+
export NEEDRESTART_MODE=a
|
| 22 |
+
|
| 23 |
+
trap 'echo "ERROR: AI Toolkit install failed at line $LINENO"' ERR
|
| 24 |
+
|
| 25 |
+
if [ "$(id -u)" -ne 0 ]; then
|
| 26 |
+
echo "Run as root on the AMD droplet."
|
| 27 |
+
exit 1
|
| 28 |
+
fi
|
| 29 |
+
|
| 30 |
+
echo "=== Installing AI Toolkit prerequisites ==="
|
| 31 |
+
$APT_GET update -y
|
| 32 |
+
$APT_GET install -y git git-lfs python3-pip python3.12-venv python3-dev build-essential pkg-config curl wget ffmpeg libgl1 libglib2.0-0
|
| 33 |
+
|
| 34 |
+
git lfs install --system || true
|
| 35 |
+
|
| 36 |
+
if command -v /opt/rocm/bin/rocm-smi >/dev/null 2>&1; then
|
| 37 |
+
echo "=== ROCm GPU check ==="
|
| 38 |
+
/opt/rocm/bin/rocm-smi || true
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
echo "=== Creating isolated AI Toolkit venv: $TOOLKIT_VENV ==="
|
| 42 |
+
if [ ! -d "$TOOLKIT_VENV" ]; then
|
| 43 |
+
python3 -m venv "$TOOLKIT_VENV"
|
| 44 |
+
fi
|
| 45 |
+
"$PYTHON_BIN" -m pip install --upgrade pip setuptools wheel
|
| 46 |
+
|
| 47 |
+
echo "=== Installing ROCm PyTorch into AI Toolkit venv ==="
|
| 48 |
+
"$PYTHON_BIN" -m pip install torch torchvision torchaudio --index-url "$ROCM_INDEX_PRIMARY" || \
|
| 49 |
+
"$PYTHON_BIN" -m pip install torch torchvision torchaudio --index-url "$ROCM_INDEX_FALLBACK"
|
| 50 |
+
|
| 51 |
+
"$PYTHON_BIN" - <<'PY'
|
| 52 |
+
import torch
|
| 53 |
+
print('torch', torch.__version__)
|
| 54 |
+
print('cuda api available', torch.cuda.is_available())
|
| 55 |
+
if torch.cuda.is_available():
|
| 56 |
+
print('device', torch.cuda.get_device_name(0))
|
| 57 |
+
PY
|
| 58 |
+
|
| 59 |
+
echo "=== Cloning/updating Ostris AI Toolkit ==="
|
| 60 |
+
if [ ! -d "$TOOLKIT_DIR/.git" ]; then
|
| 61 |
+
git clone https://github.com/ostris/ai-toolkit.git "$TOOLKIT_DIR"
|
| 62 |
+
fi
|
| 63 |
+
git -C "$TOOLKIT_DIR" fetch --depth 1 origin "$AI_TOOLKIT_REF"
|
| 64 |
+
git -C "$TOOLKIT_DIR" checkout FETCH_HEAD
|
| 65 |
+
git -C "$TOOLKIT_DIR" submodule update --init --recursive
|
| 66 |
+
|
| 67 |
+
echo "=== Installing AI Toolkit Python requirements ==="
|
| 68 |
+
# Keep the ROCm torch we installed above; do not allow requirements to swap in CUDA wheels.
|
| 69 |
+
"$PYTHON_BIN" -m pip install -r "$TOOLKIT_DIR/requirements.txt" --extra-index-url "$ROCM_INDEX_PRIMARY"
|
| 70 |
+
"$PYTHON_BIN" -m pip install --upgrade accelerate huggingface_hub hf_transfer
|
| 71 |
+
|
| 72 |
+
mkdir -p \
|
| 73 |
+
"$TRAINING_DIR/datasets" \
|
| 74 |
+
"$TRAINING_DIR/output" \
|
| 75 |
+
"$TRAINING_DIR/samples" \
|
| 76 |
+
"$TRAINING_DIR/config" \
|
| 77 |
+
"$TOOLKIT_DIR/config"
|
| 78 |
+
|
| 79 |
+
# If this script is run from a cloned Nemoflix repo, seed our checked-in config templates
|
| 80 |
+
# into the disposable training workspace.
|
| 81 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 82 |
+
REPO_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
|
| 83 |
+
if compgen -G "$REPO_DIR/training/*.yaml" >/dev/null; then
|
| 84 |
+
cp -f "$REPO_DIR"/training/*.yaml "$TRAINING_DIR/config/"
|
| 85 |
+
fi
|
| 86 |
+
|
| 87 |
+
cat > "$TRAINING_DIR/README.md" <<'EOF'
|
| 88 |
+
# Nemoflix AI Toolkit Training Workspace
|
| 89 |
+
|
| 90 |
+
Persistent-ish training layout for disposable AMD droplets.
|
| 91 |
+
|
| 92 |
+
## Paths
|
| 93 |
+
|
| 94 |
+
- AI Toolkit: `/root/ai-toolkit`
|
| 95 |
+
- Venv: `/root/ai-toolkit-venv`
|
| 96 |
+
- Datasets: `/root/nemoflix-training/datasets`
|
| 97 |
+
- Configs: `/root/nemoflix-training/config`
|
| 98 |
+
- Outputs/checkpoints: `/root/nemoflix-training/output`
|
| 99 |
+
- Sample control images: `/root/nemoflix-training/samples`
|
| 100 |
+
|
| 101 |
+
## Run a config
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
cd /root/ai-toolkit
|
| 105 |
+
/root/ai-toolkit-venv/bin/python run.py /root/nemoflix-training/config/<config>.yaml
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## Hugging Face token
|
| 109 |
+
|
| 110 |
+
For gated models, create `/root/ai-toolkit/.env`:
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
HF_TOKEN=hf_xxx
|
| 114 |
+
HF_HUB_ENABLE_HF_TRANSFER=1
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
## Character dataset conventions
|
| 118 |
+
|
| 119 |
+
Image/FLUX LoRA:
|
| 120 |
+
- 20-40 good face/body images
|
| 121 |
+
- one `.txt` caption beside each image
|
| 122 |
+
- include trigger word, e.g. `character_trigger, person, portrait, natural lighting`
|
| 123 |
+
|
| 124 |
+
Wan I2V character LoRA:
|
| 125 |
+
- 10-30 short clips, ideally 3-8 seconds
|
| 126 |
+
- one `.txt` caption beside each clip
|
| 127 |
+
- include trigger word, e.g. `character_trigger, person, walking outdoors, close-up face`
|
| 128 |
+
- trim dead time; varied angles/lighting/backgrounds; avoid sunglasses
|
| 129 |
+
EOF
|
| 130 |
+
|
| 131 |
+
cat > "$TRAINING_DIR/run-ai-toolkit.sh" <<'EOF'
|
| 132 |
+
#!/bin/bash
|
| 133 |
+
set -Eeuo pipefail
|
| 134 |
+
CONFIG_PATH="${1:?Usage: /root/nemoflix-training/run-ai-toolkit.sh /root/nemoflix-training/config/job.yaml}"
|
| 135 |
+
cd /root/ai-toolkit
|
| 136 |
+
export HF_HUB_ENABLE_HF_TRANSFER="${HF_HUB_ENABLE_HF_TRANSFER:-1}"
|
| 137 |
+
exec /root/ai-toolkit-venv/bin/python run.py "$CONFIG_PATH"
|
| 138 |
+
EOF
|
| 139 |
+
chmod +x "$TRAINING_DIR/run-ai-toolkit.sh"
|
| 140 |
+
|
| 141 |
+
if [ "$INSTALL_UI_DEPS" = "1" ]; then
|
| 142 |
+
if ! command -v node >/dev/null 2>&1 || ! node -e 'process.exit(Number(process.versions.node.split(".")[0]) >= 20 ? 0 : 1)' 2>/dev/null; then
|
| 143 |
+
echo "=== Installing Node.js 22 for optional AI Toolkit UI ==="
|
| 144 |
+
curl -fsSL https://deb.nodesource.com/setup_22.x | bash -
|
| 145 |
+
$APT_GET install -y nodejs
|
| 146 |
+
fi
|
| 147 |
+
if [ -f "$TOOLKIT_DIR/ui/package.json" ]; then
|
| 148 |
+
echo "=== Installing AI Toolkit UI dependencies ==="
|
| 149 |
+
(cd "$TOOLKIT_DIR/ui" && npm install)
|
| 150 |
+
fi
|
| 151 |
+
fi
|
| 152 |
+
|
| 153 |
+
if [ "$INSTALL_UI_DEPS" = "1" ]; then
|
| 154 |
+
cat > /etc/systemd/system/ai-toolkit-ui.service <<EOF
|
| 155 |
+
[Unit]
|
| 156 |
+
Description=Ostris AI Toolkit UI
|
| 157 |
+
After=network-online.target
|
| 158 |
+
Wants=network-online.target
|
| 159 |
+
|
| 160 |
+
[Service]
|
| 161 |
+
Type=simple
|
| 162 |
+
User=root
|
| 163 |
+
WorkingDirectory=$TOOLKIT_DIR/ui
|
| 164 |
+
Environment="PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
| 165 |
+
Environment="AI_TOOLKIT_AUTH=change-me"
|
| 166 |
+
ExecStart=/usr/bin/npm run build_and_start
|
| 167 |
+
Restart=on-failure
|
| 168 |
+
RestartSec=5
|
| 169 |
+
|
| 170 |
+
[Install]
|
| 171 |
+
WantedBy=multi-user.target
|
| 172 |
+
EOF
|
| 173 |
+
systemctl daemon-reload
|
| 174 |
+
fi
|
| 175 |
+
|
| 176 |
+
"$PYTHON_BIN" - <<'PY'
|
| 177 |
+
import importlib
|
| 178 |
+
mods = ['torch', 'accelerate', 'diffusers', 'transformers', 'huggingface_hub']
|
| 179 |
+
for name in mods:
|
| 180 |
+
mod = importlib.import_module(name)
|
| 181 |
+
print(name, getattr(mod, '__version__', 'ok'))
|
| 182 |
+
PY
|
| 183 |
+
|
| 184 |
+
echo "=== AI Toolkit install complete ==="
|
| 185 |
+
echo "Toolkit: $TOOLKIT_DIR"
|
| 186 |
+
echo "Venv: $TOOLKIT_VENV"
|
| 187 |
+
echo "Training: $TRAINING_DIR"
|
| 188 |
+
echo "Run CLI: $TRAINING_DIR/run-ai-toolkit.sh /root/nemoflix-training/config/job.yaml"
|
| 189 |
+
echo "UI: optional; rerun with INSTALL_UI_DEPS=1, then set AI_TOOLKIT_AUTH before starting ai-toolkit-ui.service"
|
scripts/startup-script.sh
CHANGED
|
@@ -147,3 +147,4 @@ curl -sS --max-time 5 http://127.0.0.1:8190/api/health
|
|
| 147 |
|
| 148 |
echo "=== Setup Complete ==="
|
| 149 |
echo "Install Wan 2.2 video stack: $APP_DIR/scripts/install-video-stack.sh"
|
|
|
|
|
|
| 147 |
|
| 148 |
echo "=== Setup Complete ==="
|
| 149 |
echo "Install Wan 2.2 video stack: $APP_DIR/scripts/install-video-stack.sh"
|
| 150 |
+
echo "Install AI Toolkit training stack: $APP_DIR/scripts/install-ai-toolkit.sh"
|
studio/src/App.tsx
CHANGED
|
@@ -1,17 +1,24 @@
|
|
| 1 |
import { useState, useEffect, useCallback } from "react";
|
| 2 |
-
import
|
|
|
|
| 3 |
|
| 4 |
export default function App() {
|
| 5 |
const [items, setItems] = useState<MediaItem[]>([]);
|
|
|
|
| 6 |
const [loading, setLoading] = useState(true);
|
| 7 |
const [selected, setSelected] = useState<string | null>(null);
|
| 8 |
|
| 9 |
const load = useCallback(async () => {
|
| 10 |
setLoading(true);
|
| 11 |
try {
|
| 12 |
-
const
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
} catch (e) {
|
| 16 |
console.error(e);
|
| 17 |
} finally {
|
|
@@ -21,25 +28,39 @@ export default function App() {
|
|
| 21 |
|
| 22 |
useEffect(() => {
|
| 23 |
load();
|
|
|
|
|
|
|
| 24 |
}, [load]);
|
| 25 |
|
|
|
|
|
|
|
| 26 |
return (
|
| 27 |
<div className="min-h-screen bg-black text-white">
|
| 28 |
<header className="border-b border-gray-800 px-6 py-4 flex items-center justify-between">
|
| 29 |
-
<
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
</header>
|
| 32 |
|
| 33 |
<main className="p-6">
|
| 34 |
-
{loading &&
|
| 35 |
<p className="text-gray-500">Loading...</p>
|
| 36 |
-
) :
|
| 37 |
<p className="text-gray-500">No media yet.</p>
|
| 38 |
) : (
|
| 39 |
<div className="grid grid-cols-2 md:grid-cols-3 lg:grid-cols-4 xl:grid-cols-5 gap-4">
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
{items.map((item) => (
|
| 41 |
<div
|
| 42 |
-
key={item.
|
| 43 |
onClick={() => setSelected(item.url)}
|
| 44 |
className="cursor-pointer rounded-lg overflow-hidden border border-gray-800 hover:border-rose-600 transition aspect-video bg-gray-900 relative group"
|
| 45 |
>
|
|
|
|
| 1 |
import { useState, useEffect, useCallback } from "react";
|
| 2 |
+
import { JobCard } from "./components/JobCard";
|
| 3 |
+
import type { JobItem, MediaItem } from "./types";
|
| 4 |
|
| 5 |
export default function App() {
|
| 6 |
const [items, setItems] = useState<MediaItem[]>([]);
|
| 7 |
+
const [jobs, setJobs] = useState<JobItem[]>([]);
|
| 8 |
const [loading, setLoading] = useState(true);
|
| 9 |
const [selected, setSelected] = useState<string | null>(null);
|
| 10 |
|
| 11 |
const load = useCallback(async () => {
|
| 12 |
setLoading(true);
|
| 13 |
try {
|
| 14 |
+
const [listingRes, jobsRes] = await Promise.all([
|
| 15 |
+
fetch("/api/listing"),
|
| 16 |
+
fetch("/api/jobs"),
|
| 17 |
+
]);
|
| 18 |
+
const listing = await listingRes.json();
|
| 19 |
+
const jobData = await jobsRes.json();
|
| 20 |
+
setItems(listing.images || []);
|
| 21 |
+
setJobs(jobData.jobs || []);
|
| 22 |
} catch (e) {
|
| 23 |
console.error(e);
|
| 24 |
} finally {
|
|
|
|
| 28 |
|
| 29 |
useEffect(() => {
|
| 30 |
load();
|
| 31 |
+
const id = window.setInterval(load, 3000);
|
| 32 |
+
return () => window.clearInterval(id);
|
| 33 |
}, [load]);
|
| 34 |
|
| 35 |
+
const hasContent = jobs.length > 0 || items.length > 0;
|
| 36 |
+
|
| 37 |
return (
|
| 38 |
<div className="min-h-screen bg-black text-white">
|
| 39 |
<header className="border-b border-gray-800 px-6 py-4 flex items-center justify-between">
|
| 40 |
+
<div>
|
| 41 |
+
<h1 className="text-xl font-semibold">Nemoflix AMD Gallery</h1>
|
| 42 |
+
<p className="text-xs text-gray-500 mt-1">Live from the MI300X droplet</p>
|
| 43 |
+
</div>
|
| 44 |
+
<div className="text-sm text-gray-500">
|
| 45 |
+
{jobs.length > 0 && <span className="text-amber-400 mr-3">{jobs.length} generating</span>}
|
| 46 |
+
<span>{items.length} media</span>
|
| 47 |
+
</div>
|
| 48 |
</header>
|
| 49 |
|
| 50 |
<main className="p-6">
|
| 51 |
+
{loading && !hasContent ? (
|
| 52 |
<p className="text-gray-500">Loading...</p>
|
| 53 |
+
) : !hasContent ? (
|
| 54 |
<p className="text-gray-500">No media yet.</p>
|
| 55 |
) : (
|
| 56 |
<div className="grid grid-cols-2 md:grid-cols-3 lg:grid-cols-4 xl:grid-cols-5 gap-4">
|
| 57 |
+
{jobs.map((job) => (
|
| 58 |
+
<JobCard key={job.prompt_id} job={job} />
|
| 59 |
+
))}
|
| 60 |
+
|
| 61 |
{items.map((item) => (
|
| 62 |
<div
|
| 63 |
+
key={item.url}
|
| 64 |
onClick={() => setSelected(item.url)}
|
| 65 |
className="cursor-pointer rounded-lg overflow-hidden border border-gray-800 hover:border-rose-600 transition aspect-video bg-gray-900 relative group"
|
| 66 |
>
|
studio/src/components/JobCard.tsx
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { JobItem } from "../types";
|
| 2 |
+
|
| 3 |
+
interface JobCardProps {
|
| 4 |
+
job: JobItem;
|
| 5 |
+
}
|
| 6 |
+
|
| 7 |
+
function getProgress(job: JobItem): number | null {
|
| 8 |
+
if (typeof job.progress_percent === "number") {
|
| 9 |
+
return job.progress_percent;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
if (job.step_max && job.step_max > 0) {
|
| 13 |
+
return Math.round(((job.step_value || 0) / job.step_max) * 100);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
return null;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
export function JobCard({ job }: JobCardProps) {
|
| 20 |
+
const progress = getProgress(job);
|
| 21 |
+
const progressWidth = `${Math.max(3, progress ?? 3)}%`;
|
| 22 |
+
|
| 23 |
+
return (
|
| 24 |
+
<div className="rounded-lg overflow-hidden border border-amber-500/40 aspect-video bg-gray-950 relative p-4 flex flex-col justify-between">
|
| 25 |
+
<div className="absolute inset-0 bg-gradient-to-br from-amber-500/10 via-transparent to-rose-600/10" />
|
| 26 |
+
|
| 27 |
+
<div className="relative flex items-center justify-between gap-2 text-amber-300 text-xs font-medium uppercase tracking-wide">
|
| 28 |
+
<span className="flex items-center gap-2">
|
| 29 |
+
<span className="inline-block w-2 h-2 rounded-full bg-amber-400 animate-pulse" />
|
| 30 |
+
{job.status === "running" ? "Generating" : job.status}
|
| 31 |
+
{job.queue_position ? ` · Queue ${job.queue_position}` : ""}
|
| 32 |
+
</span>
|
| 33 |
+
{progress !== null && <span>{progress}%</span>}
|
| 34 |
+
</div>
|
| 35 |
+
|
| 36 |
+
<div className="relative space-y-2">
|
| 37 |
+
<p className="text-sm font-medium line-clamp-2 text-white/90">
|
| 38 |
+
{job.prompt || "Video generation job"}
|
| 39 |
+
</p>
|
| 40 |
+
|
| 41 |
+
<div className="space-y-1">
|
| 42 |
+
<div className="h-1.5 rounded-full bg-gray-800 overflow-hidden">
|
| 43 |
+
<div className="h-full bg-amber-400 transition-all" style={{ width: progressWidth }} />
|
| 44 |
+
</div>
|
| 45 |
+
|
| 46 |
+
<p className="text-[11px] text-gray-400 truncate">
|
| 47 |
+
{job.current_node ? `Node ${job.current_node}` : "Waiting for Comfy progress event"}
|
| 48 |
+
{job.step_max ? ` · step ${job.step_value || 0}/${job.step_max}` : ""}
|
| 49 |
+
{job.nodes_total ? ` · nodes ${job.nodes_finished || 0}/${job.nodes_total}` : ""}
|
| 50 |
+
</p>
|
| 51 |
+
</div>
|
| 52 |
+
</div>
|
| 53 |
+
|
| 54 |
+
<p className="relative text-[10px] text-gray-500 font-mono truncate">{job.prompt_id}</p>
|
| 55 |
+
</div>
|
| 56 |
+
);
|
| 57 |
+
}
|
studio/src/types.ts
CHANGED
|
@@ -6,3 +6,19 @@ export interface MediaItem {
|
|
| 6 |
mtime: number;
|
| 7 |
url: string;
|
| 8 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
mtime: number;
|
| 7 |
url: string;
|
| 8 |
}
|
| 9 |
+
|
| 10 |
+
export interface JobItem {
|
| 11 |
+
prompt_id: string;
|
| 12 |
+
status: "pending" | "running" | "unknown" | "completed" | "failed" | string;
|
| 13 |
+
mode?: string;
|
| 14 |
+
prompt?: string;
|
| 15 |
+
created_at?: string;
|
| 16 |
+
queue_position?: number | null;
|
| 17 |
+
current_node?: string | null;
|
| 18 |
+
step_value?: number;
|
| 19 |
+
step_max?: number;
|
| 20 |
+
nodes_finished?: number;
|
| 21 |
+
nodes_total?: number;
|
| 22 |
+
progress_percent?: number | null;
|
| 23 |
+
error?: string;
|
| 24 |
+
}
|
studio/tsconfig.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compilerOptions": {
|
| 3 |
+
"target": "ES2020",
|
| 4 |
+
"useDefineForClassFields": true,
|
| 5 |
+
"lib": ["DOM", "DOM.Iterable", "ES2020"],
|
| 6 |
+
"allowJs": false,
|
| 7 |
+
"skipLibCheck": true,
|
| 8 |
+
"esModuleInterop": true,
|
| 9 |
+
"allowSyntheticDefaultImports": true,
|
| 10 |
+
"strict": true,
|
| 11 |
+
"forceConsistentCasingInFileNames": true,
|
| 12 |
+
"module": "ESNext",
|
| 13 |
+
"moduleResolution": "Bundler",
|
| 14 |
+
"resolveJsonModule": true,
|
| 15 |
+
"isolatedModules": true,
|
| 16 |
+
"noEmit": true,
|
| 17 |
+
"jsx": "react-jsx"
|
| 18 |
+
},
|
| 19 |
+
"include": ["src"],
|
| 20 |
+
"references": [{ "path": "./tsconfig.node.json" }]
|
| 21 |
+
}
|
studio/tsconfig.node.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compilerOptions": {
|
| 3 |
+
"composite": true,
|
| 4 |
+
"skipLibCheck": true,
|
| 5 |
+
"module": "ESNext",
|
| 6 |
+
"moduleResolution": "Bundler",
|
| 7 |
+
"allowSyntheticDefaultImports": true
|
| 8 |
+
},
|
| 9 |
+
"include": ["vite.config.ts"]
|
| 10 |
+
}
|
training/README.md
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Nemoflix AMD Training
|
| 2 |
+
|
| 3 |
+
Scripts/config templates for disposable AMD MI300X droplets.
|
| 4 |
+
|
| 5 |
+
## Install AI Toolkit on the droplet
|
| 6 |
+
|
| 7 |
+
After `scripts/startup-script.sh` finishes on a fresh droplet:
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
cd /root/nemoflix
|
| 11 |
+
bash scripts/install-ai-toolkit.sh
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
The installer creates:
|
| 15 |
+
|
| 16 |
+
- `/root/ai-toolkit` — Ostris AI Toolkit checkout
|
| 17 |
+
- `/root/ai-toolkit-venv` — isolated ROCm Python venv
|
| 18 |
+
- `/root/nemoflix-training` — datasets/configs/output workspace
|
| 19 |
+
- `/root/nemoflix-training/run-ai-toolkit.sh` — CLI runner
|
| 20 |
+
|
| 21 |
+
## Train
|
| 22 |
+
|
| 23 |
+
Copy a config into `/root/nemoflix-training/config/`, put media in `/root/nemoflix-training/datasets/...`, then:
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
/root/nemoflix-training/run-ai-toolkit.sh /root/nemoflix-training/config/<job>.yaml
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Optional UI
|
| 30 |
+
|
| 31 |
+
The installer creates but does not start:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
ai-toolkit-ui.service
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
Before exposing it, set a real `AI_TOOLKIT_AUTH` value in the service or an override.
|
training/flux2_identity_template.yaml
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FLUX.2 [dev] identity/character LoRA template for AI Toolkit.
|
| 2 |
+
# Source: Ostris AI Toolkit FLUX.2 UI defaults + RunComfy FLUX.2 LoRA guide.
|
| 3 |
+
---
|
| 4 |
+
job: extension
|
| 5 |
+
config:
|
| 6 |
+
name: "identity_flux2_lora_v1"
|
| 7 |
+
process:
|
| 8 |
+
- type: "sd_trainer"
|
| 9 |
+
training_folder: "/root/nemoflix-training/output"
|
| 10 |
+
device: cuda:0
|
| 11 |
+
# Replace with the private trigger for the subject, e.g. a short unique token.
|
| 12 |
+
trigger_word: "character_trigger"
|
| 13 |
+
network:
|
| 14 |
+
type: "lora"
|
| 15 |
+
linear: 32
|
| 16 |
+
linear_alpha: 32
|
| 17 |
+
save:
|
| 18 |
+
dtype: bf16
|
| 19 |
+
save_every: 250
|
| 20 |
+
max_step_saves_to_keep: 4
|
| 21 |
+
push_to_hub: false
|
| 22 |
+
datasets:
|
| 23 |
+
- folder_path: "/root/nemoflix-training/datasets/identity-flux2"
|
| 24 |
+
caption_ext: "txt"
|
| 25 |
+
# Static captions + cached text embeddings: keep dropout at 0.
|
| 26 |
+
caption_dropout_rate: 0
|
| 27 |
+
shuffle_tokens: false
|
| 28 |
+
cache_latents_to_disk: true
|
| 29 |
+
resolution: [768, 896, 1024]
|
| 30 |
+
train:
|
| 31 |
+
batch_size: 1
|
| 32 |
+
steps: 1800
|
| 33 |
+
gradient_accumulation_steps: 1
|
| 34 |
+
train_unet: true
|
| 35 |
+
train_text_encoder: false
|
| 36 |
+
gradient_checkpointing: true
|
| 37 |
+
noise_scheduler: "flowmatch"
|
| 38 |
+
timestep_type: "weighted"
|
| 39 |
+
optimizer: "adamw8bit"
|
| 40 |
+
lr: 1e-4
|
| 41 |
+
optimizer_params:
|
| 42 |
+
weight_decay: 1e-4
|
| 43 |
+
dtype: bf16
|
| 44 |
+
# First pass: no DOP, static captions, cache embeddings for speed/VRAM.
|
| 45 |
+
unload_text_encoder: false
|
| 46 |
+
cache_text_embeddings: true
|
| 47 |
+
ema_config:
|
| 48 |
+
use_ema: false
|
| 49 |
+
model:
|
| 50 |
+
name_or_path: "black-forest-labs/FLUX.2-dev"
|
| 51 |
+
arch: "flux2"
|
| 52 |
+
quantize: true
|
| 53 |
+
qtype: "qfloat8"
|
| 54 |
+
quantize_te: true
|
| 55 |
+
qtype_te: "qfloat8"
|
| 56 |
+
low_vram: false
|
| 57 |
+
model_kwargs:
|
| 58 |
+
match_target_res: false
|
| 59 |
+
sample:
|
| 60 |
+
sampler: "flowmatch"
|
| 61 |
+
sample_every: 250
|
| 62 |
+
width: 1024
|
| 63 |
+
height: 1024
|
| 64 |
+
prompts:
|
| 65 |
+
- "character_trigger, man, realistic Instagram travel photo, standing on a cliff at golden hour, ocean in the background, natural pose, shot on a mirrorless camera"
|
| 66 |
+
- "character_trigger, man, lifestyle creator photo, sitting at an outdoor cafe with a laptop and coffee, warm afternoon light, candid social media photography"
|
| 67 |
+
- "character_trigger, man, action sports photo, snowboarding down a mountain slope, powder snow, dynamic pose, realistic telephoto shot"
|
| 68 |
+
- "character_trigger, man, fitness lifestyle photo, post-workout portrait outside a modern gym, athletic casual clothing, natural light, authentic Instagram content"
|
| 69 |
+
- "character_trigger, man, editorial streetwear photo, walking through a downtown city street at sunset, stylish outfit, shallow depth of field, realistic fashion photography"
|
| 70 |
+
- "photo of a man, realistic social media portrait, white background, medium shot, studio lighting"
|
| 71 |
+
neg: ""
|
| 72 |
+
seed: 42
|
| 73 |
+
walk_seed: false
|
| 74 |
+
guidance_scale: 1
|
| 75 |
+
sample_steps: 25
|
| 76 |
+
meta:
|
| 77 |
+
name: "[name]"
|
| 78 |
+
version: "1.0"
|
training/wan22_i2v_character_template.yaml
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Wan 2.2 I2V 14B character LoRA for a character likeness.
|
| 2 |
+
# Source template: ostris/ai-toolkit config/examples/train_lora_wan22_14b_24gb.yaml
|
| 3 |
+
# I2V-specific arch/options confirmed in ui/src/app/jobs/new/options.ts.
|
| 4 |
+
---
|
| 5 |
+
job: extension
|
| 6 |
+
config:
|
| 7 |
+
name: "character_wan22_i2v_v1"
|
| 8 |
+
process:
|
| 9 |
+
- type: "sd_trainer"
|
| 10 |
+
training_folder: "/root/nemoflix-training/output"
|
| 11 |
+
device: cuda:0
|
| 12 |
+
trigger_word: "character_trigger"
|
| 13 |
+
network:
|
| 14 |
+
type: "lora"
|
| 15 |
+
linear: 32
|
| 16 |
+
linear_alpha: 32
|
| 17 |
+
save:
|
| 18 |
+
dtype: bf16
|
| 19 |
+
save_every: 250
|
| 20 |
+
max_step_saves_to_keep: 6
|
| 21 |
+
datasets:
|
| 22 |
+
- folder_path: "/root/nemoflix-training/datasets/character-wan-i2v"
|
| 23 |
+
caption_ext: "txt"
|
| 24 |
+
caption_dropout_rate: 0.05
|
| 25 |
+
num_frames: 41
|
| 26 |
+
resolution: [512, 768, 1024]
|
| 27 |
+
train:
|
| 28 |
+
batch_size: 1
|
| 29 |
+
steps: 2500
|
| 30 |
+
gradient_accumulation: 1
|
| 31 |
+
train_unet: true
|
| 32 |
+
train_text_encoder: false
|
| 33 |
+
gradient_checkpointing: true
|
| 34 |
+
noise_scheduler: "flowmatch"
|
| 35 |
+
timestep_type: "sigmoid"
|
| 36 |
+
optimizer: "adamw8bit"
|
| 37 |
+
lr: 1e-4
|
| 38 |
+
optimizer_params:
|
| 39 |
+
weight_decay: 1e-4
|
| 40 |
+
dtype: bf16
|
| 41 |
+
switch_boundary_every: 10
|
| 42 |
+
# Keep captions live for character training and DOP experiments.
|
| 43 |
+
# If memory is tight, switch to cache_text_embeddings and remove DOP.
|
| 44 |
+
cache_text_embeddings: false
|
| 45 |
+
model:
|
| 46 |
+
name_or_path: "ai-toolkit/Wan2.2-I2V-A14B-Diffusers-bf16"
|
| 47 |
+
arch: "wan22_14b_i2v"
|
| 48 |
+
quantize: true
|
| 49 |
+
qtype: "uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors"
|
| 50 |
+
quantize_te: true
|
| 51 |
+
qtype_te: "qfloat8"
|
| 52 |
+
low_vram: false
|
| 53 |
+
model_kwargs:
|
| 54 |
+
train_high_noise: true
|
| 55 |
+
train_low_noise: true
|
| 56 |
+
sample:
|
| 57 |
+
sampler: "flowmatch"
|
| 58 |
+
sample_every: 250
|
| 59 |
+
width: 768
|
| 60 |
+
height: 768
|
| 61 |
+
num_frames: 41
|
| 62 |
+
fps: 16
|
| 63 |
+
# Wan I2V samples require a prompt + ctrl_img pair.
|
| 64 |
+
# Replace this once we have character control/reference frames.
|
| 65 |
+
ctrl_img: "/root/nemoflix-training/samples/character_control.jpg"
|
| 66 |
+
prompts:
|
| 67 |
+
- "character_trigger, person, cinematic portrait, walking in heavy rain, dramatic lighting"
|
| 68 |
+
- "character_trigger, person, wearing futuristic armor, rain-soaked city street, cinematic"
|
| 69 |
+
- "character_trigger, person, close-up face, intense expression, film still, shallow depth of field"
|
| 70 |
+
neg: ""
|
| 71 |
+
seed: 42
|
| 72 |
+
walk_seed: true
|
| 73 |
+
guidance_scale: 3.5
|
| 74 |
+
sample_steps: 25
|
| 75 |
+
meta:
|
| 76 |
+
name: "[name]"
|
| 77 |
+
version: "1.0"
|