| import gradio as gr
|
| import torch
|
| import matplotlib.pyplot as plt
|
| import os
|
| from PIL import Image
|
| import numpy as np
|
|
|
|
|
| from SDLens import HookedStableDiffusionXLPipeline
|
| from training.k_sparse_autoencoder import SparseAutoencoder
|
| from utils.hooks import add_feature_on_text_prompt
|
|
|
|
|
| def modulate_hook_prompt(sae, steering_feature, block):
|
| def hook_function(*args, **kwargs):
|
| return add_feature_on_text_prompt(
|
| sae,
|
| steering_feature,
|
| *args, **kwargs
|
| )
|
| return hook_function
|
|
|
|
|
| def load_models():
|
| try:
|
|
|
| pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo')
|
| pipe.set_progress_bar_config(disable=True)
|
|
|
|
|
| blocks_to_save = ['text_encoder.text_model.encoder.layers.10', 'text_encoder_2.text_model.encoder.layers.28']
|
|
|
|
|
| sae_path = "Checkpoints/dahyecheckpoint"
|
| sae = SparseAutoencoder.load_from_disk(os.path.join(sae_path, 'final'))
|
|
|
| return pipe, blocks_to_save, sae
|
| except Exception as e:
|
| print(f"Error loading models: {e}")
|
| return None, None, None
|
|
|
|
|
| def activation_modulation_across_prompt(pipe, sae, blocks_to_save, steer_prompt, strength, prompt, guidance_scale, num_inference_steps, seed):
|
|
|
| output, cache = pipe.run_with_cache(
|
| steer_prompt,
|
| positions_to_cache=blocks_to_save,
|
| save_input=True,
|
| save_output=True,
|
| num_inference_steps=1,
|
| guidance_scale=guidance_scale,
|
| generator=torch.Generator(device="cpu").manual_seed(seed)
|
| )
|
| diff = torch.cat([cache['output'][blocks_to_save[0]], cache['output'][blocks_to_save[1]]], dim=-1)
|
| diff = diff.squeeze(0).squeeze(0)
|
|
|
| with torch.no_grad():
|
| activated = sae.encode_without_topk(diff)
|
| mask = activated * strength
|
|
|
| to_add = mask @ sae.decoder.weight.T
|
| steering_feature = to_add
|
|
|
|
|
| output = pipe.run_with_hooks(
|
| prompt,
|
| position_hook_dict = {
|
| block: modulate_hook_prompt(sae, steering_feature, block)
|
| for block in blocks_to_save
|
| },
|
| num_inference_steps=num_inference_steps,
|
| guidance_scale=guidance_scale,
|
| generator=torch.Generator(device="cpu").manual_seed(seed)
|
| )
|
|
|
| return output.images[0]
|
|
|
|
|
| def generate_comparison(prompt, steer_prompt, strength, seed, guidance_scale, steps):
|
| if pipe is None or sae is None or blocks_to_save is None:
|
| return Image.new('RGB', (512, 512), color='red'), Image.new('RGB', (512, 512), color='red'), "Error: Models failed to load"
|
|
|
| try:
|
|
|
| standard_image = pipe(
|
| prompt,
|
| num_inference_steps=steps,
|
| guidance_scale=guidance_scale,
|
| generator=torch.Generator(device="cpu").manual_seed(seed)
|
| ).images[0]
|
|
|
|
|
| if strength > 0:
|
| modified_image = activation_modulation_across_prompt(
|
| pipe, sae, blocks_to_save,
|
| steer_prompt, strength, prompt,
|
| guidance_scale, steps, seed
|
| )
|
| else:
|
|
|
| modified_image = standard_image
|
|
|
| comparison_message = f"Generated images with modulation strength: {strength}"
|
| return standard_image, modified_image, comparison_message
|
| except Exception as e:
|
| error_image = Image.new('RGB', (512, 512), color='red')
|
| return error_image, error_image, f"Error during generation: {str(e)}"
|
|
|
|
|
| print("Loading models...")
|
| pipe, blocks_to_save, sae = load_models()
|
| if pipe is not None:
|
| print("Models loaded successfully!")
|
| else:
|
| print("Failed to load models")
|
|
|
|
|
| with gr.Blocks(title="SDXL Activation Modulation") as app:
|
| gr.Markdown("# SDXL Activation Modulation Comparison")
|
| gr.Markdown("""
|
| This app demonstrates activation modulation in Stable Diffusion XL using sparse autoencoders.
|
| It compares standard SDXL-Turbo outputs with modulated outputs that can steer the generation based on a separate concept.
|
| """)
|
|
|
| with gr.Row():
|
| with gr.Column():
|
| prompt = gr.Textbox(label="Prompt", placeholder="Enter your main image prompt here...", value="A photo of a tree")
|
| steer_prompt = gr.Textbox(label="Steering Prompt", placeholder="Enter concept to steer with...", value="tree with autumn leaves")
|
| strength = gr.Slider(minimum=-2.5, maximum=2.5, value=0.8, step=0.05,
|
| label="Modulation Strength (λ)")
|
|
|
| with gr.Accordion("Advanced Settings", open=False):
|
| seed = gr.Slider(minimum=0, maximum=2147483647, step=1, value=61730, label="Seed")
|
| guidance_scale = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.5, label="Guidance Scale")
|
| steps = gr.Slider(minimum=1, maximum=50, value=3, step=1, label="Inference Steps")
|
|
|
| generate_btn = gr.Button("Generate Comparison", variant="primary")
|
| status = gr.Textbox(label="Status", interactive=False)
|
|
|
| with gr.Row():
|
| standard_output = gr.Image(label="Standard SDXL-Turbo")
|
| modified_output = gr.Image(label="Modulated Output")
|
|
|
| gr.Markdown("""
|
| ## Examples from the notebook:
|
| - Main prompt: "A photo of a tree" with steering prompt: "tree with autumn leaves"
|
| - Main prompt: "A dog" with steering prompt: "full shot"
|
| - Main prompt: "A car" with steering prompt: "A blue car"
|
| """)
|
|
|
| with gr.Row():
|
| example1 = gr.Button("Example 1: Tree with autumn leaves")
|
| example2 = gr.Button("Example 2: Dog with full shot")
|
| example3 = gr.Button("Example 3: Blue car")
|
|
|
|
|
| generate_btn.click(
|
| fn=generate_comparison,
|
| inputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps],
|
| outputs=[standard_output, modified_output, status]
|
| )
|
|
|
|
|
| example1.click(
|
| fn=lambda: ["A photo of a tree", "tree with autumn leaves", 0.5, 61730, 0.0, 3],
|
| inputs=None,
|
| outputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps]
|
| )
|
|
|
| example2.click(
|
| fn=lambda: ["A dog", "full shot", 0.4, 61730, 0.0, 3],
|
| inputs=None,
|
| outputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps]
|
| )
|
|
|
| example3.click(
|
| fn=lambda: ["A car", "A blue car", 0.3, 61730, 0.0, 3],
|
| inputs=None,
|
| outputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps]
|
| )
|
|
|
| gr.Markdown("""
|
| ## How to Use
|
| 1. Enter your main prompt (what you want to generate)
|
| 2. Enter a steering prompt (concept to influence the generation)
|
| 3. Adjust the modulation strength slider (λ) - higher values mean stronger influence
|
| 4. Click "Generate Comparison" to see the results side by side
|
| 5. Use advanced settings if needed to adjust seed, guidance scale, or steps
|
|
|
| ## About
|
| This app demonstrates activation modulation using a sparse autoencoder trained on SDXL text encoder layers.
|
| The modulation allows steering the generation toward specific concepts without changing the main prompt.
|
| """)
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| app.launch() |