| |
| """ |
| LoRA Manager for FLUX.1 Space - Advanced LoRA handling and integration |
| """ |
|
|
| import torch |
| from safetensors.torch import load_file, save_file |
| import os |
| import json |
| from typing import Dict, List, Optional, Tuple |
| import numpy as np |
|
|
| class FluxLoRAManager: |
| """ |
| Advanced LoRA manager for FLUX models |
| """ |
| |
| def __init__(self): |
| self.loaded_loras = {} |
| self.lora_metadata = {} |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| def load_lora_file(self, lora_path: str, lora_name: str = None) -> Dict: |
| """ |
| Load a LoRA file and extract metadata |
| """ |
| try: |
| print(f"π Loading LoRA file: {lora_path}") |
| |
| |
| lora_state_dict = load_file(lora_path) |
| |
| |
| metadata = self._extract_lora_metadata(lora_path, lora_state_dict) |
| |
| |
| if lora_name is None: |
| lora_name = os.path.splitext(os.path.basename(lora_path))[0] |
| |
| |
| self.loaded_loras[lora_name] = { |
| 'path': lora_path, |
| 'state_dict': lora_state_dict, |
| 'metadata': metadata, |
| 'strength': 1.0, |
| 'active': False |
| } |
| |
| print(f"β
LoRA '{lora_name}' loaded successfully") |
| print(f"π Metadata: {metadata}") |
| |
| return { |
| 'name': lora_name, |
| 'metadata': metadata, |
| 'success': True |
| } |
| |
| except Exception as e: |
| print(f"β Error loading LoRA: {e}") |
| return { |
| 'name': lora_name, |
| 'error': str(e), |
| 'success': False |
| } |
| |
| def _extract_lora_metadata(self, lora_path: str, state_dict: Dict) -> Dict: |
| """ |
| Extract metadata from LoRA file |
| """ |
| metadata = { |
| 'filename': os.path.basename(lora_path), |
| 'file_size_mb': os.path.getsize(lora_path) / (1024 * 1024), |
| 'tensor_count': len(state_dict), |
| 'tensor_names': list(state_dict.keys()), |
| 'base_model': 'unknown', |
| 'training_info': {} |
| } |
| |
| |
| json_path = lora_path.replace('.safetensors', '.json') |
| if os.path.exists(json_path): |
| try: |
| with open(json_path, 'r') as f: |
| json_metadata = json.load(f) |
| metadata.update(json_metadata) |
| except: |
| pass |
| |
| |
| if any('double_blocks' in key for key in state_dict.keys()): |
| metadata['base_model'] = 'FLUX' |
| elif any('unet' in key for key in state_dict.keys()): |
| metadata['base_model'] = 'Stable Diffusion' |
| |
| return metadata |
| |
| def apply_lora_to_model(self, lora_name: str, model_pipeline, strength: float = 1.0) -> bool: |
| """ |
| Apply a LoRA to a model pipeline |
| """ |
| if lora_name not in self.loaded_loras: |
| print(f"β LoRA '{lora_name}' not loaded") |
| return False |
| |
| try: |
| print(f"π Applying LoRA '{lora_name}' with strength {strength}") |
| |
| lora_data = self.loaded_loras[lora_name] |
| state_dict = lora_data['state_dict'] |
| |
| |
| for key, value in state_dict.items(): |
| if key in model_pipeline.state_dict(): |
| |
| scaled_value = value * strength |
| model_pipeline.state_dict()[key].copy_(scaled_value) |
| |
| |
| lora_data['strength'] = strength |
| lora_data['active'] = True |
| |
| print(f"β
LoRA '{lora_name}' applied successfully") |
| return True |
| |
| except Exception as e: |
| print(f"β Error applying LoRA: {e}") |
| return False |
| |
| def remove_lora_from_model(self, lora_name: str, model_pipeline) -> bool: |
| """ |
| Remove a LoRA from a model pipeline |
| """ |
| if lora_name not in self.loaded_loras: |
| print(f"β LoRA '{lora_name}' not loaded") |
| return False |
| |
| try: |
| print(f"π Removing LoRA '{lora_name}'") |
| |
| lora_data = self.loaded_loras[lora_name] |
| state_dict = lora_data['state_dict'] |
| |
| |
| for key, value in state_dict.items(): |
| if key in model_pipeline.state_dict(): |
| model_pipeline.state_dict()[key].zero_() |
| |
| |
| lora_data['active'] = False |
| |
| print(f"β
LoRA '{lora_name}' removed successfully") |
| return True |
| |
| except Exception as e: |
| print(f"β Error removing LoRA: {e}") |
| return False |
| |
| def blend_loras(self, lora_names: List[str], weights: List[float]) -> Dict: |
| """ |
| Blend multiple LoRAs with specified weights |
| """ |
| if len(lora_names) != len(weights): |
| print("β Number of LoRAs and weights must match") |
| return {'success': False, 'error': 'Mismatched arrays'} |
| |
| try: |
| print(f"π Blending LoRAs: {lora_names}") |
| print(f"π Weights: {weights}") |
| |
| |
| total_weight = sum(weights) |
| normalized_weights = [w / total_weight for w in weights] |
| |
| |
| all_keys = set() |
| for lora_name in lora_names: |
| if lora_name in self.loaded_loras: |
| all_keys.update(self.loaded_loras[lora_name]['state_dict'].keys()) |
| |
| |
| blended_state_dict = {} |
| for key in all_keys: |
| blended_tensor = None |
| for lora_name, weight in zip(lora_names, normalized_weights): |
| if lora_name in self.loaded_loras: |
| lora_state_dict = self.loaded_loras[lora_name]['state_dict'] |
| if key in lora_state_dict: |
| if blended_tensor is None: |
| blended_tensor = lora_state_dict[key] * weight |
| else: |
| blended_tensor += lora_state_dict[key] * weight |
| |
| if blended_tensor is not None: |
| blended_state_dict[key] = blended_tensor |
| |
| |
| blended_name = f"blended_{'_'.join(lora_names)}" |
| |
| |
| self.loaded_loras[blended_name] = { |
| 'path': 'blended', |
| 'state_dict': blended_state_dict, |
| 'metadata': { |
| 'blended_from': lora_names, |
| 'weights': normalized_weights, |
| 'base_model': 'FLUX' |
| }, |
| 'strength': 1.0, |
| 'active': False |
| } |
| |
| print(f"β
Blended LoRA '{blended_name}' created successfully") |
| return { |
| 'success': True, |
| 'blended_name': blended_name, |
| 'tensor_count': len(blended_state_dict) |
| } |
| |
| except Exception as e: |
| print(f"β Error blending LoRAs: {e}") |
| return {'success': False, 'error': str(e)} |
| |
| def get_lora_info(self, lora_name: str) -> Dict: |
| """ |
| Get detailed information about a loaded LoRA |
| """ |
| if lora_name not in self.loaded_loras: |
| return {'error': f"LoRA '{lora_name}' not found"} |
| |
| lora_data = self.loaded_loras[lora_name] |
| return { |
| 'name': lora_name, |
| 'path': lora_data['path'], |
| 'active': lora_data['active'], |
| 'strength': lora_data['strength'], |
| 'metadata': lora_data['metadata'] |
| } |
| |
| def get_all_loras_info(self) -> List[Dict]: |
| """ |
| Get information about all loaded LoRAs |
| """ |
| return [self.get_lora_info(name) for name in self.loaded_loras.keys()] |
| |
| def save_blended_lora(self, blended_name: str, output_path: str) -> bool: |
| """ |
| Save a blended LoRA to file |
| """ |
| if blended_name not in self.loaded_loras: |
| print(f"β Blended LoRA '{blended_name}' not found") |
| return False |
| |
| try: |
| print(f"πΎ Saving blended LoRA to: {output_path}") |
| |
| lora_data = self.loaded_loras[blended_name] |
| state_dict = lora_data['state_dict'] |
| metadata = lora_data['metadata'] |
| |
| |
| save_file(state_dict, output_path, metadata=metadata) |
| |
| |
| json_path = output_path.replace('.safetensors', '.json') |
| with open(json_path, 'w') as f: |
| json.dump(metadata, f, indent=2) |
| |
| print(f"β
Blended LoRA saved successfully") |
| return True |
| |
| except Exception as e: |
| print(f"β Error saving blended LoRA: {e}") |
| return False |
|
|
| |
| def create_lora_manager(): |
| """ |
| Create and return a LoRA manager instance |
| """ |
| return FluxLoRAManager() |
|
|
| def validate_lora_file(file_path: str) -> Dict: |
| """ |
| Validate a LoRA file before loading |
| """ |
| try: |
| if not os.path.exists(file_path): |
| return {'valid': False, 'error': 'File not found'} |
| |
| if not file_path.endswith('.safetensors'): |
| return {'valid': False, 'error': 'File must be .safetensors format'} |
| |
| |
| state_dict = load_file(file_path) |
| |
| return { |
| 'valid': True, |
| 'tensor_count': len(state_dict), |
| 'file_size_mb': os.path.getsize(file_path) / (1024 * 1024) |
| } |
| |
| except Exception as e: |
| return {'valid': False, 'error': str(e)} |
|
|