| """ |
| SD15 Flow-Matching Trainer - ControlNet Pose Edition |
| Author: AbstractPhil |
| |
| Trains Lune on controlnet pose dataset with transparent backgrounds. |
| |
| License: MIT |
| """ |
|
|
| import os |
| import json |
| import datetime |
| import random |
| from dataclasses import dataclass, asdict, field |
| from tqdm.auto import tqdm |
| import matplotlib.pyplot as plt |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.tensorboard import SummaryWriter |
| from torch.utils.data import DataLoader |
|
|
| import datasets |
| from diffusers import UNet2DConditionModel, AutoencoderKL |
| from transformers import CLIPTextModel, CLIPTokenizer |
| from huggingface_hub import HfApi, create_repo, hf_hub_download |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| output_dir: str = "./outputs" |
| model_repo: str = "AbstractPhil/sd15-flow-lune" |
| checkpoint_filename: str = "sd15_flow_pretrain_pose_controlnet_t500_700_s8312.pt" |
| dataset_name: str = "AbstractPhil/CN_pose3D_V7_512" |
| use_masks: bool = True |
| mask_column: str = "mask" |
| |
| |
| hf_repo_id: str = "AbstractPhil/sd15-flow-lune" |
| upload_to_hub: bool = True |
| |
| |
| run_name: str = "pretrain_pose_controlnet_v7_v10_t400_600" |
| |
| |
| continue_from_checkpoint: bool = False |
| |
| seed: int = 42 |
| batch_size: int = 64 |
|
|
| |
| base_lr: float = 2e-6 |
| shift: float = 2.5 |
| dropout: float = 0.1 |
| min_snr_gamma: float = 5.0 |
| |
| |
| |
| min_timestep: float = 400.0 |
| max_timestep: float = 600.0 |
| |
| |
| num_train_epochs: int = 1 |
| warmup_epochs: int = 1 |
| checkpointing_steps: int = 2500 |
| num_workers: int = 0 |
| |
| |
| vae_scale: float = 0.18215 |
|
|
| |
| delimiter: str = "," |
| preserved_count: int = 2 |
| remove_these: list = field(default_factory=lambda: [ |
| "simple background", |
| "white background"]) |
| prepend_prompt: str = "doll" |
| append_prompt: str = "transparent background" |
| shuffle_prompt: bool = True |
|
|
|
|
| def preprocess_caption(text: str, config: TrainConfig) -> str: |
| """ |
| Preprocess controlnet pose captions with config-based shuffling: |
| - Lowercase and clean punctuation |
| - Remove unwanted tokens from config.remove_these |
| - Prepend config.prepend_prompt |
| - Shuffle tokens (preserving first config.preserved_count) |
| - Append config.append_prompt |
| """ |
| |
| if text is None or text == "": |
| if config.append_prompt: |
| return config.append_prompt |
| return "" |
| |
| |
| text = text.lower() |
| text = text.replace(".", config.delimiter) |
| text = text.strip() |
| |
| |
| while f"{config.delimiter}{config.delimiter}" in text: |
| text = text.replace(f"{config.delimiter}{config.delimiter}", config.delimiter) |
| while " " in text: |
| text = text.replace(" ", " ") |
| |
| text = text.strip() |
| |
| |
| if text.startswith(config.delimiter): |
| text = text[1:].strip() |
| if text.endswith(config.delimiter): |
| text = text[:-1].strip() |
| |
| |
| if config.prepend_prompt: |
| text = f"{config.prepend_prompt}{config.delimiter} {text}" if text else config.prepend_prompt |
| |
| |
| if config.shuffle_prompt and text: |
| |
| tokens = [t.strip() for t in text.split(config.delimiter) if t.strip()] |
| |
| |
| if config.remove_these: |
| tokens = [t for t in tokens if t not in config.remove_these] |
| |
| |
| preserved = tokens[:config.preserved_count] |
| shuffleable = tokens[config.preserved_count:] |
| |
| |
| random.shuffle(shuffleable) |
| |
| |
| tokens = preserved + shuffleable |
| text = f"{config.delimiter} ".join(tokens) |
| else: |
| |
| if config.remove_these and text: |
| tokens = [t.strip() for t in text.split(config.delimiter) if t.strip()] |
| tokens = [t for t in tokens if t not in config.remove_these] |
| text = f"{config.delimiter} ".join(tokens) |
| |
| |
| if config.append_prompt: |
| text = f"{text}{config.delimiter} {config.append_prompt}" if text else config.append_prompt |
| |
| return text |
|
|
|
|
| def load_student_unet(repo_id: str, filename: str, device="cuda"): |
| """Load UNet from checkpoint, return checkpoint dict for optional optimizer/scheduler restoration""" |
| print(f"Downloading checkpoint from {repo_id}/{filename}...") |
| checkpoint_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=filename, |
| repo_type="model" |
| ) |
| |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| |
| print("Loading SD1.5 UNet architecture...") |
| unet = UNet2DConditionModel.from_pretrained( |
| "runwayml/stable-diffusion-v1-5", |
| subfolder="unet", |
| torch_dtype=torch.float32 |
| ) |
| |
| |
| student_state_dict = checkpoint["student"] |
| |
| |
| cleaned_dict = {} |
| for key, value in student_state_dict.items(): |
| cleaned_key = key[5:] if key.startswith("unet.") else key |
| cleaned_dict[cleaned_key] = value |
| |
| unet.load_state_dict(cleaned_dict, strict=False) |
| |
| print(f"✓ Loaded UNet from step {checkpoint.get('gstep', 'unknown')}") |
| |
| return unet.to(device), checkpoint |
|
|
|
|
| def train(config: TrainConfig): |
| device = "cuda" |
| torch.backends.cuda.matmul.allow_tf32 = True |
| |
| torch.manual_seed(config.seed) |
| torch.cuda.manual_seed(config.seed) |
| |
| |
| date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| real_output_dir = os.path.join(config.output_dir, date_time) |
| os.makedirs(real_output_dir, exist_ok=True) |
| t_writer = SummaryWriter(log_dir=real_output_dir, flush_secs=60) |
| |
| |
| hf_api = None |
| if config.upload_to_hub: |
| try: |
| hf_api = HfApi() |
| create_repo( |
| repo_id=config.hf_repo_id, |
| repo_type="model", |
| exist_ok=True, |
| private=False |
| ) |
| print(f"✓ HuggingFace repo ready: {config.hf_repo_id}") |
| except Exception as e: |
| print(f"⚠ Hub upload disabled: {e}") |
| config.upload_to_hub = False |
| |
| |
| config_path = os.path.join(real_output_dir, "config.json") |
| with open(config_path, "w") as f: |
| json.dump(asdict(config), f, indent=2) |
| |
| if config.upload_to_hub: |
| hf_api.upload_file( |
| path_or_fileobj=config_path, |
| path_in_repo="config.json", |
| repo_id=config.hf_repo_id, |
| repo_type="model" |
| ) |
| |
| |
| print("\nLoading SD1.5 VAE and CLIP...") |
| vae = AutoencoderKL.from_pretrained( |
| "runwayml/stable-diffusion-v1-5", |
| subfolder="vae", |
| torch_dtype=torch.float32 |
| ).to(device) |
| vae.requires_grad_(False) |
| vae.eval() |
| |
| tokenizer = CLIPTokenizer.from_pretrained( |
| "runwayml/stable-diffusion-v1-5", |
| subfolder="tokenizer" |
| ) |
| text_encoder = CLIPTextModel.from_pretrained( |
| "runwayml/stable-diffusion-v1-5", |
| subfolder="text_encoder", |
| torch_dtype=torch.float32 |
| ).to(device) |
| text_encoder.requires_grad_(False) |
| text_encoder.eval() |
| |
| print("✓ VAE and CLIP loaded") |
| |
| |
| print(f"\nLoading dataset: {config.dataset_name}") |
| train_dataset = datasets.load_dataset( |
| config.dataset_name, |
| split="train" |
| ) |
| |
| print(f"✓ Loaded {len(train_dataset):,} images") |
| print(f" Columns: {train_dataset.column_names}") |
| |
| |
| steps_per_epoch = len(train_dataset) // config.batch_size |
| total_steps = steps_per_epoch * config.num_train_epochs |
| warmup_steps = steps_per_epoch * config.warmup_epochs |
| |
| print(f"\nTraining schedule:") |
| print(f" Total images: {len(train_dataset):,}") |
| print(f" Batch size: {config.batch_size}") |
| print(f" Steps per epoch: {steps_per_epoch:,}") |
| print(f" Total epochs: {config.num_train_epochs}") |
| print(f" Total steps: {total_steps:,}") |
| print(f" Warmup steps: {warmup_steps:,}") |
| print(f"\nTimestep range:") |
| print(f" Min timestep: {config.min_timestep}") |
| print(f" Max timestep: {config.max_timestep}") |
| print(f" Training on: {config.max_timestep - config.min_timestep} timestep range") |
| print(f"\nPrompt preprocessing:") |
| print(f" Shuffle: {config.shuffle_prompt}") |
| print(f" Preserved tokens: {config.preserved_count}") |
| print(f" Prepend: '{config.prepend_prompt}'") |
| print(f" Append: '{config.append_prompt}'") |
| print(f" Remove: {config.remove_these}") |
| |
| @torch.no_grad() |
| def collate_fn(examples): |
| """Encode images, masks (optional), and prompts at runtime""" |
| import numpy as np |
| |
| images = [] |
| masks = [] |
| prompts = [] |
| image_ids = [] |
| |
| for idx, ex in enumerate(examples): |
| |
| img = ex['image'].convert('RGB') |
| img = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0 |
| img = img * 2.0 - 1.0 |
| images.append(img) |
| |
| |
| if config.use_masks and config.mask_column in ex: |
| |
| mask = ex[config.mask_column].convert('L') |
| mask = torch.tensor(np.array(mask)).float() / 255.0 |
| masks.append(mask) |
| |
| |
| raw_text = ex['text'] |
| processed_prompt = preprocess_caption(raw_text, config) |
| prompts.append(processed_prompt) |
| image_ids.append(idx) |
| |
| images = torch.stack(images).to(device) |
| |
| |
| latents = vae.encode(images).latent_dist.sample() |
| latents = latents * config.vae_scale |
| |
| |
| if config.use_masks and masks: |
| masks = torch.stack(masks).to(device) |
| |
| masks_downsampled = F.interpolate( |
| masks.unsqueeze(1), |
| size=latents.shape[-2:], |
| mode='nearest' |
| ).squeeze(1) |
| else: |
| |
| masks_downsampled = torch.ones( |
| (latents.shape[0], latents.shape[2], latents.shape[3]), |
| dtype=torch.float32 |
| ) |
| |
| |
| text_inputs = tokenizer( |
| prompts, |
| padding="max_length", |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| return_tensors="pt" |
| ).to(device) |
| |
| encoder_hidden_states = text_encoder(text_inputs.input_ids)[0] |
| |
| return ( |
| latents.cpu(), |
| masks_downsampled.cpu(), |
| encoder_hidden_states.cpu(), |
| image_ids, |
| prompts |
| ) |
| |
| train_dataloader = DataLoader( |
| dataset=train_dataset, |
| batch_size=config.batch_size, |
| shuffle=True, |
| collate_fn=collate_fn, |
| num_workers=config.num_workers, |
| pin_memory=True |
| ) |
| |
| |
| print(f"\nLoading model from HuggingFace...") |
| unet, checkpoint = load_student_unet(config.model_repo, config.checkpoint_filename, device=device) |
| unet.requires_grad_(True) |
| unet.train() |
| |
| |
| optimizer = torch.optim.AdamW( |
| unet.parameters(), |
| lr=config.base_lr, |
| betas=(0.9, 0.999), |
| weight_decay=0.01, |
| eps=1e-8 |
| ) |
| |
| |
| if config.continue_from_checkpoint: |
| scheduler = torch.optim.lr_scheduler.LambdaLR( |
| optimizer, |
| lr_lambda=lambda step: 1.0 |
| ) |
| else: |
| def get_lr_scale(step): |
| if step < warmup_steps: |
| return step / warmup_steps |
| return 1.0 |
| |
| scheduler = torch.optim.lr_scheduler.LambdaLR( |
| optimizer, |
| lr_lambda=get_lr_scale |
| ) |
| |
| |
| start_step = 0 |
| |
| if config.continue_from_checkpoint: |
| if "opt" in checkpoint and "scheduler" in checkpoint: |
| optimizer.load_state_dict(checkpoint["opt"]) |
| scheduler.load_state_dict(checkpoint["scheduler"]) |
| start_step = checkpoint.get("gstep", 0) |
| print(f"✓ Resumed optimizer and scheduler from step {start_step}") |
| print(f" Will train for {config.num_train_epochs} more epoch(s) = {total_steps:,} additional steps") |
| else: |
| print("⚠ No optimizer/scheduler state in checkpoint, starting fresh") |
| else: |
| print("✓ Starting with fresh optimizer (no state loaded)") |
| |
| global_step = start_step |
| end_step = start_step + total_steps |
| train_logs = { |
| "train_step": [], |
| "train_loss": [], |
| "train_timestep": [], |
| "trained_images": [] |
| } |
| |
| def get_prediction(batch, log_to=None): |
| latents, masks, encoder_hidden_states, ids, prompts = batch |
| |
| latents = latents.to(dtype=torch.float32, device=device) |
| if config.use_masks: |
| masks = masks.to(dtype=torch.float32, device=device) |
| encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float32, device=device) |
| |
| batch_size = latents.shape[0] |
| |
| |
| dropout_mask = torch.rand(batch_size, device=device) < config.dropout |
| encoder_hidden_states = encoder_hidden_states.clone() |
| encoder_hidden_states[dropout_mask] = 0 |
| |
| |
| min_sigma = config.min_timestep / 1000.0 |
| max_sigma = config.max_timestep / 1000.0 |
| |
| sigmas = torch.rand(batch_size, device=device) |
| sigmas = min_sigma + sigmas * (max_sigma - min_sigma) |
| |
| |
| sigmas = (config.shift * sigmas) / (1 + (config.shift - 1) * sigmas) |
| timesteps = sigmas * 1000 |
| sigmas = sigmas[:, None, None, None] |
| |
| |
| noise = torch.randn_like(latents) |
| noisy_latents = noise * sigmas + latents * (1 - sigmas) |
| target = noise - latents |
| |
| |
| pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] |
| |
| |
| loss = F.mse_loss(pred, target, reduction="none") |
| loss = loss.mean(dim=1) |
| |
| |
| |
| snr = ((1 - sigmas.squeeze()) ** 2) / (sigmas.squeeze() ** 2 + 1e-8) |
| snr_weight = torch.minimum(snr, torch.ones_like(snr) * config.min_snr_gamma) / snr |
| |
| |
| snr_weight = snr_weight / (snr + 1) |
| snr_weight = snr_weight[:, None, None] |
| |
| loss = loss * snr_weight |
| |
| |
| if config.use_masks: |
| |
| |
| masked_loss = loss * masks |
| |
| |
| loss_per_sample = masked_loss.sum(dim=[1, 2]) / (masks.sum(dim=[1, 2]) + 1e-8) |
| else: |
| |
| loss_per_sample = loss.mean(dim=[1, 2]) |
| |
| if log_to is not None: |
| for i in range(batch_size): |
| log_to["train_step"].append(global_step) |
| log_to["train_loss"].append(loss_per_sample[i].item()) |
| log_to["train_timestep"].append(timesteps[i].item()) |
| log_to["trained_images"].append({ |
| "step": global_step, |
| "id": ids[i], |
| "prompt": prompts[i] |
| }) |
| |
| return loss_per_sample.mean() |
| |
| def plot_logs(log_dict): |
| plt.figure(figsize=(10, 6)) |
| plt.scatter( |
| log_dict["train_timestep"], |
| log_dict["train_loss"], |
| s=3, |
| c=log_dict["train_step"], |
| marker=".", |
| cmap='cool' |
| ) |
| plt.xlabel("timestep") |
| plt.ylabel("loss") |
| plt.yscale("log") |
| plt.colorbar(label="step") |
| |
| def save_checkpoint(step, relative_epoch): |
| checkpoint_path = os.path.join(real_output_dir, f"{config.run_name}_checkpoint-{step:08}") |
| os.makedirs(checkpoint_path, exist_ok=True) |
| |
| |
| unet.save_pretrained( |
| os.path.join(checkpoint_path, "unet"), |
| safe_serialization=True |
| ) |
| |
| |
| pt_filename = f"sd15_flow_{config.run_name}_s{step}.pt" |
| pt_path = os.path.join(checkpoint_path, pt_filename) |
| |
| torch.save({ |
| "cfg": asdict(config), |
| "student": unet.state_dict(), |
| "opt": optimizer.state_dict(), |
| "scheduler": scheduler.state_dict(), |
| "gstep": step, |
| "relative_epoch": relative_epoch |
| }, pt_path) |
| |
| |
| metadata = { |
| "step": step, |
| "relative_epoch": relative_epoch, |
| "trained_images": train_logs["trained_images"] |
| } |
| metadata_path = os.path.join(checkpoint_path, "trained_images.json") |
| with open(metadata_path, "w") as f: |
| json.dump(metadata, f, indent=2) |
| |
| print(f"✓ Checkpoint saved at step {step} (relative epoch {relative_epoch})") |
| |
| |
| if config.upload_to_hub and hf_api is not None: |
| try: |
| hf_api.upload_file( |
| path_or_fileobj=pt_path, |
| path_in_repo=pt_filename, |
| repo_id=config.hf_repo_id, |
| repo_type="model" |
| ) |
| hf_api.upload_folder( |
| folder_path=os.path.join(checkpoint_path, "unet"), |
| path_in_repo=f"{config.run_name}/checkpoint-{step:08}/unet", |
| repo_id=config.hf_repo_id, |
| repo_type="model" |
| ) |
| hf_api.upload_file( |
| path_or_fileobj=metadata_path, |
| path_in_repo=f"{config.run_name}/checkpoint-{step:08}/trained_images.json", |
| repo_id=config.hf_repo_id, |
| repo_type="model" |
| ) |
| print(f"✓ Uploaded to hub: {config.hf_repo_id}") |
| except Exception as e: |
| print(f"⚠ Upload failed: {e}") |
| |
| print("\nStarting training...") |
| progress_bar = tqdm(total=total_steps, initial=0) |
| |
| epoch = 0 |
| while global_step < end_step: |
| epoch += 1 |
| for batch in train_dataloader: |
| if global_step >= end_step: |
| break |
| |
| loss = get_prediction(batch, log_to=train_logs) |
| t_writer.add_scalar("train/loss", loss.item(), global_step) |
| t_writer.add_scalar("train/lr", scheduler.get_last_lr()[0], global_step) |
| |
| |
| if len(train_logs["train_timestep"]) > 0: |
| recent_timesteps = train_logs["train_timestep"][-config.batch_size:] |
| t_writer.add_scalar("train/mean_timestep", sum(recent_timesteps) / len(recent_timesteps), global_step) |
| t_writer.add_scalar("train/min_timestep", min(recent_timesteps), global_step) |
| t_writer.add_scalar("train/max_timestep", max(recent_timesteps), global_step) |
| |
| loss.backward() |
| |
| grad_norm = torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) |
| t_writer.add_scalar("train/grad_norm", grad_norm.item(), global_step) |
| |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| |
| progress_bar.update(1) |
| progress_bar.set_postfix({ |
| "epoch": epoch, |
| "loss": f"{loss.item():.4f}", |
| "lr": f"{scheduler.get_last_lr()[0]:.2e}", |
| "gstep": global_step |
| }) |
| global_step += 1 |
| |
| if global_step % 100 == 0: |
| plot_logs(train_logs) |
| t_writer.add_figure("train_loss", plt.gcf(), global_step) |
| plt.close() |
| |
| if global_step % config.checkpointing_steps == 0: |
| save_checkpoint(global_step, epoch) |
| |
| |
| save_checkpoint(global_step, epoch) |
| |
| print("\n✅ Training complete!") |
|
|
|
|
| if __name__ == "__main__": |
| config = TrainConfig() |
| train(config) |