| """SDXL LoRA training script — run on Google Colab (T4 GPU). |
| |
| Trains a style LoRA on SDXL using DreamBooth with 15-20 curated images. |
| The trained weights (.safetensors) can then be used with image_generator_hf.py / image_generator_api.py. |
| |
| Setup: |
| 1. Open Google Colab with a T4 GPU runtime |
| 2. Upload this script, or copy each section into separate cells |
| 3. Upload your style images to lora_training_data/ |
| 4. Add a .txt caption file alongside each image |
| 5. Run all cells in order |
| 6. Download the trained .safetensors from styles/ |
| |
| Dataset structure: |
| lora_training_data/ |
| image_001.png |
| image_001.txt # "a sunset landscape with mountains, in sks style" |
| image_002.jpg |
| image_002.txt # "a woman silhouetted against warm sky, in sks style" |
| ... |
| """ |
|
|
| import json |
| import subprocess |
| import sys |
| from pathlib import Path |
|
|
|
|
| |
| |
| |
|
|
| |
| TRIGGER_WORD = "sks" |
| INSTANCE_PROMPT = f"a photo in {TRIGGER_WORD} style" |
|
|
| |
| CONFIG = { |
| "base_model": "stabilityai/stable-diffusion-xl-base-1.0", |
| "vae": "madebyollin/sdxl-vae-fp16-fix", |
| "resolution": 1024, |
| "train_batch_size": 1, |
| "gradient_accumulation_steps": 4, |
| "learning_rate": 1e-4, |
| "lr_scheduler": "constant", |
| "lr_warmup_steps": 0, |
| "max_train_steps": 1500, |
| "rank": 16, |
| "snr_gamma": 5.0, |
| "mixed_precision": "fp16", |
| "checkpointing_steps": 500, |
| "seed": 42, |
| } |
|
|
| |
| DATASET_DIR = "/content/drive/MyDrive/lora_training_data" |
| OUTPUT_DIR = "/content/drive/MyDrive/lora_output" |
| FINAL_WEIGHTS_DIR = "styles" |
|
|
|
|
| |
| |
| |
|
|
| def install_dependencies(): |
| """Install training dependencies (run once per Colab session).""" |
| |
| if not Path("diffusers").exists(): |
| subprocess.check_call([ |
| "git", "clone", "--depth", "1", |
| "https://github.com/huggingface/diffusers", |
| ]) |
|
|
| |
| subprocess.check_call([ |
| sys.executable, "-m", "pip", "install", "-q", "./diffusers", |
| ]) |
| subprocess.check_call([ |
| sys.executable, "-m", "pip", "install", "-q", |
| "-r", "diffusers/examples/dreambooth/requirements.txt", |
| ]) |
|
|
| |
| subprocess.check_call([ |
| sys.executable, "-m", "pip", "install", "-q", |
| "transformers", "accelerate", |
| "bitsandbytes", "safetensors", "Pillow", |
| ]) |
| subprocess.check_call([ |
| sys.executable, "-m", "pip", "install", "-q", |
| "peft>=0.17.0", |
| ]) |
|
|
| print("Dependencies installed.") |
|
|
|
|
| |
| |
| |
|
|
| def configure_accelerate(): |
| """Write a single-GPU accelerate config.""" |
| from accelerate.utils import write_basic_config |
|
|
| write_basic_config() |
| print("Accelerate configured for single GPU.") |
|
|
|
|
| |
| |
| |
|
|
| def verify_dataset(dataset_dir: str = DATASET_DIR) -> int: |
| """Verify dataset folder has images + metadata.jsonl (no .txt files). |
| |
| Args: |
| dataset_dir: Path to folder on Google Drive. |
| |
| Returns: |
| Number of images found. |
| """ |
| dataset_path = Path(dataset_dir) |
| image_extensions = {".png", ".jpg", ".jpeg", ".webp", ".bmp"} |
|
|
| images = [f for f in dataset_path.iterdir() if f.suffix.lower() in image_extensions] |
| metadata = dataset_path / "metadata.jsonl" |
|
|
| if not images: |
| raise FileNotFoundError(f"No images found in {dataset_dir}/.") |
| if not metadata.exists(): |
| raise FileNotFoundError(f"metadata.jsonl not found in {dataset_dir}/.") |
|
|
| |
| txt_files = [f for f in dataset_path.glob("*.txt")] |
| if txt_files: |
| raise RuntimeError( |
| f"Found .txt files in dataset folder: {[f.name for f in txt_files]}. " |
| f"Remove them — only images + metadata.jsonl should be present." |
| ) |
|
|
| print(f"Dataset OK: {len(images)} images + metadata.jsonl") |
| return len(images) |
|
|
|
|
| |
| |
| |
|
|
| def train( |
| dataset_dir: str = DATASET_DIR, |
| output_dir: str = OUTPUT_DIR, |
| resume: bool = False, |
| ): |
| """Launch DreamBooth LoRA training on SDXL. |
| |
| Args: |
| dataset_dir: Path to prepared dataset. |
| output_dir: Where to save checkpoints and final weights. |
| resume: If True, resume from the latest checkpoint. |
| """ |
| cfg = CONFIG |
|
|
| cmd = [ |
| sys.executable, "-m", "accelerate.commands.launch", |
| "diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", |
| f"--pretrained_model_name_or_path={cfg['base_model']}", |
| f"--pretrained_vae_model_name_or_path={cfg['vae']}", |
| f"--dataset_name={dataset_dir}", |
| "--image_column=image", |
| "--caption_column=prompt", |
| f"--output_dir={output_dir}", |
| f"--resolution={cfg['resolution']}", |
| f"--train_batch_size={cfg['train_batch_size']}", |
| f"--gradient_accumulation_steps={cfg['gradient_accumulation_steps']}", |
| "--gradient_checkpointing", |
| "--use_8bit_adam", |
| f"--mixed_precision={cfg['mixed_precision']}", |
| f"--learning_rate={cfg['learning_rate']}", |
| f"--lr_scheduler={cfg['lr_scheduler']}", |
| f"--lr_warmup_steps={cfg['lr_warmup_steps']}", |
| f"--max_train_steps={cfg['max_train_steps']}", |
| f"--rank={cfg['rank']}", |
| f"--snr_gamma={cfg['snr_gamma']}", |
| f"--instance_prompt={INSTANCE_PROMPT}", |
| f"--checkpointing_steps={cfg['checkpointing_steps']}", |
| f"--seed={cfg['seed']}", |
| ] |
|
|
| if resume: |
| cmd.append("--resume_from_checkpoint=latest") |
|
|
| print("Starting training...") |
| print(f" Model: {cfg['base_model']}") |
| print(f" Steps: {cfg['max_train_steps']}") |
| print(f" Rank: {cfg['rank']}") |
| print(f" LR: {cfg['learning_rate']}") |
| print(f" Resume: {resume}") |
| print() |
|
|
| |
| process = subprocess.Popen( |
| cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, |
| bufsize=1, text=True, |
| ) |
| for line in process.stdout: |
| print(line, end="", flush=True) |
| process.wait() |
| if process.returncode != 0: |
| raise RuntimeError(f"Training failed with exit code {process.returncode}") |
|
|
| print(f"\nTraining complete! Weights saved to {output_dir}/") |
|
|
|
|
| |
| |
| |
|
|
| def export_weights( |
| output_dir: str = OUTPUT_DIR, |
| styles_dir: str = FINAL_WEIGHTS_DIR, |
| style_name: str = "custom-style", |
| ): |
| """Copy trained LoRA weights to the styles directory. |
| |
| Looks for final weights first, falls back to latest checkpoint. |
| """ |
| output_path = Path(output_dir) |
|
|
| |
| src = output_path / "pytorch_lora_weights.safetensors" |
|
|
| |
| if not src.exists(): |
| checkpoints = sorted( |
| output_path.glob("checkpoint-*"), |
| key=lambda p: int(p.name.split("-")[1]), |
| ) |
| if checkpoints: |
| latest = checkpoints[-1] |
| |
| for candidate in [ |
| latest / "pytorch_lora_weights.safetensors", |
| latest / "unet" / "adapter_model.safetensors", |
| ]: |
| if candidate.exists(): |
| src = candidate |
| print(f"Using checkpoint: {latest.name}") |
| break |
|
|
| if not src.exists(): |
| raise FileNotFoundError( |
| f"No weights found in {output_dir}/. " |
| f"Check that training completed or a checkpoint was saved." |
| ) |
|
|
| dst_dir = Path(styles_dir) |
| dst_dir.mkdir(parents=True, exist_ok=True) |
| dst = dst_dir / f"{style_name}.safetensors" |
|
|
| import shutil |
| shutil.copy2(src, dst) |
|
|
| size_mb = dst.stat().st_size / (1024 * 1024) |
| print(f"Exported weights: {dst} ({size_mb:.1f} MB)") |
| print(f"Download this file and place it in your project's styles/ folder.") |
|
|
|
|
| |
| |
| |
|
|
| def backup_to_drive(output_dir: str = OUTPUT_DIR): |
| """Copy training output to Google Drive for safety. |
| |
| Note: If OUTPUT_DIR already points to Drive, this is a no-op. |
| """ |
| drive_path = Path("/content/drive/MyDrive/lora_output") |
|
|
| if Path(output_dir).resolve() == drive_path.resolve(): |
| print("Output already on Google Drive — no backup needed.") |
| return |
|
|
| if not Path("/content/drive/MyDrive").exists(): |
| from google.colab import drive |
| drive.mount("/content/drive") |
|
|
| import shutil |
| shutil.copytree(output_dir, str(drive_path), dirs_exist_ok=True) |
| print(f"Backed up to {drive_path}") |
|
|
|
|
| |
| |
| |
|
|
| def test_inference( |
| output_dir: str = OUTPUT_DIR, |
| prompt: str = None, |
| ): |
| """Generate a test image with the trained LoRA + Hyper-SD to verify quality. |
| |
| Uses the same setup as image_generator_hf.py for accurate results. |
| """ |
| import torch |
| from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline |
| from huggingface_hub import hf_hub_download |
|
|
| if prompt is None: |
| prompt = f"a serene mountain landscape at golden hour, in {TRIGGER_WORD} style" |
|
|
| print("Loading model + LoRA for test inference...") |
|
|
| vae = AutoencoderKL.from_pretrained( |
| CONFIG["vae"], torch_dtype=torch.float16, |
| ) |
|
|
| pipe = DiffusionPipeline.from_pretrained( |
| CONFIG["base_model"], |
| vae=vae, |
| torch_dtype=torch.float16, |
| variant="fp16", |
| ).to("cuda") |
|
|
| |
| hyper_path = hf_hub_download( |
| "ByteDance/Hyper-SD", "Hyper-SDXL-8steps-CFG-lora.safetensors", |
| ) |
| pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd") |
|
|
| |
| output_path = Path(output_dir) |
| weights_file = output_path / "pytorch_lora_weights.safetensors" |
| if not weights_file.exists(): |
| checkpoints = sorted( |
| output_path.glob("checkpoint-*"), |
| key=lambda p: int(p.name.split("-")[1]), |
| ) |
| if checkpoints: |
| weights_file = checkpoints[-1] / "pytorch_lora_weights.safetensors" |
| pipe.load_lora_weights( |
| str(weights_file.parent), |
| weight_name=weights_file.name, |
| adapter_name="style", |
| ) |
|
|
| pipe.set_adapters( |
| ["hyper-sd", "style"], |
| adapter_weights=[0.125, 1.0], |
| ) |
|
|
| pipe.scheduler = DDIMScheduler.from_config( |
| pipe.scheduler.config, timestep_spacing="trailing", |
| ) |
|
|
| image = pipe( |
| prompt=prompt, |
| negative_prompt="blurry, low quality, deformed, ugly, text, watermark", |
| num_inference_steps=8, |
| guidance_scale=5.0, |
| height=1344, |
| width=768, |
| ).images[0] |
|
|
| image.save("test_output.png") |
| print(f"Test image saved to test_output.png") |
| print(f"Prompt: {prompt}") |
|
|
| return image |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| print("=" * 60) |
| print("SDXL LoRA Training Pipeline") |
| print("=" * 60) |
|
|
| |
| install_dependencies() |
|
|
| |
| configure_accelerate() |
|
|
| |
| num_images = verify_dataset() |
| steps = max(1500, num_images * 100) |
| CONFIG["max_train_steps"] = steps |
| print(f"Adjusted training steps to {steps} ({num_images} images × 100)") |
|
|
| |
| train() |
|
|
| |
| backup_to_drive() |
|
|
| |
| export_weights(style_name="custom-style") |
|
|
| |
| test_inference() |
|
|
| print("\nDone! Download styles/custom-style.safetensors") |
|
|