| """Generate images using SDXL + Hyper-SD 8-step + style LoRA from registry. |
| |
| Reads segments.json (with prompts from prompt_generator) and generates |
| one 768x1344 (9:16 vertical) image per segment. |
| |
| Pipeline: SDXL base → Hyper-SD 8-step CFG LoRA (speed) → style LoRA (aesthetics) |
| """ |
|
|
| import json |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline |
| from huggingface_hub import hf_hub_download |
|
|
| from src.styles import get_style |
|
|
| |
| |
| |
|
|
| BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" |
| VAE_MODEL = "madebyollin/sdxl-vae-fp16-fix" |
| HYPER_SD_REPO = "ByteDance/Hyper-SD" |
| HYPER_SD_FILE = "Hyper-SDXL-8steps-CFG-lora.safetensors" |
|
|
| WIDTH = 768 |
| HEIGHT = 1344 |
| NUM_STEPS = 8 |
| GUIDANCE_SCALE = 5.0 |
|
|
| HYPER_SD_WEIGHT = 0.125 |
|
|
|
|
| def _get_device_and_dtype(): |
| """Detect best available device and matching dtype.""" |
| if torch.cuda.is_available(): |
| return "cuda", torch.float16 |
| if torch.backends.mps.is_available(): |
| return "mps", torch.float32 |
| return "cpu", torch.float32 |
|
|
|
|
| def load_pipeline(style_name: str = "Warm Sunset"): |
| """Load SDXL pipeline with Hyper-SD and a style LoRA from the registry. |
| |
| Args: |
| style_name: Key in STYLES registry. Use "None" for no style LoRA. |
| |
| Returns: |
| Configured DiffusionPipeline ready for inference. |
| """ |
| style = get_style(style_name) |
| device, dtype = _get_device_and_dtype() |
| print(f"Loading SDXL pipeline on {device} ({dtype})...") |
|
|
| vae = AutoencoderKL.from_pretrained(VAE_MODEL, torch_dtype=dtype) |
|
|
| load_kwargs = {"torch_dtype": dtype, "vae": vae, "use_safetensors": True} |
| if dtype == torch.float16: |
| load_kwargs["variant"] = "fp16" |
|
|
| pipe = DiffusionPipeline.from_pretrained(BASE_MODEL, **load_kwargs) |
|
|
| |
| hyper_path = hf_hub_download(HYPER_SD_REPO, HYPER_SD_FILE) |
| pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd") |
|
|
| |
| _apply_style(pipe, style) |
|
|
| |
| pipe.scheduler = DDIMScheduler.from_config( |
| pipe.scheduler.config, timestep_spacing="trailing" |
| ) |
|
|
| pipe.to(device) |
|
|
| if device == "mps": |
| pipe.enable_attention_slicing() |
| pipe.enable_vae_slicing() |
|
|
| print("Pipeline ready.") |
| return pipe |
|
|
|
|
| def _apply_style(pipe, style: dict): |
| """Load a style LoRA and set adapter weights.""" |
| source = style["source"] |
| if source is None: |
| pipe.set_adapters(["hyper-sd"], adapter_weights=[HYPER_SD_WEIGHT]) |
| print("No style LoRA — using base SDXL + Hyper-SD.") |
| return |
|
|
| load_kwargs = {"adapter_name": "style"} |
|
|
| |
| project_root = Path(__file__).resolve().parent.parent |
| source_path = (project_root / source).resolve() |
| if source_path.is_file(): |
| load_kwargs["weight_name"] = source_path.name |
| pipe.load_lora_weights(str(source_path.parent), **load_kwargs) |
| else: |
| |
| if style["weight_name"]: |
| load_kwargs["weight_name"] = style["weight_name"] |
| pipe.load_lora_weights(source, **load_kwargs) |
| pipe.set_adapters( |
| ["hyper-sd", "style"], |
| adapter_weights=[HYPER_SD_WEIGHT, style["weight"]], |
| ) |
| print(f"Loaded style LoRA: {source}") |
|
|
|
|
| def switch_style(pipe, style_name: str): |
| """Switch to a different style LoRA at runtime. |
| |
| Unloads all LoRAs then reloads Hyper-SD + new style. |
| """ |
| style = get_style(style_name) |
|
|
| pipe.unload_lora_weights() |
|
|
| |
| hyper_path = hf_hub_download(HYPER_SD_REPO, HYPER_SD_FILE) |
| pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd") |
|
|
| |
| _apply_style(pipe, style) |
| print(f"Switched to style: {style_name}") |
|
|
|
|
| def generate_image( |
| pipe, |
| prompt: str, |
| negative_prompt: str = "", |
| seed: Optional[int] = None, |
| ) -> "PIL.Image.Image": |
| """Generate a single 768x1344 vertical image.""" |
| generator = None |
| if seed is not None: |
| generator = torch.Generator(device="cpu").manual_seed(seed) |
|
|
| return pipe( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| num_inference_steps=NUM_STEPS, |
| guidance_scale=GUIDANCE_SCALE, |
| height=HEIGHT, |
| width=WIDTH, |
| generator=generator, |
| ).images[0] |
|
|
|
|
| def generate_all( |
| segments: list[dict], |
| pipe, |
| output_dir: str | Path, |
| trigger_word: str = "", |
| seed: int = 42, |
| progress_callback=None, |
| ) -> list[Path]: |
| """Generate images for all segments. |
| |
| Args: |
| segments: List of segment dicts (with 'prompt' and 'negative_prompt'). |
| pipe: Loaded DiffusionPipeline. |
| output_dir: Directory to save images. |
| trigger_word: LoRA trigger word appended to prompts. |
| seed: Base seed (incremented per segment for variety). |
| |
| Returns: |
| List of saved image paths. |
| """ |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| paths = [] |
| for seg in segments: |
| idx = seg["segment"] |
| path = output_dir / f"segment_{idx:03d}.png" |
|
|
| if path.exists(): |
| print(f" Segment {idx}/{len(segments)}: already exists, skipping") |
| paths.append(path) |
| continue |
|
|
| prompt = seg["prompt"] |
| if trigger_word: |
| prompt = f"{trigger_word} style, {prompt}" |
| neg = seg.get("negative_prompt", "") |
|
|
| print(f" Segment {idx}/{len(segments)}: generating...") |
| image = generate_image(pipe, prompt, neg, seed=seed + idx) |
|
|
| path = output_dir / f"segment_{idx:03d}.png" |
| image.save(path) |
| paths.append(path) |
| print(f" Saved {path.name}") |
| if progress_callback: |
| progress_callback(idx, len(segments)) |
|
|
| return paths |
|
|
|
|
| def run( |
| data_dir: str | Path, |
| style_name: str = "Warm Sunset", |
| seed: int = 42, |
| progress_callback=None, |
| ) -> list[Path]: |
| """Full image generation pipeline: load model, read segments, generate, save. |
| |
| Args: |
| data_dir: Run directory containing segments.json (e.g. data/Gone/run_001/). |
| style_name: Style from the registry (see src/styles.py). |
| seed: Base random seed. |
| |
| Returns: |
| List of saved image paths. |
| """ |
| data_dir = Path(data_dir) |
| style = get_style(style_name) |
|
|
| with open(data_dir / "segments.json") as f: |
| segments = json.load(f) |
|
|
| pipe = load_pipeline(style_name) |
| paths = generate_all(segments, pipe, data_dir / "images", style["trigger"], seed, progress_callback) |
|
|
| print(f"\nGenerated {len(paths)} images in {data_dir / 'images'}") |
| return paths |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
|
|
| if len(sys.argv) < 2: |
| print("Usage: python -m src.image_generator_hf <data_dir> [style_name]") |
| print(' e.g. python -m src.image_generator_hf data/Gone/run_001 "Warm Sunset"') |
| sys.exit(1) |
|
|
| style = sys.argv[2] if len(sys.argv) > 2 else "Warm Sunset" |
| run(sys.argv[1], style_name=style) |
|
|