Rodrigo Ortega commited on
Commit
62ab980
·
1 Parent(s): 27377d7

Add AI Toolkit training setup and live job progress

Browse files
app/nemoflix_amd/api.py CHANGED
@@ -1,9 +1,15 @@
1
  from __future__ import annotations
2
 
 
 
 
3
  import tempfile
 
4
  from pathlib import Path
5
  from typing import Any, Literal
6
- from uuid import uuid4
 
 
7
 
8
  from fastapi import FastAPI, File, HTTPException, UploadFile
9
  from pydantic import BaseModel, Field
@@ -84,6 +90,143 @@ def comfy() -> ComfyClient:
84
  return ComfyClient(settings.comfy_url, settings.request_timeout_seconds)
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  @app.get("/api/health")
88
  async def health() -> dict[str, Any]:
89
  client = comfy()
@@ -170,14 +313,16 @@ async def generate_video(body: VideoGenerateRequest) -> VideoGenerateResponse:
170
  return VideoGenerateResponse(ok=True, mode=body.mode, workflow=workflow)
171
 
172
  try:
173
- result = await comfy().queue_prompt(workflow, client_id=str(uuid4()))
174
  except Exception as exc: # noqa: BLE001
175
  raise HTTPException(status_code=502, detail=f"ComfyUI prompt submission failed: {exc}") from exc
176
 
 
 
177
  return VideoGenerateResponse(
178
  ok="prompt_id" in result,
179
  mode=body.mode,
180
- prompt_id=result.get("prompt_id"),
181
  number=result.get("number"),
182
  node_errors=result.get("node_errors"),
183
  )
@@ -224,6 +369,54 @@ async def _queue_position(client: ComfyClient, prompt_id: str) -> int | None:
224
  return None
225
 
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  @app.get("/api/jobs/{prompt_id}", response_model=JobStatusResponse)
228
  async def job(prompt_id: str) -> JobStatusResponse:
229
  client = comfy()
 
1
  from __future__ import annotations
2
 
3
+ import asyncio
4
+ import contextlib
5
+ import json
6
  import tempfile
7
+ from datetime import UTC, datetime
8
  from pathlib import Path
9
  from typing import Any, Literal
10
+ from urllib.parse import urlparse, urlunparse
11
+
12
+ import websockets
13
 
14
  from fastapi import FastAPI, File, HTTPException, UploadFile
15
  from pydantic import BaseModel, Field
 
90
  return ComfyClient(settings.comfy_url, settings.request_timeout_seconds)
91
 
92
 
