flux.dev.1 / flux_space_model_manager.py
annoyingpixel's picture
Upload flux_space_model_manager.py with huggingface_hub
3ff1616 verified
#!/usr/bin/env python3
"""
Model Manager for FLUX.1 Space - Handles multiple models and LoRA integration
"""
import torch
from diffusers import DiffusionPipeline
from safetensors.torch import load_file
import os
from typing import Dict, Optional, Tuple
class FluxModelManager:
"""
Manages multiple FLUX models and LoRA integration
"""
def __init__(self):
self.models = {
'flux1-dev': {
'repo_id': 'black-forest-labs/FLUX.1-dev',
'model_id': 'black-forest-labs/FLUX.1-dev',
'description': 'Original FLUX.1-dev model'
},
'flux1-krea': {
'repo_id': 'black-forest-labs/FLUX.1-Krea-dev',
'model_id': 'black-forest-labs/FLUX.1-Krea-dev',
'description': 'FLUX.1-Krea-dev model'
},
'merged': {
'repo_id': 'local/merged_krea_55_flux_45_complete',
'model_id': 'local/merged_krea_55_flux_45_complete',
'description': 'Merged Krea 55% + FLUX 45% model'
}
}
# Pre-loaded LoRA models from annoyingpixel
self.preloaded_loras = {
'T11-Ultra-Portrait-E04': {
'repo_id': 'annoyingpixel/T11-Ultra-Portrait.E04.Lora.TA',
'model_id': 'annoyingpixel/T11-Ultra-Portrait.E04.Lora.TA',
'description': 'Ultra Portrait LoRA for FLUX.1-dev',
'trigger_words': 'T11-Ultra-Portrait-E04',
'base_model': 'black-forest-labs/FLUX.1-dev'
}
}
self.current_model = None
self.current_pipeline = None
self.loaded_loras = {}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(self, model_name: str) -> bool:
"""
Load a specific FLUX model
"""
if model_name not in self.models:
print(f"❌ Model '{model_name}' not found")
return False
try:
print(f"πŸ”„ Loading model: {model_name}")
model_info = self.models[model_name]
# Load the pipeline
self.current_pipeline = DiffusionPipeline.from_pretrained(
model_info['model_id'],
torch_dtype=torch.bfloat16,
use_safetensors=True
)
if self.device == "cuda":
self.current_pipeline = self.current_pipeline.to(self.device)
self.current_pipeline.enable_attention_slicing()
self.current_pipeline.enable_vae_slicing()
self.current_model = model_name
print(f"βœ… Model '{model_name}' loaded successfully")
return True
except Exception as e:
print(f"❌ Error loading model '{model_name}': {e}")
return False
def load_lora(self, lora_path: str, lora_name: str = None, strength: float = 1.0) -> bool:
"""
Load and apply a LoRA to the current model
"""
if self.current_pipeline is None:
print("❌ No model loaded. Load a model first.")
return False
try:
print(f"πŸ”„ Loading LoRA: {lora_path}")
# Load LoRA weights
lora_state_dict = load_file(lora_path)
# Apply LoRA to the pipeline
self.current_pipeline.load_lora_weights(
lora_state_dict,
weight_name=lora_name,
adapter_name=lora_name or "default"
)
# Store LoRA info
lora_name = lora_name or os.path.basename(lora_path)
self.loaded_loras[lora_name] = {
'path': lora_path,
'strength': strength,
'state_dict': lora_state_dict
}
print(f"βœ… LoRA '{lora_name}' loaded with strength {strength}")
return True
except Exception as e:
print(f"❌ Error loading LoRA: {e}")
return False
def unload_lora(self, lora_name: str) -> bool:
"""
Unload a specific LoRA
"""
if lora_name in self.loaded_loras:
try:
# Remove LoRA from pipeline
self.current_pipeline.unload_lora_weights(lora_name)
del self.loaded_loras[lora_name]
print(f"βœ… LoRA '{lora_name}' unloaded")
return True
except Exception as e:
print(f"❌ Error unloading LoRA: {e}")
return False
else:
print(f"❌ LoRA '{lora_name}' not found")
return False
def unload_all_loras(self) -> bool:
"""
Unload all LoRAs
"""
try:
for lora_name in list(self.loaded_loras.keys()):
self.unload_lora(lora_name)
print("βœ… All LoRAs unloaded")
return True
except Exception as e:
print(f"❌ Error unloading LoRAs: {e}")
return False
def get_model_info(self) -> Dict:
"""
Get information about the current model and loaded LoRAs
"""
info = {
'current_model': self.current_model,
'model_description': self.models.get(self.current_model, {}).get('description', 'Unknown'),
'device': self.device,
'loaded_loras': list(self.loaded_loras.keys()),
'available_models': list(self.models.keys())
}
return info
def generate_image(self, prompt: str, negative_prompt: str = "",
num_inference_steps: int = 50, guidance_scale: float = 7.5,
width: int = 1024, height: int = 1024, seed: int = None) -> Tuple[torch.Tensor, Dict]:
"""
Generate an image with the current model and LoRAs
"""
if self.current_pipeline is None:
raise ValueError("No model loaded. Load a model first.")
# Set seed if provided
if seed is not None:
torch.manual_seed(seed)
# Generate image
result = self.current_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
output_type="pt"
)
# Get generation info
generation_info = {
'model': self.current_model,
'loras': list(self.loaded_loras.keys()),
'prompt': prompt,
'negative_prompt': negative_prompt,
'steps': num_inference_steps,
'guidance_scale': guidance_scale,
'seed': seed
}
return result.images[0], generation_info
def load_preloaded_lora(self, lora_name: str, strength: float = 1.0) -> bool:
"""
Load a pre-loaded LoRA directly from Hugging Face
"""
if lora_name not in self.preloaded_loras:
print(f"❌ Pre-loaded LoRA '{lora_name}' not found")
return False
if self.current_pipeline is None:
print("❌ No model loaded. Load a model first.")
return False
try:
print(f"πŸ”„ Loading pre-loaded LoRA: {lora_name}")
lora_info = self.preloaded_loras[lora_name]
# Load LoRA directly from HF repository
self.current_pipeline.load_lora_weights(
lora_info['model_id'],
weight_name="default",
adapter_name=lora_name
)
# Store LoRA info
self.loaded_loras[lora_name] = {
'path': lora_info['model_id'],
'strength': strength,
'trigger_words': lora_info.get('trigger_words', ''),
'description': lora_info['description']
}
print(f"βœ… Pre-loaded LoRA '{lora_name}' loaded with strength {strength}")
print(f"πŸ“ Trigger words: {lora_info.get('trigger_words', 'None')}")
return True
except Exception as e:
print(f"❌ Error loading pre-loaded LoRA: {e}")
return False
def get_preloaded_loras(self) -> Dict:
"""
Get information about available pre-loaded LoRAs
"""
return self.preloaded_loras
# Example usage for Gradio integration
def create_model_manager():
"""
Create and return a model manager instance
"""
return FluxModelManager()
def get_model_options():
"""
Get list of available models for dropdown
"""
manager = FluxModelManager()
return list(manager.models.keys())
def get_model_descriptions():
"""
Get model descriptions for UI
"""
manager = FluxModelManager()
return {name: info['description'] for name, info in manager.models.items()}