| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| !pip install -q diffusers transformers accelerate safetensors |
|
|
| import torch |
| import gc |
| from huggingface_hub import hf_hub_download |
| from diffusers import UNet2DConditionModel, AutoencoderKL |
| from transformers import CLIPTextModel, CLIPTokenizer |
| from safetensors.torch import load_file |
| from PIL import Image |
| import numpy as np |
| import json |
|
|
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| |
| |
| |
| DEVICE = "cuda" |
| DTYPE = torch.float16 |
|
|
| LUNE_REPO = "AbstractPhil/sd15-flow-lune-flux" |
| LUNE_WEIGHTS = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/diffusion_pytorch_model.safetensors" |
| LUNE_CONFIG = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/config.json" |
|
|
| |
| |
| |
| print("Loading CLIP...") |
| clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
| clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() |
|
|
| print("Loading VAE...") |
| vae = AutoencoderKL.from_pretrained( |
| "stable-diffusion-v1-5/stable-diffusion-v1-5", |
| subfolder="vae", |
| torch_dtype=DTYPE |
| ).to(DEVICE).eval() |
|
|
| |
| |
| |
| print(f"\nLoading Lune...") |
| config_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_CONFIG) |
| with open(config_path, 'r') as f: |
| lune_config = json.load(f) |
|
|
| print(f" prediction_type: {lune_config.get('prediction_type', 'NOT SET')}") |
|
|
| unet = UNet2DConditionModel.from_config(lune_config).to(DEVICE).to(DTYPE).eval() |
|
|
| weights_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_WEIGHTS) |
| state_dict = load_file(weights_path) |
| unet.load_state_dict(state_dict, strict=False) |
|
|
| del state_dict |
| gc.collect() |
|
|
| for p in unet.parameters(): |
| p.requires_grad = False |
|
|
| print("β Lune ready!") |
|
|
| |
| |
| |
| def shift_sigma(sigma: torch.Tensor, shift: float = 3.0) -> torch.Tensor: |
| """ |
| Apply timestep shift (same as trainer). |
| sigma_shifted = shift * sigma / (1 + (shift - 1) * sigma) |
| """ |
| return (shift * sigma) / (1 + (shift - 1) * sigma) |
|
|
| @torch.inference_mode() |
| def encode_prompt(prompt): |
| inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", |
| max_length=77, truncation=True).to(DEVICE) |
| return clip_enc(**inputs).last_hidden_state.to(DTYPE) |
|
|
| |
| |
| |
| @torch.inference_mode() |
| def generate_lune( |
| prompt: str, |
| negative_prompt: str = "", |
| seed: int = 42, |
| steps: int = 30, |
| cfg: float = 7.5, |
| shift: float = 3.0, |
| ): |
| """ |
| Correct Lune sampler matching trainer's flow convention. |
| |
| Trainer: |
| x_t = sigma * noise + (1 - sigma) * data |
| target = noise - data |
| |
| Sampling: |
| - Start at sigma=1 (pure noise) |
| - End at sigma=0 (clean data) |
| - x_{sigma - dt} = x_sigma - v * dt (SUBTRACT because v points toward noise) |
| """ |
| torch.manual_seed(seed) |
| |
| cond = encode_prompt(prompt) |
| uncond = encode_prompt(negative_prompt) if negative_prompt else encode_prompt("") |
| |
| |
| x = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE) |
| |
| |
| |
| sigmas_linear = torch.linspace(1, 0, steps + 1, device=DEVICE) |
| sigmas = shift_sigma(sigmas_linear, shift=shift) |
| |
| print(f"Lune: '{prompt[:30]}' | {steps} steps, cfg={cfg}, shift={shift}") |
| print(f" sigma range: {sigmas[0].item():.3f} β {sigmas[-1].item():.3f}") |
| |
| for i in range(steps): |
| sigma = sigmas[i] |
| sigma_next = sigmas[i + 1] |
| dt = sigma - sigma_next |
| |
| |
| timestep = sigma * 1000 |
| t_input = timestep.view(1).to(DEVICE) |
| |
| |
| v_cond = unet(x, t_input, encoder_hidden_states=cond).sample |
| v_uncond = unet(x, t_input, encoder_hidden_states=uncond).sample |
| v = v_uncond + cfg * (v_cond - v_uncond) |
| |
| |
| |
| x = x - v * dt |
| |
| if (i + 1) % (steps // 5) == 0: |
| print(f" Step {i+1}/{steps}, sigma={sigma.item():.3f} β {sigma_next.item():.3f}") |
| |
| |
| x = x / 0.18215 |
| img = vae.decode(x).sample |
| img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy() |
| return Image.fromarray((img * 255).astype(np.uint8)) |
|
|
| |
| |
| |
| print("\n" + "="*60) |
| print("Testing Lune with CORRECT flow convention") |
| print(" x_t = sigma*noise + (1-sigma)*data") |
| print(" v = noise - data") |
| print(" Sample by SUBTRACTING v") |
| print("="*60) |
|
|
| from IPython.display import display |
|
|
| prompt = "a castle at sunset" |
|
|
| print("\n--- shift=3.0 (default) ---") |
| img = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=3.0) |
| display(img) |
|
|
| print("\n--- shift=2.5 (trainer default) ---") |
| img2 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=2.5) |
| display(img2) |
|
|
| print("\n--- shift=1.0 (no shift) ---") |
| img3 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=1.0) |
| display(img3) |
|
|
| |
| import matplotlib.pyplot as plt |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
| for ax, (s, im) in zip(axes, [(3.0, img), (2.5, img2), (1.0, img3)]): |
| ax.imshow(im) |
| ax.set_title(f"shift={s}") |
| ax.axis('off') |
| plt.tight_layout() |
| plt.show() |
|
|
| print("\nβ If images look correct, the output should be beautiful.") |