93
+ _JOBS: dict[str, dict[str, Any]] = {}
94
+ _COMFY_CLIENT_ID = "nemoflix-amd-gallery"
95
+ _WS_TASK: asyncio.Task | None = None
96
+
97
+
98
+ def _ws_url(base_url: str) -> str:
99
+ parsed = urlparse(base_url)
100
+ scheme = "wss" if parsed.scheme == "https" else "ws"
101
+ return urlunparse((scheme, parsed.netloc, "/ws", "", f"clientId={_COMFY_CLIENT_ID}", ""))
102
+
103
+
104
+ def _register_submitted_job(prompt_id: str | None, body: VideoGenerateRequest, status: str = "pending") -> None:
105
+ if not prompt_id:
106
+ return
107
+ _JOBS[prompt_id] = {
108
+ "prompt_id": prompt_id,
109
+ "status": status,
110
+ "mode": body.mode,
111
+ "prompt": body.prompt,
112
+ "width": body.width,
113
+ "height": body.height,
114
+ "length": body.length,
115
+ "fps": body.fps,
116
+ "created_at": datetime.now(UTC).isoformat(),
117
+ "current_node": None,
118
+ "step_value": 0,
119
+ "step_max": 0,
120
+ "nodes_finished": 0,
121
+ "nodes_total": 0,
122
+ "progress_percent": None,
123
+ }
124
+
125
+
126
+ def _update_job_from_progress_state(prompt_id: str, nodes: dict[str, Any]) -> None:
127
+ if not prompt_id:
128
+ return
129
+ job = _JOBS.setdefault(prompt_id, {"prompt_id": prompt_id, "status": "running", "created_at": None})
130
+ total = len(nodes)
131
+ finished = 0
132
+ running = 0
133
+ current_node = None
134
+ step_value = 0
135
+ step_max = 0
136
+ for node_id, node in nodes.items():
137
+ if not isinstance(node, dict):
138
+ continue
139
+ state = node.get("state")
140
+ if state == "finished":
141
+ finished += 1
142
+ elif state == "running":
143
+ running += 1
144
+ if current_node is None:
145
+ current_node = node.get("display_node_id") or node.get("node_id") or node_id
146
+ step_value = int(node.get("value") or 0)
147
+ step_max = int(node.get("max") or 0)
148
+ percent = round((finished / total) * 100, 1) if total else None
149
+ job.update({
150
+ "status": "running",
151
+ "nodes_total": total,
152
+ "nodes_finished": finished,
153
+ "nodes_running": running,
154
+ "current_node": current_node,
155
+ "step_value": step_value,
156
+ "step_max": step_max,
157
+ "progress_percent": percent,
158
+ "updated_at": datetime.now(UTC).isoformat(),
159
+ })
160
+
161
+
162
+ async def _comfy_ws_bridge() -> None:
163
+ settings = get_settings()
164
+ url = _ws_url(settings.comfy_url)
165
+ while True:
166
+ try:
167
+ async with websockets.connect(url, ping_interval=20, ping_timeout=20) as ws:
168
+ async for raw in ws:
169
+ if isinstance(raw, bytes):
170
+ continue
171
+ try:
172
+ msg = json.loads(raw)
173
+ except json.JSONDecodeError:
174
+ continue
175
+ msg_type = msg.get("type")
176
+ data = msg.get("data", {}) if isinstance(msg.get("data"), dict) else {}
177
+ prompt_id = data.get("prompt_id") or data.get("prompt")
178
+ if msg_type == "execution_start" and isinstance(prompt_id, str):
179
+ _JOBS.setdefault(prompt_id, {"prompt_id": prompt_id, "created_at": None}).update({"status": "running"})
180
+ elif msg_type == "progress_state" and isinstance(prompt_id, str):
181
+ nodes = data.get("nodes", {})
182
+ if isinstance(nodes, dict):
183
+ _update_job_from_progress_state(prompt_id, nodes)
184
+ elif msg_type == "progress" and isinstance(prompt_id, str):
185
+ job = _JOBS.setdefault(prompt_id, {"prompt_id": prompt_id, "status": "running", "created_at": None})
186
+ value = int(data.get("value") or 0)
187
+ max_value = int(data.get("max") or 0)
188
+ job.update({
189
+ "status": "running",
190
+ "step_value": value,
191
+ "step_max": max_value,
192
+ "progress_percent": round((value / max_value) * 100, 1) if max_value else None,
193
+ "updated_at": datetime.now(UTC).isoformat(),
194
+ })
195
+ elif msg_type == "execution_success" and isinstance(prompt_id, str):
196
+ _JOBS.setdefault(prompt_id, {"prompt_id": prompt_id, "created_at": None}).update({
197
+ "status": "completed",
198
+ "progress_percent": 100,
199
+ "updated_at": datetime.now(UTC).isoformat(),
200
+ })
201
+ elif msg_type in {"execution_error", "execution_interrupted"} and isinstance(prompt_id, str):
202
+ _JOBS.setdefault(prompt_id, {"prompt_id": prompt_id, "created_at": None}).update({
203
+ "status": "failed",
204
+ "error": data.get("exception_message") or msg_type,
205
+ "updated_at": datetime.now(UTC).isoformat(),
206
+ })
207
+ except asyncio.CancelledError:
208
+ raise
209
+ except Exception:
210
+ await asyncio.sleep(3)
211
+
212
+
213
+ @app.on_event("startup")
214
+ async def start_comfy_bridge() -> None:
215
+ global _WS_TASK
216
+ if _WS_TASK is None or _WS_TASK.done():
217
+ _WS_TASK = asyncio.create_task(_comfy_ws_bridge())
218
+
219
+
220
+ @app.on_event("shutdown")
221
+ async def stop_comfy_bridge() -> None:
222
+ global _WS_TASK
223
+ if _WS_TASK:
224
+ _WS_TASK.cancel()
225
+ with contextlib.suppress(asyncio.CancelledError):
226
+ await _WS_TASK
227
+ _WS_TASK = None
228
+
229
+
230
  @app.get("/api/health")
231
  async def health() -> dict[str, Any]:
232
  client = comfy()
 
313
  return VideoGenerateResponse(ok=True, mode=body.mode, workflow=workflow)
314
 
315
  try:
316
+ result = await comfy().queue_prompt(workflow, client_id=_COMFY_CLIENT_ID)
317
  except Exception as exc: # noqa: BLE001
318
  raise HTTPException(status_code=502, detail=f"ComfyUI prompt submission failed: {exc}") from exc
319
 
320
+ prompt_id = result.get("prompt_id")
321
+ _register_submitted_job(prompt_id, body, "pending")
322
  return VideoGenerateResponse(
323
  ok="prompt_id" in result,
324
  mode=body.mode,
325
+ prompt_id=prompt_id,
326
  number=result.get("number"),
327
  node_errors=result.get("node_errors"),
328
  )
 
369
  return None
370
 
371
 
