| |
| """ |
| Model Manager for FLUX.1 Space - Handles multiple models and LoRA integration |
| """ |
|
|
| import torch |
| from diffusers import DiffusionPipeline |
| from safetensors.torch import load_file |
| import os |
| from typing import Dict, Optional, Tuple |
|
|
| class FluxModelManager: |
| """ |
| Manages multiple FLUX models and LoRA integration |
| """ |
| |
| def __init__(self): |
| self.models = { |
| 'flux1-dev': { |
| 'repo_id': 'black-forest-labs/FLUX.1-dev', |
| 'model_id': 'black-forest-labs/FLUX.1-dev', |
| 'description': 'Original FLUX.1-dev model' |
| }, |
| 'flux1-krea': { |
| 'repo_id': 'black-forest-labs/FLUX.1-Krea-dev', |
| 'model_id': 'black-forest-labs/FLUX.1-Krea-dev', |
| 'description': 'FLUX.1-Krea-dev model' |
| }, |
| 'merged': { |
| 'repo_id': 'local/merged_krea_55_flux_45_complete', |
| 'model_id': 'local/merged_krea_55_flux_45_complete', |
| 'description': 'Merged Krea 55% + FLUX 45% model' |
| } |
| } |
| |
| |
| self.preloaded_loras = { |
| 'T11-Ultra-Portrait-E04': { |
| 'repo_id': 'annoyingpixel/T11-Ultra-Portrait.E04.Lora.TA', |
| 'model_id': 'annoyingpixel/T11-Ultra-Portrait.E04.Lora.TA', |
| 'description': 'Ultra Portrait LoRA for FLUX.1-dev', |
| 'trigger_words': 'T11-Ultra-Portrait-E04', |
| 'base_model': 'black-forest-labs/FLUX.1-dev' |
| } |
| } |
| |
| self.current_model = None |
| self.current_pipeline = None |
| self.loaded_loras = {} |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| def load_model(self, model_name: str) -> bool: |
| """ |
| Load a specific FLUX model |
| """ |
| if model_name not in self.models: |
| print(f"β Model '{model_name}' not found") |
| return False |
| |
| try: |
| print(f"π Loading model: {model_name}") |
| model_info = self.models[model_name] |
| |
| |
| self.current_pipeline = DiffusionPipeline.from_pretrained( |
| model_info['model_id'], |
| torch_dtype=torch.bfloat16, |
| use_safetensors=True |
| ) |
| |
| if self.device == "cuda": |
| self.current_pipeline = self.current_pipeline.to(self.device) |
| self.current_pipeline.enable_attention_slicing() |
| self.current_pipeline.enable_vae_slicing() |
| |
| self.current_model = model_name |
| print(f"β
Model '{model_name}' loaded successfully") |
| return True |
| |
| except Exception as e: |
| print(f"β Error loading model '{model_name}': {e}") |
| return False |
| |
| def load_lora(self, lora_path: str, lora_name: str = None, strength: float = 1.0) -> bool: |
| """ |
| Load and apply a LoRA to the current model |
| """ |
| if self.current_pipeline is None: |
| print("β No model loaded. Load a model first.") |
| return False |
| |
| try: |
| print(f"π Loading LoRA: {lora_path}") |
| |
| |
| lora_state_dict = load_file(lora_path) |
| |
| |
| self.current_pipeline.load_lora_weights( |
| lora_state_dict, |
| weight_name=lora_name, |
| adapter_name=lora_name or "default" |
| ) |
| |
| |
| lora_name = lora_name or os.path.basename(lora_path) |
| self.loaded_loras[lora_name] = { |
| 'path': lora_path, |
| 'strength': strength, |
| 'state_dict': lora_state_dict |
| } |
| |
| print(f"β
LoRA '{lora_name}' loaded with strength {strength}") |
| return True |
| |
| except Exception as e: |
| print(f"β Error loading LoRA: {e}") |
| return False |
| |
| def unload_lora(self, lora_name: str) -> bool: |
| """ |
| Unload a specific LoRA |
| """ |
| if lora_name in self.loaded_loras: |
| try: |
| |
| self.current_pipeline.unload_lora_weights(lora_name) |
| del self.loaded_loras[lora_name] |
| print(f"β
LoRA '{lora_name}' unloaded") |
| return True |
| except Exception as e: |
| print(f"β Error unloading LoRA: {e}") |
| return False |
| else: |
| print(f"β LoRA '{lora_name}' not found") |
| return False |
| |
| def unload_all_loras(self) -> bool: |
| """ |
| Unload all LoRAs |
| """ |
| try: |
| for lora_name in list(self.loaded_loras.keys()): |
| self.unload_lora(lora_name) |
| print("β
All LoRAs unloaded") |
| return True |
| except Exception as e: |
| print(f"β Error unloading LoRAs: {e}") |
| return False |
| |
| def get_model_info(self) -> Dict: |
| """ |
| Get information about the current model and loaded LoRAs |
| """ |
| info = { |
| 'current_model': self.current_model, |
| 'model_description': self.models.get(self.current_model, {}).get('description', 'Unknown'), |
| 'device': self.device, |
| 'loaded_loras': list(self.loaded_loras.keys()), |
| 'available_models': list(self.models.keys()) |
| } |
| return info |
| |
| def generate_image(self, prompt: str, negative_prompt: str = "", |
| num_inference_steps: int = 50, guidance_scale: float = 7.5, |
| width: int = 1024, height: int = 1024, seed: int = None) -> Tuple[torch.Tensor, Dict]: |
| """ |
| Generate an image with the current model and LoRAs |
| """ |
| if self.current_pipeline is None: |
| raise ValueError("No model loaded. Load a model first.") |
| |
| |
| if seed is not None: |
| torch.manual_seed(seed) |
| |
| |
| result = self.current_pipeline( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| width=width, |
| height=height, |
| output_type="pt" |
| ) |
| |
| |
| generation_info = { |
| 'model': self.current_model, |
| 'loras': list(self.loaded_loras.keys()), |
| 'prompt': prompt, |
| 'negative_prompt': negative_prompt, |
| 'steps': num_inference_steps, |
| 'guidance_scale': guidance_scale, |
| 'seed': seed |
| } |
| |
| return result.images[0], generation_info |
| |
| def load_preloaded_lora(self, lora_name: str, strength: float = 1.0) -> bool: |
| """ |
| Load a pre-loaded LoRA directly from Hugging Face |
| """ |
| if lora_name not in self.preloaded_loras: |
| print(f"β Pre-loaded LoRA '{lora_name}' not found") |
| return False |
| |
| if self.current_pipeline is None: |
| print("β No model loaded. Load a model first.") |
| return False |
| |
| try: |
| print(f"π Loading pre-loaded LoRA: {lora_name}") |
| lora_info = self.preloaded_loras[lora_name] |
| |
| |
| self.current_pipeline.load_lora_weights( |
| lora_info['model_id'], |
| weight_name="default", |
| adapter_name=lora_name |
| ) |
| |
| |
| self.loaded_loras[lora_name] = { |
| 'path': lora_info['model_id'], |
| 'strength': strength, |
| 'trigger_words': lora_info.get('trigger_words', ''), |
| 'description': lora_info['description'] |
| } |
| |
| print(f"β
Pre-loaded LoRA '{lora_name}' loaded with strength {strength}") |
| print(f"π Trigger words: {lora_info.get('trigger_words', 'None')}") |
| return True |
| |
| except Exception as e: |
| print(f"β Error loading pre-loaded LoRA: {e}") |
| return False |
| |
| def get_preloaded_loras(self) -> Dict: |
| """ |
| Get information about available pre-loaded LoRAs |
| """ |
| return self.preloaded_loras |
|
|
| |
| def create_model_manager(): |
| """ |
| Create and return a model manager instance |
| """ |
| return FluxModelManager() |
|
|
| def get_model_options(): |
| """ |
| Get list of available models for dropdown |
| """ |
| manager = FluxModelManager() |
| return list(manager.models.keys()) |
|
|
| def get_model_descriptions(): |
| """ |
| Get model descriptions for UI |
| """ |
| manager = FluxModelManager() |
| return {name: info['description'] for name, info in manager.models.items()} |
|
|