#!/usr/bin/env python3 """ Model Manager for FLUX.1 Space - Handles multiple models and LoRA integration """ import torch from diffusers import DiffusionPipeline from safetensors.torch import load_file import os from typing import Dict, Optional, Tuple class FluxModelManager: """ Manages multiple FLUX models and LoRA integration """ def __init__(self): self.models = { 'flux1-dev': { 'repo_id': 'black-forest-labs/FLUX.1-dev', 'model_id': 'black-forest-labs/FLUX.1-dev', 'description': 'Original FLUX.1-dev model' }, 'flux1-krea': { 'repo_id': 'black-forest-labs/FLUX.1-Krea-dev', 'model_id': 'black-forest-labs/FLUX.1-Krea-dev', 'description': 'FLUX.1-Krea-dev model' }, 'merged': { 'repo_id': 'local/merged_krea_55_flux_45_complete', 'model_id': 'local/merged_krea_55_flux_45_complete', 'description': 'Merged Krea 55% + FLUX 45% model' } } # Pre-loaded LoRA models from annoyingpixel self.preloaded_loras = { 'T11-Ultra-Portrait-E04': { 'repo_id': 'annoyingpixel/T11-Ultra-Portrait.E04.Lora.TA', 'model_id': 'annoyingpixel/T11-Ultra-Portrait.E04.Lora.TA', 'description': 'Ultra Portrait LoRA for FLUX.1-dev', 'trigger_words': 'T11-Ultra-Portrait-E04', 'base_model': 'black-forest-labs/FLUX.1-dev' } } self.current_model = None self.current_pipeline = None self.loaded_loras = {} self.device = "cuda" if torch.cuda.is_available() else "cpu" def load_model(self, model_name: str) -> bool: """ Load a specific FLUX model """ if model_name not in self.models: print(f"❌ Model '{model_name}' not found") return False try: print(f"🔄 Loading model: {model_name}") model_info = self.models[model_name] # Load the pipeline self.current_pipeline = DiffusionPipeline.from_pretrained( model_info['model_id'], torch_dtype=torch.bfloat16, use_safetensors=True ) if self.device == "cuda": self.current_pipeline = self.current_pipeline.to(self.device) self.current_pipeline.enable_attention_slicing() self.current_pipeline.enable_vae_slicing() self.current_model = model_name print(f"✅ Model '{model_name}' loaded successfully") return True except Exception as e: print(f"❌ Error loading model '{model_name}': {e}") return False def load_lora(self, lora_path: str, lora_name: str = None, strength: float = 1.0) -> bool: """ Load and apply a LoRA to the current model """ if self.current_pipeline is None: print("❌ No model loaded. Load a model first.") return False try: print(f"🔄 Loading LoRA: {lora_path}") # Load LoRA weights lora_state_dict = load_file(lora_path) # Apply LoRA to the pipeline self.current_pipeline.load_lora_weights( lora_state_dict, weight_name=lora_name, adapter_name=lora_name or "default" ) # Store LoRA info lora_name = lora_name or os.path.basename(lora_path) self.loaded_loras[lora_name] = { 'path': lora_path, 'strength': strength, 'state_dict': lora_state_dict } print(f"✅ LoRA '{lora_name}' loaded with strength {strength}") return True except Exception as e: print(f"❌ Error loading LoRA: {e}") return False def unload_lora(self, lora_name: str) -> bool: """ Unload a specific LoRA """ if lora_name in self.loaded_loras: try: # Remove LoRA from pipeline self.current_pipeline.unload_lora_weights(lora_name) del self.loaded_loras[lora_name] print(f"✅ LoRA '{lora_name}' unloaded") return True except Exception as e: print(f"❌ Error unloading LoRA: {e}") return False else: print(f"❌ LoRA '{lora_name}' not found") return False def unload_all_loras(self) -> bool: """ Unload all LoRAs """ try: for lora_name in list(self.loaded_loras.keys()): self.unload_lora(lora_name) print("✅ All LoRAs unloaded") return True except Exception as e: print(f"❌ Error unloading LoRAs: {e}") return False def get_model_info(self) -> Dict: """ Get information about the current model and loaded LoRAs """ info = { 'current_model': self.current_model, 'model_description': self.models.get(self.current_model, {}).get('description', 'Unknown'), 'device': self.device, 'loaded_loras': list(self.loaded_loras.keys()), 'available_models': list(self.models.keys()) } return info def generate_image(self, prompt: str, negative_prompt: str = "", num_inference_steps: int = 50, guidance_scale: float = 7.5, width: int = 1024, height: int = 1024, seed: int = None) -> Tuple[torch.Tensor, Dict]: """ Generate an image with the current model and LoRAs """ if self.current_pipeline is None: raise ValueError("No model loaded. Load a model first.") # Set seed if provided if seed is not None: torch.manual_seed(seed) # Generate image result = self.current_pipeline( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, width=width, height=height, output_type="pt" ) # Get generation info generation_info = { 'model': self.current_model, 'loras': list(self.loaded_loras.keys()), 'prompt': prompt, 'negative_prompt': negative_prompt, 'steps': num_inference_steps, 'guidance_scale': guidance_scale, 'seed': seed } return result.images[0], generation_info def load_preloaded_lora(self, lora_name: str, strength: float = 1.0) -> bool: """ Load a pre-loaded LoRA directly from Hugging Face """ if lora_name not in self.preloaded_loras: print(f"❌ Pre-loaded LoRA '{lora_name}' not found") return False if self.current_pipeline is None: print("❌ No model loaded. Load a model first.") return False try: print(f"🔄 Loading pre-loaded LoRA: {lora_name}") lora_info = self.preloaded_loras[lora_name] # Load LoRA directly from HF repository self.current_pipeline.load_lora_weights( lora_info['model_id'], weight_name="default", adapter_name=lora_name ) # Store LoRA info self.loaded_loras[lora_name] = { 'path': lora_info['model_id'], 'strength': strength, 'trigger_words': lora_info.get('trigger_words', ''), 'description': lora_info['description'] } print(f"✅ Pre-loaded LoRA '{lora_name}' loaded with strength {strength}") print(f"📝 Trigger words: {lora_info.get('trigger_words', 'None')}") return True except Exception as e: print(f"❌ Error loading pre-loaded LoRA: {e}") return False def get_preloaded_loras(self) -> Dict: """ Get information about available pre-loaded LoRAs """ return self.preloaded_loras # Example usage for Gradio integration def create_model_manager(): """ Create and return a model manager instance """ return FluxModelManager() def get_model_options(): """ Get list of available models for dropdown """ manager = FluxModelManager() return list(manager.models.keys()) def get_model_descriptions(): """ Get model descriptions for UI """ manager = FluxModelManager() return {name: info['description'] for name, info in manager.models.items()}