372
+ @app.get("/api/jobs")
373
+ async def jobs() -> dict[str, Any]:
374
+ """Return jobs submitted through this API.
375
+
376
+ ComfyUI is still the execution engine, but this endpoint intentionally does
377
+ not list arbitrary Comfy queue entries. The gallery should only show jobs we
378
+ submitted and registered locally; completed media is discovered separately by
379
+ /api/listing.
380
+ """
381
+ client = comfy()
382
+
383
+ try:
384
+ queue = await client.get("/queue")
385
+ except Exception as exc: # noqa: BLE001
386
+ jobs_list = sorted(_JOBS.values(), key=lambda j: j.get("created_at") or "", reverse=True)
387
+ return {"jobs": jobs_list, "count": len(jobs_list), "error": str(exc)}
388
+
389
+ running_ids: set[str] = set()
390
+ pending_positions: dict[str, int] = {}
391
+
392
+ for item in queue.get("queue_running", []) if isinstance(queue, dict) else []:
393
+ if isinstance(item, list) and len(item) > 1 and isinstance(item[1], str):
394
+ running_ids.add(item[1])
395
+
396
+ for position, item in enumerate(queue.get("queue_pending", []) if isinstance(queue, dict) else [], start=1):
397
+ if isinstance(item, list) and len(item) > 1 and isinstance(item[1], str):
398
+ pending_positions[item[1]] = position
399
+
400
+ for prompt_id, job in _JOBS.items():
401
+ if job.get("status") in {"completed", "failed"}:
402
+ continue
403
+ if prompt_id in running_ids:
404
+ job["status"] = "running"
405
+ job["queue_position"] = None
406
+ elif prompt_id in pending_positions:
407
+ job["status"] = "pending"
408
+ job["queue_position"] = pending_positions[prompt_id]
409
+ elif job.get("status") in {"pending", "running"}:
410
+ job["status"] = "unknown"
411
+ job["queue_position"] = None
412
+
413
+ jobs_list = sorted(
414
+ _JOBS.values(),
415
+ key=lambda j: (j.get("status") != "running", j.get("queue_position") or 0, j.get("created_at") or ""),
416
+ )
417
+ return {"jobs": jobs_list, "count": len(jobs_list)}
418
+
419
+
420
  @app.get("/api/jobs/{prompt_id}", response_model=JobStatusResponse)
421
  async def job(prompt_id: str) -> JobStatusResponse:
422
  client = comfy()
requirements.txt CHANGED
@@ -3,3 +3,4 @@ uvicorn[standard]==0.34.0
3
  httpx==0.28.1
4
  python-multipart==0.0.20
5
  pydantic-settings==2.7.1
 
 
3
  httpx==0.28.1
4
  python-multipart==0.0.20
5
  pydantic-settings==2.7.1
