| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from datasets import load_dataset |
| from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer |
| from huggingface_hub import HfApi, hf_hub_download |
| from safetensors.torch import save_file, load_file |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm.auto import tqdm |
| import numpy as np |
| import math |
| import os |
| import json |
| from datetime import datetime |
|
|
| |
| |
| |
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.backends.cudnn.benchmark = True |
| torch.set_float32_matmul_precision('high') |
|
|
| |
| import warnings |
| warnings.filterwarnings('ignore', message='.*TF32.*') |
|
|
| |
| |
| |
| BATCH_SIZE = 128 |
| GRAD_ACCUM = 1 |
| LR = 1e-4 |
| EPOCHS = 10 |
| MAX_SEQ = 128 |
| MIN_SNR = 5.0 |
| SHIFT = 3.0 |
| DEVICE = "cuda" |
| DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
|
|
| |
| HF_REPO = "AbstractPhil/tiny-flux" |
| SAVE_EVERY = 1000 |
| UPLOAD_EVERY = 1000 |
| SAMPLE_EVERY = 500 |
| LOG_EVERY = 10 |
|
|
| |
| LOAD_TARGET = "hub:step_24000" |
| RESUME_STEP = None |
|
|
| |
| CHECKPOINT_DIR = "./tiny_flux_checkpoints" |
| LOG_DIR = "./tiny_flux_logs" |
| SAMPLE_DIR = "./tiny_flux_samples" |
|
|
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) |
| os.makedirs(LOG_DIR, exist_ok=True) |
| os.makedirs(SAMPLE_DIR, exist_ok=True) |
|
|
| |
| |
| |
| print("Setting up HuggingFace Hub...") |
| api = HfApi() |
| try: |
| api.create_repo(repo_id=HF_REPO, exist_ok=True, repo_type="model") |
| print(f"✓ Repo ready: {HF_REPO}") |
| except Exception as e: |
| print(f"Note: {e}") |
|
|
| |
| |
| |
| run_name = datetime.now().strftime("%Y%m%d_%H%M%S") |
| writer = SummaryWriter(log_dir=os.path.join(LOG_DIR, run_name)) |
| print(f"✓ Tensorboard: {LOG_DIR}/{run_name}") |
|
|
| |
| |
| |
| print("\nLoading dataset...") |
| ds = load_dataset("AbstractPhil/flux-schnell-teacher-latents", "train_3_512", split="train") |
| print(f"Samples: {len(ds)}") |
|
|
| |
| |
| |
| print("\nLoading flan-t5-base (768 dim)...") |
| t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") |
| t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval() |
|
|
| print("Loading CLIP-L...") |
| clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
| clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() |
|
|
| for p in t5_enc.parameters(): p.requires_grad = False |
| for p in clip_enc.parameters(): p.requires_grad = False |
|
|
| |
| |
| |
| print("Loading Flux VAE for samples...") |
| from diffusers import AutoencoderKL |
|
|
| vae = AutoencoderKL.from_pretrained( |
| "black-forest-labs/FLUX.1-schnell", |
| subfolder="vae", |
| torch_dtype=DTYPE |
| ).to(DEVICE).eval() |
| for p in vae.parameters(): p.requires_grad = False |
|
|
| |
| |
| |
| @torch.inference_mode() |
| def encode_prompts_batched(prompts: list) -> tuple: |
| """Encode multiple prompts at once - MUCH faster than loop.""" |
| |
| t5_in = t5_tok( |
| prompts, |
| max_length=MAX_SEQ, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt" |
| ).to(DEVICE) |
| t5_out = t5_enc( |
| input_ids=t5_in.input_ids, |
| attention_mask=t5_in.attention_mask |
| ).last_hidden_state |
| |
| |
| clip_in = clip_tok( |
| prompts, |
| max_length=77, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt" |
| ).to(DEVICE) |
| clip_out = clip_enc( |
| input_ids=clip_in.input_ids, |
| attention_mask=clip_in.attention_mask |
| ) |
| |
| return t5_out, clip_out.pooler_output |
|
|
|
|
| @torch.inference_mode() |
| def encode_prompt(prompt: str) -> tuple: |
| """Encode single prompt (for compatibility).""" |
| return encode_prompts_batched([prompt]) |
|
|
|
|
| |
| |
| |
| print("\nPre-encoding prompts...") |
| PRECOMPUTE_ENCODINGS = True |
| ENCODING_CACHE_DIR = "./encoding_cache" |
| os.makedirs(ENCODING_CACHE_DIR, exist_ok=True) |
|
|
| |
| cache_file = os.path.join(ENCODING_CACHE_DIR, f"encodings_{len(ds)}_t5base_clipL.pt") |
|
|
| if PRECOMPUTE_ENCODINGS: |
| if os.path.exists(cache_file): |
| |
| print(f"Loading cached encodings from {cache_file}...") |
| cached = torch.load(cache_file, weights_only=True) |
| all_t5_embeds = cached["t5_embeds"] |
| all_clip_pooled = cached["clip_pooled"] |
| print(f"✓ Loaded cached encodings") |
| else: |
| |
| print("Encoding prompts (will cache for future runs)...") |
| all_prompts = ds["prompt"] |
| |
| encode_batch_size = 64 |
| all_t5_embeds = [] |
| all_clip_pooled = [] |
| |
| for i in tqdm(range(0, len(all_prompts), encode_batch_size), desc="Encoding"): |
| batch_prompts = all_prompts[i:i+encode_batch_size] |
| t5_out, clip_out = encode_prompts_batched(batch_prompts) |
| all_t5_embeds.append(t5_out.cpu()) |
| all_clip_pooled.append(clip_out.cpu()) |
| |
| all_t5_embeds = torch.cat(all_t5_embeds, dim=0) |
| all_clip_pooled = torch.cat(all_clip_pooled, dim=0) |
| |
| |
| torch.save({ |
| "t5_embeds": all_t5_embeds, |
| "clip_pooled": all_clip_pooled, |
| }, cache_file) |
| print(f"✓ Saved encoding cache to {cache_file}") |
| |
| print(f" T5 embeds: {all_t5_embeds.shape}") |
| print(f" CLIP pooled: {all_clip_pooled.shape}") |
|
|
|
|
| |
| |
| |
| def flux_shift(t, s=SHIFT): |
| """Flux timestep shift for training distribution.""" |
| return s * t / (1 + (s - 1) * t) |
|
|
|
|
| def min_snr_weight(t, gamma=MIN_SNR): |
| """Min-SNR weighting to balance loss across timesteps.""" |
| snr = (t / (1 - t).clamp(min=1e-5)).pow(2) |
| return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5) |
|
|
|
|
| |
| |
| |
| @torch.inference_mode() |
| def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=64): |
| """Generate sample images using Euler sampling.""" |
| model.eval() |
| B = len(prompts) |
| C = 16 |
| |
| |
| t5_embeds, clip_pooleds = encode_prompts_batched(prompts) |
| t5_embeds = t5_embeds.to(DTYPE) |
| clip_pooleds = clip_pooleds.to(DTYPE) |
| |
| |
| x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE) |
| |
| |
| img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE) |
| |
| |
| t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE) |
| timesteps = flux_shift(t_linear, s=SHIFT) |
| |
| |
| for i in range(num_steps): |
| t_curr = timesteps[i] |
| t_next = timesteps[i + 1] |
| dt = t_next - t_curr |
| |
| t_batch = t_curr.expand(B).to(DTYPE) |
| guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE) |
| |
| v_cond = model( |
| hidden_states=x, |
| encoder_hidden_states=t5_embeds, |
| pooled_projections=clip_pooleds, |
| timestep=t_batch, |
| img_ids=img_ids, |
| guidance=guidance, |
| ) |
| |
| x = x + v_cond * dt |
| |
| |
| latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2) |
| latents = latents / vae.config.scaling_factor |
| images = vae.decode(latents.to(vae.dtype)).sample |
| images = (images / 2 + 0.5).clamp(0, 1) |
| |
| model.train() |
| return images |
|
|
|
|
| def save_samples(images, prompts, step, save_dir): |
| """Save sample images.""" |
| from torchvision.utils import make_grid, save_image |
| |
| for i, (img, prompt) in enumerate(zip(images, prompts)): |
| safe_prompt = prompt[:50].replace(" ", "_").replace("/", "-") |
| path = os.path.join(save_dir, f"step{step}_{i}_{safe_prompt}.png") |
| save_image(img, path) |
| |
| grid = make_grid(images, nrow=2, normalize=False) |
| writer.add_image("samples", grid, step) |
| writer.add_text("sample_prompts", "\n".join(prompts), step) |
| print(f" ✓ Saved {len(images)} samples") |
|
|
|
|
| |
| |
| |
| def collate_preencoded(batch): |
| """Collate using pre-encoded embeddings - returns CPU tensors.""" |
| indices = [b["__index__"] for b in batch] |
| latents = torch.stack([ |
| torch.tensor(np.array(b["latent"]), dtype=DTYPE) |
| for b in batch |
| ]) |
| |
| |
| return { |
| "latents": latents, |
| "t5_embeds": all_t5_embeds[indices].to(DTYPE), |
| "clip_pooled": all_clip_pooled[indices].to(DTYPE), |
| } |
|
|
|
|
| def collate_online(batch): |
| """Collate with online encoding - returns CPU tensors.""" |
| prompts = [b["prompt"] for b in batch] |
| latents = torch.stack([ |
| torch.tensor(np.array(b["latent"]), dtype=DTYPE) |
| for b in batch |
| ]) |
| |
| |
| t5_embeds, clip_pooled = encode_prompts_batched(prompts) |
| |
| return { |
| "latents": latents, |
| "t5_embeds": t5_embeds.cpu().to(DTYPE), |
| "clip_pooled": clip_pooled.cpu().to(DTYPE), |
| } |
|
|
|
|
| |
| class IndexedDataset: |
| """Wraps dataset to add __index__ field without expensive ds.map()""" |
| def __init__(self, ds): |
| self.ds = ds |
| def __len__(self): |
| return len(self.ds) |
| def __getitem__(self, idx): |
| item = dict(self.ds[idx]) |
| item["__index__"] = idx |
| return item |
|
|
| |
| if PRECOMPUTE_ENCODINGS: |
| ds = IndexedDataset(ds) |
| collate_fn = collate_preencoded |
| num_workers = 2 |
| else: |
| collate_fn = collate_online |
| num_workers = 0 |
|
|
|
|
| |
| |
| |
| def load_weights(path): |
| """Load weights, handling torch.compile prefix.""" |
| if path.endswith(".safetensors"): |
| state_dict = load_file(path) |
| elif path.endswith(".pt"): |
| ckpt = torch.load(path, map_location=DEVICE, weights_only=False) |
| if isinstance(ckpt, dict): |
| state_dict = ckpt.get("model", ckpt.get("state_dict", ckpt)) |
| else: |
| state_dict = ckpt |
| else: |
| try: |
| state_dict = load_file(path) |
| except: |
| state_dict = torch.load(path, map_location=DEVICE, weights_only=False) |
| |
| |
| if isinstance(state_dict, dict) and any(k.startswith("_orig_mod.") for k in state_dict.keys()): |
| print(" Stripping torch.compile prefix...") |
| state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} |
| |
| return state_dict |
|
|
|
|
| def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path): |
| """Save checkpoint, stripping torch.compile prefix.""" |
| os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) |
| |
| state_dict = model.state_dict() |
| if any(k.startswith("_orig_mod.") for k in state_dict.keys()): |
| state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} |
| |
| weights_path = path.replace(".pt", ".safetensors") |
| save_file(state_dict, weights_path) |
| |
| torch.save({ |
| "step": step, |
| "epoch": epoch, |
| "loss": loss, |
| "optimizer": optimizer.state_dict(), |
| "scheduler": scheduler.state_dict(), |
| }, path) |
| print(f" ✓ Saved checkpoint: step {step}") |
| return weights_path |
|
|
|
|
| def upload_checkpoint(weights_path, step, config): |
| """Upload to HuggingFace Hub.""" |
| try: |
| api.upload_file( |
| path_or_fileobj=weights_path, |
| path_in_repo=f"checkpoints/step_{step}.safetensors", |
| repo_id=HF_REPO, |
| commit_message=f"Checkpoint step {step}", |
| ) |
| |
| config_path = os.path.join(CHECKPOINT_DIR, "config.json") |
| with open(config_path, "w") as f: |
| json.dump(config.__dict__, f, indent=2) |
| api.upload_file(path_or_fileobj=config_path, path_in_repo="config.json", repo_id=HF_REPO) |
| |
| print(f" ✓ Uploaded step {step} to {HF_REPO}") |
| except Exception as e: |
| print(f" ⚠ Upload failed: {e}") |
|
|
|
|
| def load_checkpoint(model, optimizer, scheduler, target): |
| """Load checkpoint from various sources.""" |
| start_step, start_epoch = 0, 0 |
| |
| if target == "none" or target is None: |
| print("Starting fresh (no checkpoint)") |
| return 0, 0 |
| |
| |
| if target == "hub" or (isinstance(target, str) and target.startswith("hub:")): |
| try: |
| if target == "hub": |
| weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors") |
| else: |
| step_name = target.split(":")[1] |
| try: |
| weights_path = hf_hub_download(repo_id=HF_REPO, filename=f"checkpoints/{step_name}.safetensors") |
| except: |
| weights_path = hf_hub_download(repo_id=HF_REPO, filename=f"checkpoints/{step_name}.pt") |
| start_step = int(step_name.split("_")[-1]) if "_" in step_name else 0 |
| |
| weights = load_weights(weights_path) |
| |
| missing, unexpected = model.load_state_dict(weights, strict=False) |
| if missing: |
| |
| expected_missing = {'time_in.sin_basis', 'guidance_in.sin_basis', |
| 'rope.freqs_0', 'rope.freqs_1', 'rope.freqs_2'} |
| actual_missing = set(missing) - expected_missing |
| if actual_missing: |
| print(f" ⚠ Unexpected missing keys: {actual_missing}") |
| else: |
| print(f" ✓ Missing only precomputed buffers (OK)") |
| print(f"✓ Loaded from hub: {target}") |
| return start_step, start_epoch |
| except Exception as e: |
| print(f"Hub load failed: {e}") |
| return 0, 0 |
| |
| |
| if isinstance(target, str) and target.startswith("local:"): |
| path = target.split(":", 1)[1] |
| weights = load_weights(path) |
| missing, unexpected = model.load_state_dict(weights, strict=False) |
| if missing: |
| expected_missing = {'time_in.sin_basis', 'guidance_in.sin_basis', |
| 'rope.freqs_0', 'rope.freqs_1', 'rope.freqs_2'} |
| actual_missing = set(missing) - expected_missing |
| if actual_missing: |
| print(f" ⚠ Unexpected missing keys: {actual_missing}") |
| print(f"✓ Loaded from local: {path}") |
| return 0, 0 |
| |
| print("No checkpoint found, starting fresh") |
| return 0, 0 |
|
|
|
|
| |
| |
| |
| loader = DataLoader( |
| ds, |
| batch_size=BATCH_SIZE, |
| shuffle=True, |
| collate_fn=collate_fn, |
| num_workers=num_workers, |
| pin_memory=True, |
| persistent_workers=(num_workers > 0), |
| prefetch_factor=2 if num_workers > 0 else None, |
| ) |
|
|
| |
| |
| |
| config = TinyFluxConfig() |
| model = TinyFlux(config).to(DEVICE).to(DTYPE) |
| print(f"\nParams: {sum(p.numel() for p in model.parameters()):,}") |
|
|
| |
| |
| |
| opt = torch.optim.AdamW( |
| model.parameters(), |
| lr=LR, |
| betas=(0.9, 0.99), |
| weight_decay=0.01, |
| fused=True, |
| ) |
|
|
| total_steps = len(loader) * EPOCHS // GRAD_ACCUM |
| warmup = min(500, total_steps // 10) |
|
|
|
|
| def lr_fn(step): |
| if step < warmup: |
| return step / warmup |
| return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup))) |
|
|
|
|
| sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn) |
|
|
| |
| |
| |
| print(f"\nLoad target: {LOAD_TARGET}") |
| start_step, start_epoch = load_checkpoint(model, opt, sched, LOAD_TARGET) |
|
|
| if RESUME_STEP is not None: |
| print(f"Overriding start_step: {start_step} -> {RESUME_STEP}") |
| start_step = RESUME_STEP |
|
|
| |
| |
| |
| model = torch.compile(model, mode="default") |
|
|
| |
| writer.add_text("config", json.dumps(config.__dict__, indent=2), 0) |
| writer.add_text("training_config", json.dumps({ |
| "batch_size": BATCH_SIZE, |
| "grad_accum": GRAD_ACCUM, |
| "lr": LR, |
| "epochs": EPOCHS, |
| "min_snr": MIN_SNR, |
| "shift": SHIFT, |
| "optimizations": ["TF32", "fused_adamw", "precomputed_encodings", "flash_attention", "torch.compile"] |
| }, indent=2), 0) |
|
|
| |
| SAMPLE_PROMPTS = [ |
| "a photo of a cat sitting on a windowsill", |
| "a beautiful sunset over mountains", |
| "a portrait of a woman with red hair", |
| "a futuristic cityscape at night", |
| ] |
|
|
| |
| |
| |
| print(f"\nTraining {EPOCHS} epochs, {total_steps} total steps") |
| print(f"Resuming from step {start_step}, epoch {start_epoch}") |
| print(f"Save: {SAVE_EVERY}, Upload: {UPLOAD_EVERY}, Sample: {SAMPLE_EVERY}, Log: {LOG_EVERY}") |
| print("Optimizations: TF32, fused AdamW, pre-encoded prompts, Flash Attention, torch.compile") |
|
|
| model.train() |
| step = start_step |
| best = float("inf") |
|
|
| |
| _cached_img_ids = None |
|
|
| for ep in range(start_epoch, EPOCHS): |
| ep_loss = 0 |
| ep_batches = 0 |
| pbar = tqdm(loader, desc=f"E{ep + 1}") |
| |
| for i, batch in enumerate(pbar): |
| |
| latents = batch["latents"].to(DEVICE, non_blocking=True) |
| t5 = batch["t5_embeds"].to(DEVICE, non_blocking=True) |
| clip = batch["clip_pooled"].to(DEVICE, non_blocking=True) |
| |
| B, C, H, W = latents.shape |
| |
| |
| data = latents.permute(0, 2, 3, 1).reshape(B, H * W, C) |
| noise = torch.randn_like(data) |
| |
| |
| t = torch.sigmoid(torch.randn(B, device=DEVICE)) |
| t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1 - 1e-4) |
| |
| |
| t_expanded = t.view(B, 1, 1) |
| x_t = (1 - t_expanded) * noise + t_expanded * data |
| |
| |
| v_target = data - noise |
| |
| |
| img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE) |
| |
| |
| guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1 |
| |
| |
| with torch.autocast("cuda", dtype=DTYPE): |
| v_pred = model( |
| hidden_states=x_t, |
| encoder_hidden_states=t5, |
| pooled_projections=clip, |
| timestep=t, |
| img_ids=img_ids, |
| guidance=guidance, |
| ) |
| |
| |
| loss_raw = F.mse_loss(v_pred, v_target, reduction="none").mean(dim=[1, 2]) |
| snr_weights = min_snr_weight(t) |
| loss = (loss_raw * snr_weights).mean() / GRAD_ACCUM |
| loss.backward() |
| |
| if (i + 1) % GRAD_ACCUM == 0: |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| sched.step() |
| opt.zero_grad(set_to_none=True) |
| step += 1 |
| |
| if step % LOG_EVERY == 0: |
| writer.add_scalar("train/loss", loss.item() * GRAD_ACCUM, step) |
| writer.add_scalar("train/lr", sched.get_last_lr()[0], step) |
| writer.add_scalar("train/grad_norm", grad_norm.item(), step) |
| writer.add_scalar("train/t_mean", t.mean().item(), step) |
| |
| if step % SAMPLE_EVERY == 0: |
| print(f"\n Generating samples at step {step}...") |
| images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20) |
| save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR) |
| |
| if step % SAVE_EVERY == 0: |
| ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt") |
| weights_path = save_checkpoint(model, opt, sched, step, ep, loss.item(), ckpt_path) |
| |
| if step % UPLOAD_EVERY == 0: |
| upload_checkpoint(weights_path, step, config) |
| |
| ep_loss += loss.item() * GRAD_ACCUM |
| ep_batches += 1 |
| pbar.set_postfix(loss=f"{loss.item() * GRAD_ACCUM:.4f}", lr=f"{sched.get_last_lr()[0]:.1e}", step=step) |
| |
| avg = ep_loss / max(ep_batches, 1) |
| print(f"Epoch {ep + 1} loss: {avg:.4f}") |
| writer.add_scalar("train/epoch_loss", avg, ep + 1) |
| |
| if avg < best: |
| best = avg |
| best_path = os.path.join(CHECKPOINT_DIR, "best.pt") |
| weights_path = save_checkpoint(model, opt, sched, step, ep, avg, best_path) |
| |
| try: |
| api.upload_file( |
| path_or_fileobj=weights_path, |
| path_in_repo="model.safetensors", |
| repo_id=HF_REPO, |
| commit_message=f"Best model (epoch {ep + 1}, loss {avg:.4f})", |
| ) |
| print(f" ✓ Uploaded best to {HF_REPO}") |
| except Exception as e: |
| print(f" ⚠ Upload failed: {e}") |
|
|
| |
| |
| |
| print("\nSaving final model...") |
| final_path = os.path.join(CHECKPOINT_DIR, "final.pt") |
| weights_path = save_checkpoint(model, opt, sched, step, EPOCHS, best, final_path) |
|
|
| print("Generating final samples...") |
| images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20) |
| save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR) |
|
|
| try: |
| api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO) |
| config_path = os.path.join(CHECKPOINT_DIR, "config.json") |
| with open(config_path, "w") as f: |
| json.dump(config.__dict__, f, indent=2) |
| api.upload_file(path_or_fileobj=config_path, path_in_repo="config.json", repo_id=HF_REPO) |
| print(f"\n✓ Training complete! https://huggingface.co/{HF_REPO}") |
| except Exception as e: |
| print(f"\n⚠ Final upload failed: {e}") |
|
|
| writer.close() |
| print(f"Best loss: {best:.4f}") |