| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| !pip install -q diffusers transformers accelerate safetensors |
|
|
| import torch |
| import gc |
| from huggingface_hub import hf_hub_download |
| from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler |
| from transformers import CLIPTextModel, CLIPTokenizer |
| from PIL import Image |
| import numpy as np |
|
|
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| |
| |
| |
| DEVICE = "cuda" |
| DTYPE = torch.float16 |
|
|
| SOL_REPO = "AbstractPhil/sd15-flow-matching" |
| SOL_FILENAME = "sd15_flowmatch_david_weighted_efinal.pt" |
|
|
| |
| |
| |
| print("Loading CLIP...") |
| clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
| clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() |
|
|
| print("Loading VAE...") |
| vae = AutoencoderKL.from_pretrained( |
| "stable-diffusion-v1-5/stable-diffusion-v1-5", |
| subfolder="vae", |
| torch_dtype=DTYPE |
| ).to(DEVICE).eval() |
|
|
| print("Loading UNet...") |
| unet = UNet2DConditionModel.from_pretrained( |
| "stable-diffusion-v1-5/stable-diffusion-v1-5", |
| subfolder="unet", |
| torch_dtype=DTYPE, |
| ).to(DEVICE).eval() |
|
|
| print("Loading DDPM Scheduler...") |
| sched = DDPMScheduler(num_train_timesteps=1000) |
|
|
| |
| |
| |
| print(f"\nLoading Sol from {SOL_REPO}...") |
| weights_path = hf_hub_download(repo_id=SOL_REPO, filename=SOL_FILENAME) |
| checkpoint = torch.load(weights_path, map_location="cpu") |
|
|
| state_dict = checkpoint["student"] |
| print(f" gstep: {checkpoint.get('gstep', 'unknown')}") |
|
|
| if any(k.startswith("unet.") for k in state_dict.keys()): |
| state_dict = {k.replace("unet.", ""): v for k, v in state_dict.items() if k.startswith("unet.")} |
|
|
| state_dict = {k: v for k, v in state_dict.items() if not k.startswith(("hooks.", "local_heads."))} |
|
|
| missing, unexpected = unet.load_state_dict(state_dict, strict=False) |
| print(f" Loaded: {len(state_dict)} keys, missing: {len(missing)}, unexpected: {len(unexpected)}") |
|
|
| del checkpoint, state_dict |
| gc.collect() |
|
|
| for p in unet.parameters(): |
| p.requires_grad = False |
|
|
| print("✓ Sol ready!") |
|
|
| |
| |
| |
| def alpha_sigma(t: torch.LongTensor): |
| """Get alpha and sigma from DDPM alphas_cumprod - matches trainer exactly.""" |
| ac = sched.alphas_cumprod.to(DEVICE)[t] |
| alpha = ac.sqrt().view(-1, 1, 1, 1).float() |
| sigma = (1.0 - ac).sqrt().view(-1, 1, 1, 1).float() |
| return alpha, sigma |
|
|
| |
| |
| |
| @torch.inference_mode() |
| def generate_sol(prompt, negative_prompt="", seed=42, steps=30, cfg=7.5): |
| """ |
| Matches trainer's sample() method exactly: |
| 1. Use DDPM scheduler timesteps |
| 2. Model predicts velocity v |
| 3. Convert v → x0_hat → eps_hat |
| 4. Use sched.step(eps_hat, t, x_t) |
| """ |
| if seed is not None: |
| torch.manual_seed(seed) |
| |
| |
| inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE) |
| cond = clip_enc(**inputs).last_hidden_state.to(DTYPE) |
| |
| inputs_neg = clip_tok(negative_prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE) |
| uncond = clip_enc(**inputs_neg).last_hidden_state.to(DTYPE) |
| |
| |
| sched.set_timesteps(steps, device=DEVICE) |
| |
| |
| x_t = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE) |
| |
| print(f"Sampling '{prompt[:40]}' | {steps} steps, cfg={cfg}") |
| |
| for i, t_scalar in enumerate(sched.timesteps): |
| t = torch.full((1,), t_scalar, device=DEVICE, dtype=torch.long) |
| |
| |
| v_cond = unet(x_t.to(DTYPE), t, encoder_hidden_states=cond).sample |
| v_uncond = unet(x_t.to(DTYPE), t, encoder_hidden_states=uncond).sample |
| |
| |
| v_hat = v_uncond + cfg * (v_cond - v_uncond) |
| |
| |
| alpha, sigma = alpha_sigma(t) |
| |
| |
| |
| |
| |
| denom = alpha**2 + sigma**2 |
| x0_hat = (alpha * x_t.float() - sigma * v_hat.float()) / (denom + 1e-8) |
| eps_hat = (x_t.float() - alpha * x0_hat) / (sigma + 1e-8) |
| |
| |
| step_out = sched.step(eps_hat, t_scalar, x_t.float()) |
| x_t = step_out.prev_sample.to(DTYPE) |
| |
| if (i + 1) % max(1, steps // 5) == 0: |
| print(f" Step {i+1}/{steps}, t={t_scalar}") |
| |
| |
| x_t = x_t / 0.18215 |
| img = vae.decode(x_t).sample |
| img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy() |
| |
| return Image.fromarray((img * 255).astype(np.uint8)) |
|
|
| |
| |
| |
| print("\n" + "="*60) |
| print("Generating test images with Sol (correct sampler)") |
| print("="*60) |
|
|
| from IPython.display import display |
|
|
| prompts = [ |
| "a castle at sunset", |
| "a portrait of a woman", |
| "a city street at night", |
| ] |
|
|
| for prompt in prompts: |
| print() |
| img = generate_sol(prompt, negative_prompt="", seed=42, steps=4, cfg=5.0) |
| display(img) |
| |
| print("\n✓ Bask in the beauty of the geometric expert!") |