flux.dev.1 / flux_space_lora_manager.py
annoyingpixel's picture
Upload flux_space_lora_manager.py with huggingface_hub
a4355c4 verified
#!/usr/bin/env python3
"""
LoRA Manager for FLUX.1 Space - Advanced LoRA handling and integration
"""
import torch
from safetensors.torch import load_file, save_file
import os
import json
from typing import Dict, List, Optional, Tuple
import numpy as np
class FluxLoRAManager:
"""
Advanced LoRA manager for FLUX models
"""
def __init__(self):
self.loaded_loras = {}
self.lora_metadata = {}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load_lora_file(self, lora_path: str, lora_name: str = None) -> Dict:
"""
Load a LoRA file and extract metadata
"""
try:
print(f"πŸ”„ Loading LoRA file: {lora_path}")
# Load LoRA weights
lora_state_dict = load_file(lora_path)
# Extract metadata
metadata = self._extract_lora_metadata(lora_path, lora_state_dict)
# Generate name if not provided
if lora_name is None:
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
# Store LoRA
self.loaded_loras[lora_name] = {
'path': lora_path,
'state_dict': lora_state_dict,
'metadata': metadata,
'strength': 1.0,
'active': False
}
print(f"βœ… LoRA '{lora_name}' loaded successfully")
print(f"πŸ“Š Metadata: {metadata}")
return {
'name': lora_name,
'metadata': metadata,
'success': True
}
except Exception as e:
print(f"❌ Error loading LoRA: {e}")
return {
'name': lora_name,
'error': str(e),
'success': False
}
def _extract_lora_metadata(self, lora_path: str, state_dict: Dict) -> Dict:
"""
Extract metadata from LoRA file
"""
metadata = {
'filename': os.path.basename(lora_path),
'file_size_mb': os.path.getsize(lora_path) / (1024 * 1024),
'tensor_count': len(state_dict),
'tensor_names': list(state_dict.keys()),
'base_model': 'unknown',
'training_info': {}
}
# Try to load JSON metadata if it exists
json_path = lora_path.replace('.safetensors', '.json')
if os.path.exists(json_path):
try:
with open(json_path, 'r') as f:
json_metadata = json.load(f)
metadata.update(json_metadata)
except:
pass
# Analyze tensor structure to determine base model
if any('double_blocks' in key for key in state_dict.keys()):
metadata['base_model'] = 'FLUX'
elif any('unet' in key for key in state_dict.keys()):
metadata['base_model'] = 'Stable Diffusion'
return metadata
def apply_lora_to_model(self, lora_name: str, model_pipeline, strength: float = 1.0) -> bool:
"""
Apply a LoRA to a model pipeline
"""
if lora_name not in self.loaded_loras:
print(f"❌ LoRA '{lora_name}' not loaded")
return False
try:
print(f"πŸ”„ Applying LoRA '{lora_name}' with strength {strength}")
lora_data = self.loaded_loras[lora_name]
state_dict = lora_data['state_dict']
# Apply LoRA weights with strength
for key, value in state_dict.items():
if key in model_pipeline.state_dict():
# Scale the LoRA weights by strength
scaled_value = value * strength
model_pipeline.state_dict()[key].copy_(scaled_value)
# Update LoRA status
lora_data['strength'] = strength
lora_data['active'] = True
print(f"βœ… LoRA '{lora_name}' applied successfully")
return True
except Exception as e:
print(f"❌ Error applying LoRA: {e}")
return False
def remove_lora_from_model(self, lora_name: str, model_pipeline) -> bool:
"""
Remove a LoRA from a model pipeline
"""
if lora_name not in self.loaded_loras:
print(f"❌ LoRA '{lora_name}' not loaded")
return False
try:
print(f"πŸ”„ Removing LoRA '{lora_name}'")
lora_data = self.loaded_loras[lora_name]
state_dict = lora_data['state_dict']
# Remove LoRA weights (set to zero)
for key, value in state_dict.items():
if key in model_pipeline.state_dict():
model_pipeline.state_dict()[key].zero_()
# Update LoRA status
lora_data['active'] = False
print(f"βœ… LoRA '{lora_name}' removed successfully")
return True
except Exception as e:
print(f"❌ Error removing LoRA: {e}")
return False
def blend_loras(self, lora_names: List[str], weights: List[float]) -> Dict:
"""
Blend multiple LoRAs with specified weights
"""
if len(lora_names) != len(weights):
print("❌ Number of LoRAs and weights must match")
return {'success': False, 'error': 'Mismatched arrays'}
try:
print(f"πŸ”„ Blending LoRAs: {lora_names}")
print(f"πŸ“Š Weights: {weights}")
# Normalize weights
total_weight = sum(weights)
normalized_weights = [w / total_weight for w in weights]
# Get all unique tensor keys
all_keys = set()
for lora_name in lora_names:
if lora_name in self.loaded_loras:
all_keys.update(self.loaded_loras[lora_name]['state_dict'].keys())
# Create blended state dict
blended_state_dict = {}
for key in all_keys:
blended_tensor = None
for lora_name, weight in zip(lora_names, normalized_weights):
if lora_name in self.loaded_loras:
lora_state_dict = self.loaded_loras[lora_name]['state_dict']
if key in lora_state_dict:
if blended_tensor is None:
blended_tensor = lora_state_dict[key] * weight
else:
blended_tensor += lora_state_dict[key] * weight
if blended_tensor is not None:
blended_state_dict[key] = blended_tensor
# Create blended LoRA name
blended_name = f"blended_{'_'.join(lora_names)}"
# Store blended LoRA
self.loaded_loras[blended_name] = {
'path': 'blended',
'state_dict': blended_state_dict,
'metadata': {
'blended_from': lora_names,
'weights': normalized_weights,
'base_model': 'FLUX'
},
'strength': 1.0,
'active': False
}
print(f"βœ… Blended LoRA '{blended_name}' created successfully")
return {
'success': True,
'blended_name': blended_name,
'tensor_count': len(blended_state_dict)
}
except Exception as e:
print(f"❌ Error blending LoRAs: {e}")
return {'success': False, 'error': str(e)}
def get_lora_info(self, lora_name: str) -> Dict:
"""
Get detailed information about a loaded LoRA
"""
if lora_name not in self.loaded_loras:
return {'error': f"LoRA '{lora_name}' not found"}
lora_data = self.loaded_loras[lora_name]
return {
'name': lora_name,
'path': lora_data['path'],
'active': lora_data['active'],
'strength': lora_data['strength'],
'metadata': lora_data['metadata']
}
def get_all_loras_info(self) -> List[Dict]:
"""
Get information about all loaded LoRAs
"""
return [self.get_lora_info(name) for name in self.loaded_loras.keys()]
def save_blended_lora(self, blended_name: str, output_path: str) -> bool:
"""
Save a blended LoRA to file
"""
if blended_name not in self.loaded_loras:
print(f"❌ Blended LoRA '{blended_name}' not found")
return False
try:
print(f"πŸ’Ύ Saving blended LoRA to: {output_path}")
lora_data = self.loaded_loras[blended_name]
state_dict = lora_data['state_dict']
metadata = lora_data['metadata']
# Save safetensors file
save_file(state_dict, output_path, metadata=metadata)
# Save JSON metadata
json_path = output_path.replace('.safetensors', '.json')
with open(json_path, 'w') as f:
json.dump(metadata, f, indent=2)
print(f"βœ… Blended LoRA saved successfully")
return True
except Exception as e:
print(f"❌ Error saving blended LoRA: {e}")
return False
# Utility functions for Gradio integration
def create_lora_manager():
"""
Create and return a LoRA manager instance
"""
return FluxLoRAManager()
def validate_lora_file(file_path: str) -> Dict:
"""
Validate a LoRA file before loading
"""
try:
if not os.path.exists(file_path):
return {'valid': False, 'error': 'File not found'}
if not file_path.endswith('.safetensors'):
return {'valid': False, 'error': 'File must be .safetensors format'}
# Try to load the file
state_dict = load_file(file_path)
return {
'valid': True,
'tensor_count': len(state_dict),
'file_size_mb': os.path.getsize(file_path) / (1024 * 1024)
}
except Exception as e:
return {'valid': False, 'error': str(e)}