| |
| """ |
| Convert SimpleTuner LoRA weights to diffusers-compatible format for AuraFlow. |
| |
| This script converts LoRA weights saved by SimpleTuner into a format that can be |
| directly loaded by diffusers' load_lora_weights() method. |
| |
| Usage: |
| python convert_simpletuner_lora.py <input_lora.safetensors> <output_lora.safetensors> |
| |
| Example: |
| python convert_simpletuner_lora.py input_lora.safetensors diffusers_compatible_lora.safetensors |
| """ |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
| from typing import Dict |
|
|
| import safetensors.torch |
| import torch |
|
|
|
|
| def detect_lora_format(state_dict: Dict[str, torch.Tensor]) -> str: |
| """ |
| Detect the format of the LoRA state dict. |
| |
| Returns: |
| "peft" if already in PEFT/diffusers format |
| "mixed" if mixed format (some lora_A/B, some lora.down/up) |
| "simpletuner_transformer" if in SimpleTuner format with transformer prefix |
| "simpletuner_auraflow" if in SimpleTuner AuraFlow format |
| "kohya" if in Kohya format |
| "unknown" otherwise |
| """ |
| keys = list(state_dict.keys()) |
| |
| |
| has_lora_a_b = any((".lora_A." in k or ".lora_B." in k) for k in keys) |
| has_lora_down_up = any((".lora_down." in k or ".lora_up." in k) for k in keys) |
| has_lora_dot_down_up = any((".lora.down." in k or ".lora.up." in k) for k in keys) |
| |
| |
| has_transformer_prefix = any(k.startswith("transformer.") for k in keys) |
| has_lora_transformer_prefix = any(k.startswith("lora_transformer_") for k in keys) |
| has_lora_unet_prefix = any(k.startswith("lora_unet_") for k in keys) |
| |
| |
| if has_transformer_prefix and has_lora_a_b and (has_lora_down_up or has_lora_dot_down_up): |
| return "mixed" |
| |
| |
| if has_transformer_prefix and has_lora_a_b and not has_lora_down_up and not has_lora_dot_down_up: |
| return "peft" |
| |
| |
| if has_transformer_prefix and (has_lora_down_up or has_lora_dot_down_up): |
| return "simpletuner_transformer" |
| |
| |
| if has_lora_transformer_prefix and has_lora_down_up: |
| return "simpletuner_auraflow" |
| |
| |
| if has_lora_unet_prefix and has_lora_down_up: |
| return "kohya" |
| |
| return "unknown" |
|
|
|
|
| def convert_mixed_lora_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """ |
| Convert mixed LoRA format to pure PEFT format. |
| |
| SimpleTuner sometimes saves a hybrid format where some layers use lora_A/lora_B |
| and others use .lora.down./.lora.up. This converts all to lora_A/lora_B. |
| """ |
| new_state_dict = {} |
| converted_count = 0 |
| kept_count = 0 |
| skipped_count = 0 |
| renames = [] |
| |
| |
| all_keys = sorted(state_dict.keys()) |
| |
| print("\nProcessing keys:") |
| print("-" * 80) |
| |
| for key in all_keys: |
| |
| if ".lora_A." in key or ".lora_B." in key: |
| new_state_dict[key] = state_dict[key] |
| kept_count += 1 |
| |
| |
| elif ".lora.down.weight" in key: |
| new_key = key.replace(".lora.down.weight", ".lora_A.weight") |
| new_state_dict[new_key] = state_dict[key] |
| renames.append((key, new_key)) |
| converted_count += 1 |
| |
| |
| elif ".lora.up.weight" in key: |
| new_key = key.replace(".lora.up.weight", ".lora_B.weight") |
| new_state_dict[new_key] = state_dict[key] |
| renames.append((key, new_key)) |
| converted_count += 1 |
| |
| |
| elif ".alpha" in key: |
| skipped_count += 1 |
| continue |
| |
| |
| else: |
| new_state_dict[key] = state_dict[key] |
| print(f"⚠ Warning: Unexpected key format: {key}") |
| |
| print(f"\nSummary:") |
| print(f" ✓ Kept {kept_count} keys already in correct format (lora_A/lora_B)") |
| print(f" ✓ Converted {converted_count} keys from .lora.down/.lora.up to lora_A/lora_B") |
| print(f" ✓ Skipped {skipped_count} alpha keys") |
| |
| if renames: |
| print(f"\nRenames applied ({len(renames)} conversions):") |
| print("-" * 80) |
| for old_key, new_key in renames: |
| |
| if ".lora.down.weight" in old_key: |
| layer = old_key.replace(".lora.down.weight", "") |
| print(f" {layer}") |
| print(f" .lora.down.weight → .lora_A.weight") |
| elif ".lora.up.weight" in old_key: |
| layer = old_key.replace(".lora.up.weight", "") |
| print(f" {layer}") |
| print(f" .lora.up.weight → .lora_B.weight") |
| |
| return new_state_dict |
|
|
|
|
| def convert_simpletuner_transformer_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """ |
| Convert SimpleTuner transformer format (already has transformer. prefix but uses lora_down/lora_up) |
| to diffusers PEFT format (transformer. prefix with lora_A/lora_B). |
| |
| This is a simpler conversion since the key structure is already correct. |
| """ |
| new_state_dict = {} |
| renames = [] |
| |
| |
| all_keys = list(state_dict.keys()) |
| base_keys = set() |
| |
| for key in all_keys: |
| if ".lora_down.weight" in key: |
| base_key = key.replace(".lora_down.weight", "") |
| base_keys.add(base_key) |
| |
| print(f"\nFound {len(base_keys)} LoRA layers to convert") |
| print("-" * 80) |
| |
| |
| for base_key in sorted(base_keys): |
| down_key = f"{base_key}.lora_down.weight" |
| up_key = f"{base_key}.lora_up.weight" |
| alpha_key = f"{base_key}.alpha" |
| |
| if down_key not in state_dict or up_key not in state_dict: |
| print(f"⚠ Warning: Missing weights for {base_key}") |
| continue |
| |
| down_weight = state_dict.pop(down_key) |
| up_weight = state_dict.pop(up_key) |
| |
| |
| has_alpha = False |
| if alpha_key in state_dict: |
| alpha = state_dict.pop(alpha_key) |
| lora_rank = down_weight.shape[0] |
| scale = alpha / lora_rank |
| |
| |
| scale_down = scale |
| scale_up = 1.0 |
| while scale_down * 2 < scale_up: |
| scale_down *= 2 |
| scale_up /= 2 |
| |
| down_weight = down_weight * scale_down |
| up_weight = up_weight * scale_up |
| has_alpha = True |
| |
| |
| new_down_key = f"{base_key}.lora_A.weight" |
| new_up_key = f"{base_key}.lora_B.weight" |
| |
| new_state_dict[new_down_key] = down_weight |
| new_state_dict[new_up_key] = up_weight |
| |
| renames.append((down_key, new_down_key, has_alpha)) |
| renames.append((up_key, new_up_key, has_alpha)) |
| |
| |
| remaining = [k for k in state_dict.keys() if not k.startswith("text_encoder")] |
| if remaining: |
| print(f"⚠ Warning: {len(remaining)} keys were not converted: {remaining[:5]}") |
| |
| print(f"\nRenames applied ({len(renames)} conversions):") |
| print("-" * 80) |
| |
| |
| current_layer = None |
| for old_key, new_key, has_alpha in renames: |
| layer = old_key.replace(".lora_down.weight", "").replace(".lora_up.weight", "") |
| |
| if layer != current_layer: |
| alpha_str = " (alpha scaled)" if has_alpha else "" |
| print(f"\n {layer}{alpha_str}") |
| current_layer = layer |
| |
| if ".lora_down.weight" in old_key: |
| print(f" .lora_down.weight → .lora_A.weight") |
| elif ".lora_up.weight" in old_key: |
| print(f" .lora_up.weight → .lora_B.weight") |
| |
| return new_state_dict |
|
|
|
|
| def convert_simpletuner_auraflow_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """ |
| Convert SimpleTuner AuraFlow LoRA format to diffusers PEFT format. |
| |
| SimpleTuner typically saves LoRAs in a format similar to Kohya's sd-scripts, |
| but for transformer-based models like AuraFlow, the keys may differ. |
| """ |
| new_state_dict = {} |
| |
| def _convert(original_key, diffusers_key, state_dict, new_state_dict): |
| """Helper to convert a single LoRA layer.""" |
| down_key = f"{original_key}.lora_down.weight" |
| if down_key not in state_dict: |
| return False |
| |
| down_weight = state_dict.pop(down_key) |
| lora_rank = down_weight.shape[0] |
| |
| up_weight_key = f"{original_key}.lora_up.weight" |
| up_weight = state_dict.pop(up_weight_key) |
| |
| |
| alpha_key = f"{original_key}.alpha" |
| if alpha_key in state_dict: |
| alpha = state_dict.pop(alpha_key) |
| scale = alpha / lora_rank |
| |
| |
| scale_down = scale |
| scale_up = 1.0 |
| while scale_down * 2 < scale_up: |
| scale_down *= 2 |
| scale_up /= 2 |
| |
| down_weight = down_weight * scale_down |
| up_weight = up_weight * scale_up |
| |
| |
| diffusers_down_key = f"{diffusers_key}.lora_A.weight" |
| new_state_dict[diffusers_down_key] = down_weight |
| new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight |
| |
| return True |
| |
| |
| all_unique_keys = { |
| k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") |
| for k in state_dict |
| if ".lora_down.weight" in k or ".lora_up.weight" in k or ".alpha" in k |
| } |
| |
| |
| for original_key in sorted(all_unique_keys): |
| if original_key.startswith("lora_transformer_single_transformer_blocks_"): |
| |
| parts = original_key.split("lora_transformer_single_transformer_blocks_")[-1].split("_") |
| block_idx = int(parts[0]) |
| diffusers_key = f"single_transformer_blocks.{block_idx}" |
| |
| |
| remaining = "_".join(parts[1:]) |
| if "attn_to_q" in remaining: |
| diffusers_key += ".attn.to_q" |
| elif "attn_to_k" in remaining: |
| diffusers_key += ".attn.to_k" |
| elif "attn_to_v" in remaining: |
| diffusers_key += ".attn.to_v" |
| elif "proj_out" in remaining: |
| diffusers_key += ".proj_out" |
| elif "proj_mlp" in remaining: |
| diffusers_key += ".proj_mlp" |
| elif "norm_linear" in remaining: |
| diffusers_key += ".norm.linear" |
| else: |
| print(f"Warning: Unhandled single block key pattern: {original_key}") |
| continue |
| |
| elif original_key.startswith("lora_transformer_transformer_blocks_"): |
| |
| parts = original_key.split("lora_transformer_transformer_blocks_")[-1].split("_") |
| block_idx = int(parts[0]) |
| diffusers_key = f"transformer_blocks.{block_idx}" |
| |
| |
| remaining = "_".join(parts[1:]) |
| if "attn_to_out_0" in remaining: |
| diffusers_key += ".attn.to_out.0" |
| elif "attn_to_add_out" in remaining: |
| diffusers_key += ".attn.to_add_out" |
| elif "attn_to_q" in remaining: |
| diffusers_key += ".attn.to_q" |
| elif "attn_to_k" in remaining: |
| diffusers_key += ".attn.to_k" |
| elif "attn_to_v" in remaining: |
| diffusers_key += ".attn.to_v" |
| elif "attn_add_q_proj" in remaining: |
| diffusers_key += ".attn.add_q_proj" |
| elif "attn_add_k_proj" in remaining: |
| diffusers_key += ".attn.add_k_proj" |
| elif "attn_add_v_proj" in remaining: |
| diffusers_key += ".attn.add_v_proj" |
| elif "ff_net_0_proj" in remaining: |
| diffusers_key += ".ff.net.0.proj" |
| elif "ff_net_2" in remaining: |
| diffusers_key += ".ff.net.2" |
| elif "ff_context_net_0_proj" in remaining: |
| diffusers_key += ".ff_context.net.0.proj" |
| elif "ff_context_net_2" in remaining: |
| diffusers_key += ".ff_context.net.2" |
| elif "norm1_linear" in remaining: |
| diffusers_key += ".norm1.linear" |
| elif "norm1_context_linear" in remaining: |
| diffusers_key += ".norm1_context.linear" |
| else: |
| print(f"Warning: Unhandled double block key pattern: {original_key}") |
| continue |
| |
| elif original_key.startswith("lora_te1_") or original_key.startswith("lora_te_"): |
| |
| print(f"Found text encoder key: {original_key}") |
| continue |
| |
| else: |
| print(f"Warning: Unknown key pattern: {original_key}") |
| continue |
| |
| |
| _convert(original_key, diffusers_key, state_dict, new_state_dict) |
| |
| |
| transformer_state_dict = { |
| f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.") |
| } |
| |
| |
| if len(state_dict) > 0: |
| remaining_keys = [k for k in state_dict.keys() if not k.startswith("lora_te")] |
| if remaining_keys: |
| print(f"Warning: Some keys were not converted: {remaining_keys[:10]}") |
| |
| return transformer_state_dict |
|
|
|
|
| def convert_lora(input_path: str, output_path: str) -> None: |
| """ |
| Main conversion function. |
| |
| Args: |
| input_path: Path to input LoRA safetensors file |
| output_path: Path to output diffusers-compatible safetensors file |
| """ |
| print(f"Loading LoRA from: {input_path}") |
| state_dict = safetensors.torch.load_file(input_path) |
| |
| print(f"Detecting LoRA format...") |
| format_type = detect_lora_format(state_dict) |
| print(f"Detected format: {format_type}") |
| |
| if format_type == "peft": |
| print("LoRA is already in diffusers-compatible PEFT format!") |
| print("No conversion needed. Copying file...") |
| import shutil |
| shutil.copy(input_path, output_path) |
| return |
| |
| elif format_type == "mixed": |
| print("Converting MIXED format LoRA to pure PEFT format...") |
| print("(Some layers use lora_A/B, others use .lora.down/.lora.up)") |
| converted_state_dict = convert_mixed_lora_to_diffusers(state_dict.copy()) |
| |
| elif format_type == "simpletuner_transformer": |
| print("Converting SimpleTuner transformer format to diffusers...") |
| print("(has transformer. prefix but uses lora_down/lora_up naming)") |
| converted_state_dict = convert_simpletuner_transformer_to_diffusers(state_dict.copy()) |
| |
| elif format_type == "simpletuner_auraflow": |
| print("Converting SimpleTuner AuraFlow format to diffusers...") |
| converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy()) |
| |
| elif format_type == "kohya": |
| print("Note: Detected Kohya format. This converter is optimized for AuraFlow.") |
| print("For other models, diffusers has built-in conversion.") |
| converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy()) |
| |
| else: |
| print("Error: Unknown LoRA format!") |
| print("Sample keys from the state dict:") |
| for i, key in enumerate(list(state_dict.keys())[:20]): |
| print(f" {key}") |
| sys.exit(1) |
| |
| print(f"Saving converted LoRA to: {output_path}") |
| safetensors.torch.save_file(converted_state_dict, output_path) |
| |
| print("\nConversion complete!") |
| print(f"Original keys: {len(state_dict)}") |
| print(f"Converted keys: {len(converted_state_dict)}") |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Convert SimpleTuner LoRA to diffusers-compatible format", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| # Convert a SimpleTuner LoRA for AuraFlow |
| python convert_simpletuner_lora.py my_lora.safetensors diffusers_lora.safetensors |
| |
| # Check format without converting |
| python convert_simpletuner_lora.py my_lora.safetensors /tmp/test.safetensors |
| """ |
| ) |
| |
| parser.add_argument( |
| "input", |
| type=str, |
| help="Input LoRA file (SimpleTuner format)" |
| ) |
| |
| parser.add_argument( |
| "output", |
| type=str, |
| help="Output LoRA file (diffusers-compatible format)" |
| ) |
| |
| parser.add_argument( |
| "--dry-run", |
| action="store_true", |
| help="Only detect format, don't convert" |
| ) |
| |
| args = parser.parse_args() |
| |
| |
| if not Path(args.input).exists(): |
| print(f"Error: Input file not found: {args.input}") |
| sys.exit(1) |
| |
| if args.dry_run: |
| print(f"Loading LoRA from: {args.input}") |
| state_dict = safetensors.torch.load_file(args.input) |
| format_type = detect_lora_format(state_dict) |
| print(f"Detected format: {format_type}") |
| print(f"\nSample keys ({min(10, len(state_dict))} of {len(state_dict)}):") |
| for key in list(state_dict.keys())[:10]: |
| print(f" {key}") |
| return |
| |
| convert_lora(args.input, args.output) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|