#!/usr/bin/env python3 """ LoRA Manager for FLUX.1 Space - Advanced LoRA handling and integration """ import torch from safetensors.torch import load_file, save_file import os import json from typing import Dict, List, Optional, Tuple import numpy as np class FluxLoRAManager: """ Advanced LoRA manager for FLUX models """ def __init__(self): self.loaded_loras = {} self.lora_metadata = {} self.device = "cuda" if torch.cuda.is_available() else "cpu" def load_lora_file(self, lora_path: str, lora_name: str = None) -> Dict: """ Load a LoRA file and extract metadata """ try: print(f"🔄 Loading LoRA file: {lora_path}") # Load LoRA weights lora_state_dict = load_file(lora_path) # Extract metadata metadata = self._extract_lora_metadata(lora_path, lora_state_dict) # Generate name if not provided if lora_name is None: lora_name = os.path.splitext(os.path.basename(lora_path))[0] # Store LoRA self.loaded_loras[lora_name] = { 'path': lora_path, 'state_dict': lora_state_dict, 'metadata': metadata, 'strength': 1.0, 'active': False } print(f"✅ LoRA '{lora_name}' loaded successfully") print(f"📊 Metadata: {metadata}") return { 'name': lora_name, 'metadata': metadata, 'success': True } except Exception as e: print(f"❌ Error loading LoRA: {e}") return { 'name': lora_name, 'error': str(e), 'success': False } def _extract_lora_metadata(self, lora_path: str, state_dict: Dict) -> Dict: """ Extract metadata from LoRA file """ metadata = { 'filename': os.path.basename(lora_path), 'file_size_mb': os.path.getsize(lora_path) / (1024 * 1024), 'tensor_count': len(state_dict), 'tensor_names': list(state_dict.keys()), 'base_model': 'unknown', 'training_info': {} } # Try to load JSON metadata if it exists json_path = lora_path.replace('.safetensors', '.json') if os.path.exists(json_path): try: with open(json_path, 'r') as f: json_metadata = json.load(f) metadata.update(json_metadata) except: pass # Analyze tensor structure to determine base model if any('double_blocks' in key for key in state_dict.keys()): metadata['base_model'] = 'FLUX' elif any('unet' in key for key in state_dict.keys()): metadata['base_model'] = 'Stable Diffusion' return metadata def apply_lora_to_model(self, lora_name: str, model_pipeline, strength: float = 1.0) -> bool: """ Apply a LoRA to a model pipeline """ if lora_name not in self.loaded_loras: print(f"❌ LoRA '{lora_name}' not loaded") return False try: print(f"🔄 Applying LoRA '{lora_name}' with strength {strength}") lora_data = self.loaded_loras[lora_name] state_dict = lora_data['state_dict'] # Apply LoRA weights with strength for key, value in state_dict.items(): if key in model_pipeline.state_dict(): # Scale the LoRA weights by strength scaled_value = value * strength model_pipeline.state_dict()[key].copy_(scaled_value) # Update LoRA status lora_data['strength'] = strength lora_data['active'] = True print(f"✅ LoRA '{lora_name}' applied successfully") return True except Exception as e: print(f"❌ Error applying LoRA: {e}") return False def remove_lora_from_model(self, lora_name: str, model_pipeline) -> bool: """ Remove a LoRA from a model pipeline """ if lora_name not in self.loaded_loras: print(f"❌ LoRA '{lora_name}' not loaded") return False try: print(f"🔄 Removing LoRA '{lora_name}'") lora_data = self.loaded_loras[lora_name] state_dict = lora_data['state_dict'] # Remove LoRA weights (set to zero) for key, value in state_dict.items(): if key in model_pipeline.state_dict(): model_pipeline.state_dict()[key].zero_() # Update LoRA status lora_data['active'] = False print(f"✅ LoRA '{lora_name}' removed successfully") return True except Exception as e: print(f"❌ Error removing LoRA: {e}") return False def blend_loras(self, lora_names: List[str], weights: List[float]) -> Dict: """ Blend multiple LoRAs with specified weights """ if len(lora_names) != len(weights): print("❌ Number of LoRAs and weights must match") return {'success': False, 'error': 'Mismatched arrays'} try: print(f"🔄 Blending LoRAs: {lora_names}") print(f"📊 Weights: {weights}") # Normalize weights total_weight = sum(weights) normalized_weights = [w / total_weight for w in weights] # Get all unique tensor keys all_keys = set() for lora_name in lora_names: if lora_name in self.loaded_loras: all_keys.update(self.loaded_loras[lora_name]['state_dict'].keys()) # Create blended state dict blended_state_dict = {} for key in all_keys: blended_tensor = None for lora_name, weight in zip(lora_names, normalized_weights): if lora_name in self.loaded_loras: lora_state_dict = self.loaded_loras[lora_name]['state_dict'] if key in lora_state_dict: if blended_tensor is None: blended_tensor = lora_state_dict[key] * weight else: blended_tensor += lora_state_dict[key] * weight if blended_tensor is not None: blended_state_dict[key] = blended_tensor # Create blended LoRA name blended_name = f"blended_{'_'.join(lora_names)}" # Store blended LoRA self.loaded_loras[blended_name] = { 'path': 'blended', 'state_dict': blended_state_dict, 'metadata': { 'blended_from': lora_names, 'weights': normalized_weights, 'base_model': 'FLUX' }, 'strength': 1.0, 'active': False } print(f"✅ Blended LoRA '{blended_name}' created successfully") return { 'success': True, 'blended_name': blended_name, 'tensor_count': len(blended_state_dict) } except Exception as e: print(f"❌ Error blending LoRAs: {e}") return {'success': False, 'error': str(e)} def get_lora_info(self, lora_name: str) -> Dict: """ Get detailed information about a loaded LoRA """ if lora_name not in self.loaded_loras: return {'error': f"LoRA '{lora_name}' not found"} lora_data = self.loaded_loras[lora_name] return { 'name': lora_name, 'path': lora_data['path'], 'active': lora_data['active'], 'strength': lora_data['strength'], 'metadata': lora_data['metadata'] } def get_all_loras_info(self) -> List[Dict]: """ Get information about all loaded LoRAs """ return [self.get_lora_info(name) for name in self.loaded_loras.keys()] def save_blended_lora(self, blended_name: str, output_path: str) -> bool: """ Save a blended LoRA to file """ if blended_name not in self.loaded_loras: print(f"❌ Blended LoRA '{blended_name}' not found") return False try: print(f"💾 Saving blended LoRA to: {output_path}") lora_data = self.loaded_loras[blended_name] state_dict = lora_data['state_dict'] metadata = lora_data['metadata'] # Save safetensors file save_file(state_dict, output_path, metadata=metadata) # Save JSON metadata json_path = output_path.replace('.safetensors', '.json') with open(json_path, 'w') as f: json.dump(metadata, f, indent=2) print(f"✅ Blended LoRA saved successfully") return True except Exception as e: print(f"❌ Error saving blended LoRA: {e}") return False # Utility functions for Gradio integration def create_lora_manager(): """ Create and return a LoRA manager instance """ return FluxLoRAManager() def validate_lora_file(file_path: str) -> Dict: """ Validate a LoRA file before loading """ try: if not os.path.exists(file_path): return {'valid': False, 'error': 'File not found'} if not file_path.endswith('.safetensors'): return {'valid': False, 'error': 'File must be .safetensors format'} # Try to load the file state_dict = load_file(file_path) return { 'valid': True, 'tensor_count': len(state_dict), 'file_size_mb': os.path.getsize(file_path) / (1024 * 1024) } except Exception as e: return {'valid': False, 'error': str(e)}