| import os |
| import gc |
| import glob |
| from multiprocessing import Pool |
| import time |
| from tqdm import tqdm |
| import torch |
| from safetensors.torch import load_file |
| from diffusers import FluxTransformer2DModel, FluxPipeline |
| from huggingface_hub import snapshot_download |
| from PIL import Image |
|
|
| |
| DEVICE = torch.device("cpu") |
| |
| USE_CPU_OFFLOAD = True |
| DTYPE = torch.bfloat16 |
| NUM_WORKERS = 1 |
| SEED = 0 |
| IMAGE_WIDTH = 880 |
| IMAGE_HEIGHT = 656 |
|
|
| PROMPTS = [ |
| "a tiny astronaut hatching from an egg on the moon", |
| |
| 'photo of a man on a beach holding a sign that says "Premature optimization is the root of all evil - test your shit!"' |
| ] |
| STEP_COUNTS = [4, 8, 16, 32, 50] |
| MERGE_RATIOS = [ |
| |
| (1, 0), (12, 1), (10, 1), (7, 1), (5.5, 1), (4, 1), (3.5, 1), (3, 1), (2.5, 1), (2, 1), (1.5, 1), (0, 1) |
| ] |
| MERGE_LABELS = [ |
| |
| "Pure Schnell", "12:1", "10:1", "7:1", "5.5:1", "4:1", "3.5:1", "3:1", "2.5:1", "2:1", "1.5:1", "Pure Dev" |
| ] |
| assert len(MERGE_RATIOS) == len(MERGE_LABELS) |
|
|
| |
| IMAGE_OUTPUT_DIR = "./outputs" |
| MODEL_OUTPUT_DIR = "./merged_models" |
| SAVE_MODELS = False |
| os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True) |
|
|
|
|
| |
| def cleanup(): |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
|
|
| |
| start_time = time.time() |
|
|
|
|
| def merge_models(dev_shards, schnell_shards, ratio): |
| schnell_weight, dev_weight = ratio |
| total_weight = schnell_weight + dev_weight |
|
|
| merged_state_dict = {} |
| guidance_state_dict = {} |
|
|
| for i in tqdm(range(len(dev_shards)), "Processing shards...", dynamic_ncols=True): |
| state_dict_dev = load_file(dev_shards[i]) |
| state_dict_schnell = load_file(schnell_shards[i]) |
|
|
| keys = list(state_dict_dev.keys()) |
| for k in tqdm(keys, f"\tProcessing keys of shard {i}...", dynamic_ncols=True): |
| if "guidance" not in k: |
| merged_state_dict[k] = ( |
| state_dict_schnell[k] * schnell_weight + |
| state_dict_dev[k] * dev_weight |
| ) / total_weight |
| else: |
| guidance_state_dict[k] = state_dict_dev[k] |
|
|
| merged_state_dict.update(guidance_state_dict) |
| return merged_state_dict |
|
|
|
|
| |
| def create_merged_model(dev_ckpt, schnell_ckpt, ratio): |
| config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer") |
| model = FluxTransformer2DModel.from_config(config) |
|
|
| dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors")) |
| schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors")) |
|
|
| merged_state_dict = merge_models(dev_shards, schnell_shards, ratio) |
| model.load_state_dict(merged_state_dict) |
| del merged_state_dict |
| cleanup() |
|
|
| return model.to(DTYPE) |
|
|
|
|
| def generate_image(pipeline, prompt, num_steps, output_path): |
| if not os.path.exists(output_path): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| image = pipeline( |
| prompt=prompt, |
| guidance_scale=3.5, |
| num_inference_steps=num_steps, |
| height=IMAGE_HEIGHT, |
| width=IMAGE_WIDTH, |
| max_sequence_length=512, |
| generator=torch.manual_seed(SEED), |
| ).images[0] |
| image.save(output_path) |
| else: |
| print("Image already exists, skipping...") |
|
|
|
|
| def process_model(ratio, label, dev_ckpt, schnell_ckpt): |
| image_output_dir = os.path.join(IMAGE_OUTPUT_DIR, label.replace(":", "_")) |
| os.makedirs(image_output_dir, exist_ok=True) |
| existing_images = len([name for name in os.listdir(image_output_dir) if os.path.isfile(os.path.join(image_output_dir, name))]) |
| if existing_images == len(PROMPTS) * len(STEP_COUNTS): |
| print(f"\nModel {label} already complete, skipping...") |
| return |
| else: |
| print(f"\nProcessing {label} model...") |
|
|
| if ratio == (1, 0): |
| model = FluxTransformer2DModel.from_pretrained(schnell_ckpt, subfolder="transformer", torch_dtype=DTYPE) |
| elif ratio == (0, 1): |
| model = FluxTransformer2DModel.save_pretrained().from_pretrained(dev_ckpt, subfolder="transformer", torch_dtype=DTYPE) |
| else: |
| model = create_merged_model(dev_ckpt, schnell_ckpt, ratio) |
|
|
| if SAVE_MODELS: |
| model_output_dir = os.path.join(MODEL_OUTPUT_DIR, label.replace(":", "_")) |
| print(f"Saving model to {model_output_dir}...") |
| model.save_pretrained(model_output_dir, max_shared_size="50GB", safe_serialization=True) |
|
|
| pipeline = FluxPipeline.from_pretrained( |
| dev_ckpt, |
| transformer=model, |
| torch_dtype=DTYPE, |
| ).to(DEVICE) |
| if USE_CPU_OFFLOAD: |
| pipeline.enable_sequential_cpu_offload() |
| |
|
|
| for prompt_idx, prompt in enumerate(PROMPTS): |
| for step_count in STEP_COUNTS: |
| output_path = os.path.join( |
| image_output_dir, |
| f"prompt{prompt_idx + 1}_steps{step_count}.png" |
| ) |
| generate_image(pipeline, prompt, step_count, output_path) |
|
|
| del pipeline |
| del model |
| cleanup() |
|
|
|
|
| def main(): |
| dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", ignore_patterns=["flux1-dev.sft","flux1-dev.safetensors"], |
| local_dir="./models/dev/") |
| schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*", |
| local_dir="./models/schnell/") |
|
|
| with Pool(NUM_WORKERS) as pool: |
| results = [ |
| pool.apply_async( |
| process_model, |
| (ratio, label, dev_ckpt, schnell_ckpt) |
| ) |
| for ratio, label in zip(MERGE_RATIOS, MERGE_LABELS) |
| ] |
|
|
| for result in tqdm(results): |
| result.get() |
|
|
| pool.close() |
| pool.join() |
|
|
|
|
| def create_image_grid(image_paths, output_path, padding=10): |
| width = IMAGE_WIDTH // 2 |
| height = IMAGE_HEIGHT // 2 |
| images = [Image.open(path).resize((width, height)) for path in image_paths] |
|
|
| grid_cols = len(MERGE_RATIOS) |
| grid_rows = len(STEP_COUNTS) |
| top_pad = 250 |
| left_pad = 200 |
| grid_width = (width * grid_cols) + (padding * (grid_cols + 1)) + left_pad |
| grid_height = (height * grid_rows) + (padding * (grid_rows + 1)) + top_pad |
|
|
| grid_image = Image.new('RGB', (grid_width, grid_height), color=(255, 255, 255)) |
|
|
| for idx, img in enumerate(images): |
| row = idx // grid_cols |
| col = idx % grid_cols |
| x_position = (col * width) + (padding * (col + 1)) + left_pad |
| y_position = (row * height) + (padding * (row + 1)) + top_pad |
| grid_image.paste(img, (x_position, y_position)) |
|
|
| grid_image.save(output_path) |
|
|
|
|
| |
| main() |
|
|
| |
| print("Creating image comparison grid...") |
| |
| all_image_paths = [ |
| os.path.join( |
| IMAGE_OUTPUT_DIR, |
| label.replace(":", "_"), |
| f"prompt{prompt_idx + 1}_steps{step_count}.png" |
| ) |
| for prompt_idx in range(len(PROMPTS)) |
| for step_count in STEP_COUNTS |
| for label in MERGE_LABELS |
| ] |
| missing_images = [path for path in all_image_paths if not os.path.exists(path)] |
| if missing_images: |
| print(f"Warning: {len(missing_images)} images were not generated:") |
| for path in missing_images[:5]: |
| print(f" • {path}") |
| if len(missing_images) > 5: |
| print(f" (and {len(missing_images) - 5} more...)") |
|
|
| |
| for prompt_idx in range(len(PROMPTS)): |
| prompt_images = [path for path in all_image_paths if f"prompt{prompt_idx + 1}" in path] |
| grid_output_path = os.path.join(IMAGE_OUTPUT_DIR, f"grid_prompt{prompt_idx + 1}.png") |
| create_image_grid(prompt_images, grid_output_path) |
|
|
| |
| end_time = time.time() |
| total_time = end_time - start_time |
| num_images = len(all_image_paths) |
|
|
| print(f"\nProcessing complete!") |
| print(f"Total time: {total_time:.2f} seconds") |
| print(f"Total images generated: {num_images}") |
| print(f"Average time per image: {total_time / num_images:.2f} seconds") |
| print(f"Output directory: {IMAGE_OUTPUT_DIR}") |
|
|