| |
| """ |
| FLUX.1 Space App Template - Enhanced with Model and LoRA Management |
| """ |
|
|
| import gradio as gr |
| import torch |
| import numpy as np |
| from PIL import Image |
| import os |
| import json |
| from typing import Dict, List, Optional |
|
|
| |
| from flux_space_model_manager import FluxModelManager |
| from flux_space_lora_manager import FluxLoRAManager |
|
|
| class FluxSpaceApp: |
| """ |
| Enhanced FLUX.1 Space application with model and LoRA management |
| """ |
| |
| def __init__(self): |
| self.model_manager = FluxModelManager() |
| self.lora_manager = FluxLoRAManager() |
| self.current_model = None |
| |
| def create_interface(self): |
| """ |
| Create the Gradio interface |
| """ |
| with gr.Blocks(title="FLUX.1 Enhanced Space", theme=gr.themes.Default()) as demo: |
| |
| |
| gr.Markdown(""" |
| # FLUX.1 Enhanced Space |
| **Multiple Models + LoRA Support** |
| |
| Choose your base model and load custom LoRAs for enhanced image generation. |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| |
| gr.Markdown("### Model Selection") |
| model_selector = gr.Dropdown( |
| choices=list(self.model_manager.models.keys()), |
| value="flux1-dev", |
| label="Base Model", |
| info="Select the base model for generation" |
| ) |
| |
| model_info = gr.Markdown("**Model Info:** Select a model to see details") |
| |
| |
| load_model_btn = gr.Button("Load Model", variant="primary") |
| |
| |
| model_status = gr.Markdown("**Status:** No model loaded") |
| |
| with gr.Column(scale=1): |
| |
| gr.Markdown("### LoRA Management") |
| |
| |
| gr.Markdown("#### Pre-loaded LoRAs") |
| preloaded_lora_selector = gr.Dropdown( |
| choices=["T11-Ultra-Portrait-E04"], |
| value=None, |
| label="Select Pre-loaded LoRA", |
| info="Load LoRAs directly from Hugging Face" |
| ) |
| |
| preloaded_lora_strength = gr.Slider( |
| minimum=0.0, |
| maximum=2.0, |
| value=1.0, |
| step=0.1, |
| label="Pre-loaded LoRA Strength" |
| ) |
| |
| load_preloaded_lora_btn = gr.Button("Load Pre-loaded LoRA", variant="secondary") |
| |
| |
| gr.Markdown("#### Custom LoRA Upload") |
| lora_upload = gr.File( |
| label="Upload LoRA (.safetensors)", |
| file_types=[".safetensors"], |
| file_count="single" |
| ) |
| |
| lora_name = gr.Textbox( |
| label="LoRA Name (optional)", |
| placeholder="Custom name for the LoRA" |
| ) |
| |
| lora_strength = gr.Slider( |
| minimum=0.0, |
| maximum=2.0, |
| value=1.0, |
| step=0.1, |
| label="Custom LoRA Strength" |
| ) |
| |
| with gr.Row(): |
| load_lora_btn = gr.Button("Load Custom LoRA", variant="secondary") |
| unload_lora_btn = gr.Button("Unload LoRA", variant="stop") |
| |
| |
| lora_status = gr.Markdown("**LoRAs:** None loaded") |
| |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
| gr.Markdown("### Generation") |
| |
| prompt = gr.Textbox( |
| label="Prompt", |
| placeholder="Enter your prompt here...", |
| lines=3 |
| ) |
| |
| negative_prompt = gr.Textbox( |
| label="Negative Prompt", |
| placeholder="Enter negative prompt...", |
| lines=2 |
| ) |
| |
| with gr.Row(): |
| with gr.Column(): |
| steps = gr.Slider( |
| minimum=10, |
| maximum=100, |
| value=50, |
| step=1, |
| label="Inference Steps" |
| ) |
| guidance_scale = gr.Slider( |
| minimum=1.0, |
| maximum=20.0, |
| value=7.5, |
| step=0.1, |
| label="Guidance Scale" |
| ) |
| |
| with gr.Column(): |
| width = gr.Slider( |
| minimum=512, |
| maximum=2048, |
| value=1024, |
| step=64, |
| label="Width" |
| ) |
| height = gr.Slider( |
| minimum=512, |
| maximum=2048, |
| value=1024, |
| step=64, |
| label="Height" |
| ) |
| |
| seed = gr.Number( |
| label="Seed", |
| value=-1, |
| info="Use -1 for random seed" |
| ) |
| |
| generate_btn = gr.Button("Generate Image", variant="primary", size="lg") |
| |
| with gr.Column(scale=1): |
| |
| gr.Markdown("### Advanced") |
| |
| |
| gr.Markdown("#### LoRA Blending") |
| |
| lora_list = gr.Dropdown( |
| choices=[], |
| label="Select LoRAs to Blend", |
| multiselect=True |
| ) |
| |
| blend_weights = gr.Textbox( |
| label="Blend Weights (comma-separated)", |
| placeholder="1.0, 0.5, 0.3", |
| info="Weights for each LoRA in order" |
| ) |
| |
| blend_btn = gr.Button("Blend LoRAs", variant="secondary") |
| |
| |
| gr.Markdown("#### Generation Info") |
| generation_info = gr.JSON(label="Last Generation Details") |
| |
| |
| with gr.Row(): |
| output_image = gr.Image( |
| label="Generated Image", |
| type="pil" |
| ) |
| |
| with gr.Column(): |
| gr.Markdown("### Generation Log") |
| generation_log = gr.Textbox( |
| label="Log", |
| lines=10, |
| max_lines=20, |
| interactive=False |
| ) |
| |
| |
| def load_model_handler(model_name): |
| """Handle model loading""" |
| try: |
| success = self.model_manager.load_model(model_name) |
| if success: |
| model_info = self.model_manager.get_model_info() |
| status_text = f"Model Loaded: {model_name}" |
| info_text = f""" |
| **Current Model:** {model_info['current_model']} |
| **Description:** {model_info['model_description']} |
| **Device:** {model_info['device']} |
| """ |
| self.current_model = model_name |
| else: |
| status_text = f"Failed to load: {model_name}" |
| info_text = "Error: Model loading failed" |
| |
| return status_text, info_text |
| |
| except Exception as e: |
| return f"Error: {str(e)}", "Error: Model loading failed" |
| |
| def load_preloaded_lora_handler(lora_name, strength): |
| """Handle pre-loaded LoRA loading""" |
| try: |
| if not lora_name: |
| return "Error: No LoRA selected", "LoRAs: None loaded", [] |
| |
| |
| success = self.model_manager.load_preloaded_lora(lora_name, strength) |
| |
| if success: |
| |
| lora_info = self.model_manager.get_preloaded_loras().get(lora_name, {}) |
| trigger_words = lora_info.get('trigger_words', '') |
| |
| status_text = f"Pre-loaded LoRA Loaded: {lora_name}" |
| if trigger_words: |
| status_text += f" (Trigger: {trigger_words})" |
| |
| lora_status_text = f"LoRAs: {lora_name}" |
| lora_list = [lora_name] |
| |
| return status_text, lora_status_text, lora_list |
| else: |
| return f"Error: Failed to load pre-loaded LoRA", "LoRAs: None loaded", [] |
| |
| except Exception as e: |
| return f"Error: {str(e)}", "LoRAs: None loaded", [] |
| |
| def load_lora_handler(file, name, strength): |
| """Handle custom LoRA loading""" |
| try: |
| if file is None: |
| return "Error: No file uploaded", "LoRAs: None loaded", [] |
| |
| file_path = file.name |
| lora_name = name if name else os.path.splitext(os.path.basename(file_path))[0] |
| |
| |
| result = self.lora_manager.load_lora_file(file_path, lora_name) |
| |
| if result['success']: |
| |
| if self.model_manager.current_pipeline is not None: |
| self.lora_manager.apply_lora_to_model( |
| lora_name, |
| self.model_manager.current_pipeline, |
| strength |
| ) |
| |
| |
| lora_list = list(self.lora_manager.loaded_loras.keys()) |
| |
| status_text = f"Custom LoRA Loaded: {lora_name}" |
| lora_status_text = f"LoRAs: {', '.join(lora_list)}" |
| |
| return status_text, lora_status_text, lora_list |
| else: |
| return f"Error: {result.get('error', 'Unknown error')}", "LoRAs: None loaded", [] |
| |
| except Exception as e: |
| return f"Error: {str(e)}", "LoRAs: None loaded", [] |
| |
| def generate_handler(prompt, negative_prompt, steps, guidance_scale, width, height, seed): |
| """Handle image generation""" |
| try: |
| if self.model_manager.current_pipeline is None: |
| return None, "Error: No model loaded", {} |
| |
| |
| if seed == -1: |
| seed = torch.randint(0, 2**32, (1,)).item() |
| |
| |
| image, gen_info = self.model_manager.generate_image( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| num_inference_steps=steps, |
| guidance_scale=guidance_scale, |
| width=width, |
| height=height, |
| seed=seed |
| ) |
| |
| |
| if isinstance(image, torch.Tensor): |
| image = image.cpu().numpy() |
| if image.shape[0] == 3: |
| image = np.transpose(image, (1, 2, 0)) |
| image = (image * 255).astype(np.uint8) |
| image = Image.fromarray(image) |
| |
| |
| log_entry = f""" |
| Generation Complete |
| Prompt: {prompt} |
| Negative: {negative_prompt} |
| Steps: {steps}, Guidance: {guidance_scale} |
| Size: {width}x{height} |
| Seed: {seed} |
| Model: {gen_info['model']} |
| LoRAs: {', '.join(gen_info['loras']) if gen_info['loras'] else 'None'} |
| """.strip() |
| |
| return image, log_entry, gen_info |
| |
| except Exception as e: |
| return None, f"Error: {str(e)}", {} |
| |
| |
| load_model_btn.click( |
| fn=load_model_handler, |
| inputs=[model_selector], |
| outputs=[model_status, model_info] |
| ) |
| |
| load_preloaded_lora_btn.click( |
| fn=load_preloaded_lora_handler, |
| inputs=[preloaded_lora_selector, preloaded_lora_strength], |
| outputs=[lora_status, lora_status, lora_list] |
| ) |
| |
| load_lora_btn.click( |
| fn=load_lora_handler, |
| inputs=[lora_upload, lora_name, lora_strength], |
| outputs=[lora_status, lora_status, lora_list] |
| ) |
| |
| generate_btn.click( |
| fn=generate_handler, |
| inputs=[prompt, negative_prompt, steps, guidance_scale, width, height, seed], |
| outputs=[output_image, generation_log, generation_info] |
| ) |
| |
| |
| model_selector.change( |
| fn=load_model_handler, |
| inputs=[model_selector], |
| outputs=[model_status, model_info] |
| ) |
| |
| return demo |
|
|
| |
| if __name__ == "__main__": |
| app = FluxSpaceApp() |
| demo = app.create_interface() |
| demo.launch(share=True, debug=True) |
|
|