6
+ websockets==16.0
scripts/install-ai-toolkit.sh ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -Eeuo pipefail
3
+ set -x
4
+
5
+ # Install Ostris AI Toolkit on a disposable AMD MI300X ROCm droplet.
6
+ # Run after scripts/startup-script.sh so ROCm/system basics are already present.
7
+ # This is intentionally idempotent: safe to rerun on a fresh or partially initialized box.
8
+
9
+ APT_GET="apt-get -o DPkg::Lock::Timeout=300"
10
+ TOOLKIT_DIR="${TOOLKIT_DIR:-/root/ai-toolkit}"
11
+ TOOLKIT_VENV="${TOOLKIT_VENV:-/root/ai-toolkit-venv}"
12
+ TRAINING_DIR="${TRAINING_DIR:-/root/nemoflix-training}"
13
+ ROCM_INDEX_PRIMARY="${ROCM_INDEX_PRIMARY:-https://download.pytorch.org/whl/rocm7.2}"
14
+ ROCM_INDEX_FALLBACK="${ROCM_INDEX_FALLBACK:-https://download.pytorch.org/whl/rocm7.0}"
15
+ AI_TOOLKIT_REF="${AI_TOOLKIT_REF:-main}"
16
+ # CLI training is the default. The UI pulls NodeSource apt repo + nodejs and can
17
+ # trigger service restart/deferred-restart behavior on cloud images; keep it opt-in.
18
+ INSTALL_UI_DEPS="${INSTALL_UI_DEPS:-0}"
19
+ PYTHON_BIN="$TOOLKIT_VENV/bin/python"
20
+ export DEBIAN_FRONTEND=noninteractive
21
+ export NEEDRESTART_MODE=a
22
+
23
+ trap 'echo "ERROR: AI Toolkit install failed at line $LINENO"' ERR
24
+
25
+ if [ "$(id -u)" -ne 0 ]; then
26
+ echo "Run as root on the AMD droplet."
27
+ exit 1
28
+ fi
29
+
30
+ echo "=== Installing AI Toolkit prerequisites ==="
31
+ $APT_GET update -y
32
+ $APT_GET install -y git git-lfs python3-pip python3.12-venv python3-dev build-essential pkg-config curl wget ffmpeg libgl1 libglib2.0-0
33
+
34
+ git lfs install --system || true
35
+
36
+ if command -v /opt/rocm/bin/rocm-smi >/dev/null 2>&1; then
37
+ echo "=== ROCm GPU check ==="
38
+ /opt/rocm/bin/rocm-smi || true
39
+ fi
40
+
41
+ echo "=== Creating isolated AI Toolkit venv: $TOOLKIT_VENV ==="
42
+ if [ ! -d "$TOOLKIT_VENV" ]; then
43
+ python3 -m venv "$TOOLKIT_VENV"
44
+ fi
45
+ "$PYTHON_BIN" -m pip install --upgrade pip setuptools wheel
46
+
47
+ echo "=== Installing ROCm PyTorch into AI Toolkit venv ==="
48
+ "$PYTHON_BIN" -m pip install torch torchvision torchaudio --index-url "$ROCM_INDEX_PRIMARY" || \
49
+ "$PYTHON_BIN" -m pip install torch torchvision torchaudio --index-url "$ROCM_INDEX_FALLBACK"
50
+
51
+ "$PYTHON_BIN" - <<'PY'
52
+ import torch
53
+ print('torch', torch.__version__)
54
+ print('cuda api available', torch.cuda.is_available())
55
+ if torch.cuda.is_available():
56
+ print('device', torch.cuda.get_device_name(0))
57
+ PY
58
+
59
+ echo "=== Cloning/updating Ostris AI Toolkit ==="
60
+ if [ ! -d "$TOOLKIT_DIR/.git" ]; then
61
+ git clone https://github.com/ostris/ai-toolkit.git "$TOOLKIT_DIR"
62
+ fi
63
+ git -C "$TOOLKIT_DIR" fetch --depth 1 origin "$AI_TOOLKIT_REF"
64
+ git -C "$TOOLKIT_DIR" checkout FETCH_HEAD
65
+ git -C "$TOOLKIT_DIR" submodule update --init --recursive
66
+
67
+ echo "=== Installing AI Toolkit Python requirements ==="
68
+ # Keep the ROCm torch we installed above; do not allow requirements to swap in CUDA wheels.
69
+ "$PYTHON_BIN" -m pip install -r "$TOOLKIT_DIR/requirements.txt" --extra-index-url "$ROCM_INDEX_PRIMARY"
70
+ "$PYTHON_BIN" -m pip install --upgrade accelerate huggingface_hub hf_transfer
71
+
72
+ mkdir -p \
73
+ "$TRAINING_DIR/datasets" \
74
+ "$TRAINING_DIR/output" \
75
+ "$TRAINING_DIR/samples" \
76
+ "$TRAINING_DIR/config" \
77
+ "$TOOLKIT_DIR/config"
78
+
79
+ # If this script is run from a cloned Nemoflix repo, seed our checked-in config templates
80
+ # into the disposable training workspace.
81
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
82
+ REPO_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
83
+ if compgen -G "$REPO_DIR/training/*.yaml" >/dev/null; then
84
+ cp -f "$REPO_DIR"/training/*.yaml "$TRAINING_DIR/config/"
85
+ fi
86
+
87
+ cat > "$TRAINING_DIR/README.md" <<'EOF'
88
+ # Nemoflix AI Toolkit Training Workspace
89
+
90
+ Persistent-ish training layout for disposable AMD droplets.
91
+
92
+ ## Paths
93
+
94
+ - AI Toolkit: `/root/ai-toolkit`
95
+ - Venv: `/root/ai-toolkit-venv`
96
+ - Datasets: `/root/nemoflix-training/datasets`
97
+ - Configs: `/root/nemoflix-training/config`
98
+ - Outputs/checkpoints: `/root/nemoflix-training/output`
99
+ - Sample control images: `/root/nemoflix-training/samples`
100
+
101
+ ## Run a config
102
+
103
+ ```bash
104
+ cd /root/ai-toolkit
105
+ /root/ai-toolkit-venv/bin/python run.py /root/nemoflix-training/config/<config>.yaml
106
+ ```
107
+
108
+ ## Hugging Face token
109
+
110
+ For gated models, create `/root/ai-toolkit/.env`:
111
+
112
+ ```bash
113
+ HF_TOKEN=hf_xxx
114
+ HF_HUB_ENABLE_HF_TRANSFER=1
115
+ ```
116
+
117
+ ## Character dataset conventions
118
+
119
+ Image/FLUX LoRA:
120
+ - 20-40 good face/body images
121
+ - one `.txt` caption beside each image
122
+ - include trigger word, e.g. `character_trigger, person, portrait, natural lighting`
123
+
124
+ Wan I2V character LoRA:
125
+ - 10-30 short clips, ideally 3-8 seconds
126
+ - one `.txt` caption beside each clip
127
+ - include trigger word, e.g. `character_trigger, person, walking outdoors, close-up face`
128
+ - trim dead time; varied angles/lighting/backgrounds; avoid sunglasses
129
+ EOF
130
+
131
+ cat > "$TRAINING_DIR/run-ai-toolkit.sh" <<'EOF'
132
+ #!/bin/bash
133
+ set -Eeuo pipefail
134
+ CONFIG_PATH="${1:?Usage: /root/nemoflix-training/run-ai-toolkit.sh /root/nemoflix-training/config/job.yaml}"
135
+ cd /root/ai-toolkit
136
+ export HF_HUB_ENABLE_HF_TRANSFER="${HF_HUB_ENABLE_HF_TRANSFER:-1}"
137
+ exec /root/ai-toolkit-venv/bin/python run.py "$CONFIG_PATH"
138
+ EOF
139
+ chmod +x "$TRAINING_DIR/run-ai-toolkit.sh"
140
+
141
+ if [ "$INSTALL_UI_DEPS" = "1" ]; then
142
+ if ! command -v node >/dev/null 2>&1 || ! node -e 'process.exit(Number(process.versions.node.split(".")[0]) >= 20 ? 0 : 1)' 2>/dev/null; then
143
+ echo "=== Installing Node.js 22 for optional AI Toolkit UI ==="
144
+ curl -fsSL https://deb.nodesource.com/setup_22.x | bash -
145
+ $APT_GET install -y nodejs
146
+ fi
147
+ if [ -f "$TOOLKIT_DIR/ui/package.json" ]; then
148
+ echo "=== Installing AI Toolkit UI dependencies ==="
149
+ (cd "$TOOLKIT_DIR/ui" && npm install)
150
+ fi
151
+ fi
152
+
153
+ if [ "$INSTALL_UI_DEPS" = "1" ]; then
154
+ cat > /etc/systemd/system/ai-toolkit-ui.service <<EOF
155
+ [Unit]
156
+ Description=Ostris AI Toolkit UI
157
+ After=network-online.target
158
+ Wants=network-online.target
159
+
160
+ [Service]
161
+ Type=simple
162
+ User=root
163
+ WorkingDirectory=$TOOLKIT_DIR/ui
164
+ Environment="PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
165
+ Environment="AI_TOOLKIT_AUTH=change-me"
166
+ ExecStart=/usr/bin/npm run build_and_start
167
+ Restart=on-failure
168
+ RestartSec=5
169
+
170
+ [Install]
171
+ WantedBy=multi-user.target
172
+ EOF
173
+ systemctl daemon-reload
174
+ fi
175
+
176
+ "$PYTHON_BIN" - <<'PY'
177
+ import importlib
178
+ mods = ['torch', 'accelerate', 'diffusers', 'transformers', 'huggingface_hub']
179
+ for name in mods:
180
+ mod = importlib.import_module(name)
181
+ print(name, getattr(mod, '__version__', 'ok'))
182
+ PY
183
+
184
+ echo "=== AI Toolkit install complete ==="
185
+ echo "Toolkit: $TOOLKIT_DIR"
186
+ echo "Venv: $TOOLKIT_VENV"
187
+ echo "Training: $TRAINING_DIR"
188
+ echo "Run CLI: $TRAINING_DIR/run-ai-toolkit.sh /root/nemoflix-training/config/job.yaml"
189
+ echo "UI: optional; rerun with INSTALL_UI_DEPS=1, then set AI_TOOLKIT_AUTH before starting ai-toolkit-ui.service"
scripts/startup-script.sh CHANGED
@@ -147,3 +147,4 @@ curl -sS --max-time 5 http://127.0.0.1:8190/api/health
147
 
148
  echo "=== Setup Complete ==="
149
  echo "Install Wan 2.2 video stack: $APP_DIR/scripts/install-video-stack.sh"
 
 
147
 
148
  echo "=== Setup Complete ==="
149
  echo "Install Wan 2.2 video stack: $APP_DIR/scripts/install-video-stack.sh"
150
+ echo "Install AI Toolkit training stack: $APP_DIR/scripts/install-ai-toolkit.sh"
studio/src/App.tsx CHANGED
@@ -1,17 +1,24 @@
1
  import { useState, useEffect, useCallback } from "react";
2
- import type { MediaItem } from "./types";
 
3
 
4
  export default function App() {
5
  const [items, setItems] = useState<MediaItem[]>([]);
 
6
  const [loading, setLoading] = useState(true);
7
  const [selected, setSelected] = useState<string | null>(null);
8
 
9
  const load = useCallback(async () => {
10
  setLoading(true);
11
  try {
12
- const res = await fetch("/api/listing");
13
- const data = await res.json();
14
- setItems(data.images || []);
 
 
 
 
 
15
  } catch (e) {
16
  console.error(e);
17
  } finally {
@@ -21,25 +28,39 @@ export default function App() {
21
 
22
  useEffect(() => {
23
  load();
 
 
24
  }, [load]);
25
 
 
 
26
  return (
27
  <div className="min-h-screen bg-black text-white">
28
  <header className="border-b border-gray-800 px-6 py-4 flex items-center justify-between">
29
- <h1 className="text-xl font-semibold">Nemoflix AMD Gallery</h1>
30
- <span className="text-sm text-gray-500">{items.length} items</span>
 
 
 
 
 
 
31
  </header>
32
 
33
  <main className="p-6">
34
- {loading && items.length === 0 ? (
35
  <p className="text-gray-500">Loading...</p>
36
- ) : items.length === 0 ? (
37
  <p className="text-gray-500">No media yet.</p>
38
  ) : (
39
  <div className="grid grid-cols-2 md:grid-cols-3 lg:grid-cols-4 xl:grid-cols-5 gap-4">
 
 
 
 
40
  {items.map((item) => (
41
  <div
42
- key={item.name}
43
  onClick={() => setSelected(item.url)}
44
  className="cursor-pointer rounded-lg overflow-hidden border border-gray-800 hover:border-rose-600 transition aspect-video bg-gray-900 relative group"
45
  >
 
1
  import { useState, useEffect, useCallback } from "react";
2
+ import { JobCard } from "./components/JobCard";
3
+ import type { JobItem, MediaItem } from "./types";
4
 
5
  export default function App() {
6
  const [items, setItems] = useState<MediaItem[]>([]);
7
+ const [jobs, setJobs] = useState<JobItem[]>([]);
8
  const [loading, setLoading] = useState(true);
9
  const [selected, setSelected] = useState<string | null>(null);
10
 
11
  const load = useCallback(async () => {
12
  setLoading(true);
13
  try {
14
+ const [listingRes, jobsRes] = await Promise.all([
15
+ fetch("/api/listing"),
16
+ fetch("/api/jobs"),
17
+ ]);
18
+ const listing = await listingRes.json();
19
+ const jobData = await jobsRes.json();
20
+ setItems(listing.images || []);
21
+ setJobs(jobData.jobs || []);
22
  } catch (e) {
23
  console.error(e);
24
  } finally {
 
28
 
29
  useEffect(() => {
30
  load();
31
+ const id = window.setInterval(load, 3000);
32
+ return () => window.clearInterval(id);
33
  }, [load]);
34
 
35
+ const hasContent = jobs.length > 0 || items.length > 0;
36
+
37
  return (
38
  <div className="min-h-screen bg-black text-white">
39
  <header className="border-b border-gray-800 px-6 py-4 flex items-center justify-between">
40
+ <div>
41
+ <h1 className="text-xl font-semibold">Nemoflix AMD Gallery</h1>
42
+ <p className="text-xs text-gray-500 mt-1">Live from the MI300X droplet</p>
43
+ </div>
44
+ <div className="text-sm text-gray-500">
45
+ {jobs.length > 0 && <span className="text-amber-400 mr-3">{jobs.length} generating</span>}
46
+ <span>{items.length} media</span>
47
+ </div>
48
  </header>
49
 
50
  <main className="p-6">
51
+ {loading && !hasContent ? (
52
  <p className="text-gray-500">Loading...</p>
53
+ ) : !hasContent ? (
54
  <p className="text-gray-500">No media yet.</p>
55
  ) : (
56
  <div className="grid grid-cols-2 md:grid-cols-3 lg:grid-cols-4 xl:grid-cols-5 gap-4">
57
+ {jobs.map((job) => (
58
+ <JobCard key={job.prompt_id} job={job} />
59
+ ))}
60
+
61
  {items.map((item) => (
62
  <div
63
+ key={item.url}
64
  onClick={() => setSelected(item.url)}
65
  className="cursor-pointer rounded-lg overflow-hidden border border-gray-800 hover:border-rose-600 transition aspect-video bg-gray-900 relative group"
66
  >
studio/src/components/JobCard.tsx ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { JobItem } from "../types";
2
+
3
+ interface JobCardProps {
4
+ job: JobItem;
5
+ }
6
+
7
+ function getProgress(job: JobItem): number | null {
8
+ if (typeof job.progress_percent === "number") {
9
+ return job.progress_percent;
10
+ }
11
+
12
+ if (job.step_max && job.step_max > 0) {
13
+ return Math.round(((job.step_value || 0) / job.step_max) * 100);
14
+ }
15
+
16
+ return null;
17
+ }
18
+
19
+ export function JobCard({ job }: JobCardProps) {
20
+ const progress = getProgress(job);
21
+ const progressWidth = `${Math.max(3, progress ?? 3)}%`;
22
+
23
+ return (
24
+ <div className="rounded-lg overflow-hidden border border-amber-500/40 aspect-video bg-gray-950 relative p-4 flex flex-col justify-between">
25
+ <div className="absolute inset-0 bg-gradient-to-br from-amber-500/10 via-transparent to-rose-600/10" />
26
+
27
+ <div className="relative flex items-center justify-between gap-2 text-amber-300 text-xs font-medium uppercase tracking-wide">
28
+ <span className="flex items-center gap-2">
29
+ <span className="inline-block w-2 h-2 rounded-full bg-amber-400 animate-pulse" />
30
+ {job.status === "running" ? "Generating" : job.status}
31
+ {job.queue_position ? ` · Queue ${job.queue_position}` : ""}
32
+ </span>
33
+ {progress !== null && <span>{progress}%</span>}
34
+ </div>
35
+
36
+ <div className="relative space-y-2">
37
+ <p className="text-sm font-medium line-clamp-2 text-white/90">
38
+ {job.prompt || "Video generation job"}
39
+ </p>
40
+
41
+ <div className="space-y-1">
42
+ <div className="h-1.5 rounded-full bg-gray-800 overflow-hidden">
43
+ <div className="h-full bg-amber-400 transition-all" style={{ width: progressWidth }} />
44
+ </div>
45
+
46
+ <p className="text-[11px] text-gray-400 truncate">
47
+ {job.current_node ? `Node ${job.current_node}` : "Waiting for Comfy progress event"}
48
+ {job.step_max ? ` · step ${job.step_value || 0}/${job.step_max}` : ""}
49
+ {job.nodes_total ? ` · nodes ${job.nodes_finished || 0}/${job.nodes_total}` : ""}
50
+ </p>
51
+ </div>
52
+ </div>
53
+
54
+ <p className="relative text-[10px] text-gray-500 font-mono truncate">{job.prompt_id}</p>
55
+ </div>
56
+ );
57
+ }
studio/src/types.ts CHANGED
@@ -6,3 +6,19 @@ export interface MediaItem {
6
  mtime: number;
7
  url: string;
8
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  mtime: number;
7
  url: string;
8
  }
9
+
10
+ export interface JobItem {
11
+ prompt_id: string;
12
+ status: "pending" | "running" | "unknown" | "completed" | "failed" | string;
13
+ mode?: string;
14
+ prompt?: string;
15
+ created_at?: string;
16
+ queue_position?: number | null;
17
+ current_node?: string | null;
18
+ step_value?: number;
19
+ step_max?: number;
20
+ nodes_finished?: number;
21
+ nodes_total?: number;
22
+ progress_percent?: number | null;
23
+ error?: string;
24
+ }
studio/tsconfig.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ "target": "ES2020",
4
+ "useDefineForClassFields": true,
5
+ "lib": ["DOM", "DOM.Iterable", "ES2020"],
6
+ "allowJs": false,
7
+ "skipLibCheck": true,
8
+ "esModuleInterop": true,
9
+ "allowSyntheticDefaultImports": true,
10
+ "strict": true,
11
+ "forceConsistentCasingInFileNames": true,
12
+ "module": "ESNext",
13
+ "moduleResolution": "Bundler",
14
+ "resolveJsonModule": true,
15
+ "isolatedModules": true,
16
+ "noEmit": true,
17
+ "jsx": "react-jsx"
18
+ },
19
+ "include": ["src"],
20
+ "references": [{ "path": "./tsconfig.node.json" }]
21
+ }
studio/tsconfig.node.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ "composite": true,
4
+ "skipLibCheck": true,
5
+ "module": "ESNext",
6
+ "moduleResolution": "Bundler",
7
+ "allowSyntheticDefaultImports": true
8
+ },
9
+ "include": ["vite.config.ts"]
10
+ }
training/README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Nemoflix AMD Training
2
+
3
+ Scripts/config templates for disposable AMD MI300X droplets.
4
+
5
+ ## Install AI Toolkit on the droplet
6
+
7
+ After `scripts/startup-script.sh` finishes on a fresh droplet:
8
+
9
+ ```bash
10
+ cd /root/nemoflix
11
+ bash scripts/install-ai-toolkit.sh
12
+ ```
13
+
14
+ The installer creates:
15
+
16
+ - `/root/ai-toolkit` — Ostris AI Toolkit checkout
17
+ - `/root/ai-toolkit-venv` — isolated ROCm Python venv
18
+ - `/root/nemoflix-training` — datasets/configs/output workspace
19
+ - `/root/nemoflix-training/run-ai-toolkit.sh` — CLI runner
20
+
21
+ ## Train
22
+
23
+ Copy a config into `/root/nemoflix-training/config/`, put media in `/root/nemoflix-training/datasets/...`, then:
24
+
25
+ ```bash
26
+ /root/nemoflix-training/run-ai-toolkit.sh /root/nemoflix-training/config/<job>.yaml
27
+ ```
28
+
29
+ ## Optional UI
30
+
31
+ The installer creates but does not start:
32
+
33
+ ```bash
34
+ ai-toolkit-ui.service
35
+ ```
36
+
37
+ Before exposing it, set a real `AI_TOOLKIT_AUTH` value in the service or an override.
training/flux2_identity_template.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLUX.2 [dev] identity/character LoRA template for AI Toolkit.
2
+ # Source: Ostris AI Toolkit FLUX.2 UI defaults + RunComfy FLUX.2 LoRA guide.
3
+ ---
4
+ job: extension
5
+ config:
6
+ name: "identity_flux2_lora_v1"
7
+ process:
8
+ - type: "sd_trainer"
9
+ training_folder: "/root/nemoflix-training/output"
10
+ device: cuda:0
11
+ # Replace with the private trigger for the subject, e.g. a short unique token.
12
+ trigger_word: "character_trigger"
13
+ network:
14
+ type: "lora"
15
+ linear: 32
16
+ linear_alpha: 32
17
+ save:
18
+ dtype: bf16
19
+ save_every: 250
20
+ max_step_saves_to_keep: 4
21
+ push_to_hub: false
22
+ datasets:
23
+ - folder_path: "/root/nemoflix-training/datasets/identity-flux2"
24
+ caption_ext: "txt"
25
+ # Static captions + cached text embeddings: keep dropout at 0.
26
+ caption_dropout_rate: 0
27
+ shuffle_tokens: false
28
+ cache_latents_to_disk: true
29
+ resolution: [768, 896, 1024]
30
+ train:
31
+ batch_size: 1
32
+ steps: 1800
33
+ gradient_accumulation_steps: 1
34
+ train_unet: true
35
+ train_text_encoder: false
36
+ gradient_checkpointing: true
37
+ noise_scheduler: "flowmatch"
38
+ timestep_type: "weighted"
39
+ optimizer: "adamw8bit"
40
+ lr: 1e-4
41
+ optimizer_params:
42
+ weight_decay: 1e-4
43
+ dtype: bf16
44
+ # First pass: no DOP, static captions, cache embeddings for speed/VRAM.
45
+ unload_text_encoder: false
46
+ cache_text_embeddings: true
47
+ ema_config:
48
+ use_ema: false
49
+ model:
50
+ name_or_path: "black-forest-labs/FLUX.2-dev"
51
+ arch: "flux2"
52
+ quantize: true
53
+ qtype: "qfloat8"
54
+ quantize_te: true
55
+ qtype_te: "qfloat8"
56
+ low_vram: false
57
+ model_kwargs:
58
+ match_target_res: false
59
+ sample:
60
+ sampler: "flowmatch"
61
+ sample_every: 250
62
+ width: 1024
63
+ height: 1024
64
+ prompts:
65
+ - "character_trigger, man, realistic Instagram travel photo, standing on a cliff at golden hour, ocean in the background, natural pose, shot on a mirrorless camera"
66
+ - "character_trigger, man, lifestyle creator photo, sitting at an outdoor cafe with a laptop and coffee, warm afternoon light, candid social media photography"
67
+ - "character_trigger, man, action sports photo, snowboarding down a mountain slope, powder snow, dynamic pose, realistic telephoto shot"
68
+ - "character_trigger, man, fitness lifestyle photo, post-workout portrait outside a modern gym, athletic casual clothing, natural light, authentic Instagram content"
69
+ - "character_trigger, man, editorial streetwear photo, walking through a downtown city street at sunset, stylish outfit, shallow depth of field, realistic fashion photography"
70
+ - "photo of a man, realistic social media portrait, white background, medium shot, studio lighting"
71
+ neg: ""
72
+ seed: 42
73
+ walk_seed: false
74
+ guidance_scale: 1
75
+ sample_steps: 25
76
+ meta:
77
+ name: "[name]"
78
+ version: "1.0"
training/wan22_i2v_character_template.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Wan 2.2 I2V 14B character LoRA for a character likeness.
2
+ # Source template: ostris/ai-toolkit config/examples/train_lora_wan22_14b_24gb.yaml
3
+ # I2V-specific arch/options confirmed in ui/src/app/jobs/new/options.ts.
4
+ ---
5
+ job: extension
6
+ config:
7
+ name: "character_wan22_i2v_v1"
8
+ process:
9
+ - type: "sd_trainer"
10
+ training_folder: "/root/nemoflix-training/output"
11
+ device: cuda:0
12
+ trigger_word: "character_trigger"
13
+ network:
14
+ type: "lora"
15
+ linear: 32
16
+ linear_alpha: 32
17
+ save:
18
+ dtype: bf16
19
+ save_every: 250
20
+ max_step_saves_to_keep: 6
21
+ datasets:
22
+ - folder_path: "/root/nemoflix-training/datasets/character-wan-i2v"
23
+ caption_ext: "txt"
24
+ caption_dropout_rate: 0.05
25
+ num_frames: 41
26
+ resolution: [512, 768, 1024]
27
+ train:
28
+ batch_size: 1
29
+ steps: 2500
30
+ gradient_accumulation: 1
31
+ train_unet: true
32
+ train_text_encoder: false
33
+ gradient_checkpointing: true
34
+ noise_scheduler: "flowmatch"
35
+ timestep_type: "sigmoid"
36
+ optimizer: "adamw8bit"
37
+ lr: 1e-4
38
+ optimizer_params:
39
+ weight_decay: 1e-4
40
+ dtype: bf16
41
+ switch_boundary_every: 10
42
+ # Keep captions live for character training and DOP experiments.
43
+ # If memory is tight, switch to cache_text_embeddings and remove DOP.
44
+ cache_text_embeddings: false
45
+ model:
46
+ name_or_path: "ai-toolkit/Wan2.2-I2V-A14B-Diffusers-bf16"
47
+ arch: "wan22_14b_i2v"
48
+ quantize: true
49
+ qtype: "uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors"
50
+ quantize_te: true
51
+ qtype_te: "qfloat8"
52
+ low_vram: false
53
+ model_kwargs:
54
+ train_high_noise: true
55
+ train_low_noise: true
56
+ sample:
57
+ sampler: "flowmatch"
58
+ sample_every: 250
59
+ width: 768
60
+ height: 768
61
+ num_frames: 41
62
+ fps: 16
63
+ # Wan I2V samples require a prompt + ctrl_img pair.
64
+ # Replace this once we have character control/reference frames.
65
+ ctrl_img: "/root/nemoflix-training/samples/character_control.jpg"
66
+ prompts:
67
+ - "character_trigger, person, cinematic portrait, walking in heavy rain, dramatic lighting"
68
+ - "character_trigger, person, wearing futuristic armor, rain-soaked city street, cinematic"
69
+ - "character_trigger, person, close-up face, intense expression, film still, shallow depth of field"
70
+ neg: ""
71
+ seed: 42
72
+ walk_seed: true
73
+ guidance_scale: 3.5
74
+ sample_steps: 25
75
+ meta:
76
+ name: "[name]"
77
+ version: "1.0"