| |
| import time |
| import argparse |
| import torch |
| from diffusers import FluxPipeline |
|
|
| def benchmark_load_lora( |
| base_model: str, |
| lora_source: str, |
| weight_name: str = None, |
| adapter_name: str = None, |
| dtype = torch.bfloat16, |
| runs: int = 3, |
| ): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Benchmarking on device {device}, torch.cuda.device_count()={torch.cuda.device_count()}.") |
|
|
| print(f"1/4. Loading base Flux.1-dev model …") |
| t0 = time.time() |
| pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=dtype, use_safetensors=True) |
| base_load_s = time.time() - t0 |
| print(f" Base model loaded in {base_load_s:.3f} s") |
|
|
| print("2/4. Moving pipeline to GPU …") |
| t1 = time.time() |
| pipe = pipe.to(device) |
| torch.cuda.synchronize(device) |
| move_s = time.time() - t1 |
| print(f" to('cuda') took {move_s:.3f} s") |
|
|
| |
| for i in range(runs): |
| print(f"3.{i+1}/4. Running load_lora_weights (run {i+1}/{runs}) …") |
| start = time.time() |
| adapter_name = "lora" |
| pipe.load_lora_weights(lora_source, adapter_name=adapter_name) |
| torch.cuda.synchronize(device) |
| duration = time.time() - start |
| print(f" → run {i+1}: load_lora_weights took {duration:.3f} s") |
|
|
| if i < runs - 1: |
| print(" Unloading LoRA …") |
| pipe.unload_lora_weights(reset_to_overwritten_params=True) |
| torch.cuda.synchronize(device) |
|
|
| print("All runs complete.") |
| avg = duration |
| print(f"☆ Final run time: {avg:.3f} s") |
| print(f"― average over {runs} runs ≈ {avg:.3f} s") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Benchmark Flux.1‑dev load_lora_weights timing" |
| ) |
| parser.add_argument("--model", default="black-forest-labs/FLUX.1-dev") |
| parser.add_argument("--lora", required=True, help="LoRA adapter repo ID or local folder / file path") |
| parser.add_argument("--runs", type=int, default=3) |
| args = parser.parse_args() |
|
|
| benchmark_load_lora( |
| base_model=args.model, |
| lora_source=args.lora, |
| runs=args.runs |
| ) |
|
|