| import numpy as np |
| from PIL import Image |
| from huggingface_hub import snapshot_download, login |
| from leffa.transform import LeffaTransform |
| from leffa.model import LeffaModel |
| from leffa.inference import LeffaInference |
| from utils.garment_agnostic_mask_predictor import AutoMasker |
| from utils.densepose_predictor import DensePosePredictor |
| from utils.utils import resize_and_center |
| import spaces |
| import torch |
| from diffusers import DiffusionPipeline |
| from transformers import pipeline |
| import gradio as gr |
| import os |
| import random |
| import gc |
| from contextlib import contextmanager |
|
|
| |
| MAX_SEED = 2**32 - 1 |
| BASE_MODEL = "black-forest-labs/FLUX.1-dev" |
| MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style" |
| CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA" |
|
|
| |
| def safe_model_call(func): |
| def wrapper(*args, **kwargs): |
| try: |
| clear_memory() |
| result = func(*args, **kwargs) |
| clear_memory() |
| return result |
| except Exception as e: |
| clear_memory() |
| print(f"Error in {func.__name__}: {str(e)}") |
| raise |
| return wrapper |
|
|
| |
| @contextmanager |
| def torch_gc(): |
| try: |
| yield |
| finally: |
| gc.collect() |
| if torch.cuda.is_available() and torch.cuda.current_device() >= 0: |
| with torch.cuda.device('cuda'): |
| torch.cuda.empty_cache() |
|
|
| def clear_memory(): |
| gc.collect() |
|
|
| def setup_environment(): |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| if not HF_TOKEN: |
| raise ValueError("HF_TOKEN not found in environment variables") |
| login(token=HF_TOKEN) |
| return HF_TOKEN |
|
|
| def contains_korean(text): |
| return any(ord('가') <= ord(char) <= ord('힣') for char in text) |
|
|
| @spaces.GPU() |
| def get_translator(): |
| return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda") |
|
|
| |
| setup_environment() |
|
|
| @spaces.GPU() |
| def initialize_fashion_pipe(): |
| with torch_gc(): |
| pipe = DiffusionPipeline.from_pretrained( |
| BASE_MODEL, |
| torch_dtype=torch.float16, |
| ) |
| return pipe.to("cuda") |
|
|
| def setup(): |
| |
| snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts") |
|
|
| @spaces.GPU() |
| def get_translator(): |
| with torch_gc(): |
| return pipeline("translation", |
| model="Helsinki-NLP/opus-mt-ko-en", |
| device="cuda") |
|
|
| @safe_model_call |
| def get_mask_predictor(): |
| global mask_predictor |
| if mask_predictor is None: |
| mask_predictor = AutoMasker( |
| densepose_path="./ckpts/densepose", |
| schp_path="./ckpts/schp", |
| ) |
| return mask_predictor |
|
|
| @safe_model_call |
| def get_densepose_predictor(): |
| global densepose_predictor |
| if densepose_predictor is None: |
| densepose_predictor = DensePosePredictor( |
| config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml", |
| weights_path="./ckpts/densepose/model_final_162be9.pkl", |
| ) |
| return densepose_predictor |
|
|
| @spaces.GPU() |
| def get_vt_model(): |
| with torch_gc(): |
| model = LeffaModel( |
| pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting", |
| pretrained_model="./ckpts/virtual_tryon.pth" |
| ) |
| model = model.half() |
| return model.to("cuda"), LeffaInference(model=model) |
|
|
| def load_lora(pipe, lora_path): |
| try: |
| pipe.unload_lora_weights() |
| except: |
| pass |
| try: |
| pipe.load_lora_weights(lora_path) |
| return pipe |
| except Exception as e: |
| print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}") |
| return pipe |
|
|
| @spaces.GPU() |
| def get_mask_predictor(): |
| global mask_predictor |
| if mask_predictor is None: |
| mask_predictor = AutoMasker( |
| densepose_path="./ckpts/densepose", |
| schp_path="./ckpts/schp", |
| ) |
| return mask_predictor |
|
|
| |
| @spaces.GPU() |
| def initialize_fashion_pipe(): |
| try: |
| pipe = DiffusionPipeline.from_pretrained( |
| BASE_MODEL, |
| torch_dtype=torch.float16, |
| safety_checker=None, |
| requires_safety_checker=False |
| ).to("cuda") |
| pipe.enable_model_cpu_offload() |
| return pipe |
| except Exception as e: |
| print(f"Error initializing fashion pipe: {e}") |
| raise |
|
|
| @spaces.GPU() |
| def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)): |
| try: |
| |
| if contains_korean(prompt): |
| with torch.inference_mode(): |
| translator = get_translator() |
| translated = translator(prompt)[0]['translation_text'] |
| actual_prompt = translated |
| else: |
| actual_prompt = prompt |
|
|
| |
| pipe = initialize_fashion_pipe() |
| |
| |
| if mode == "Generate Model": |
| pipe.load_lora_weights(MODEL_LORA_REPO) |
| trigger_word = "fashion photography, professional model" |
| else: |
| pipe.load_lora_weights(CLOTHES_LORA_REPO) |
| trigger_word = "upper clothing, fashion item" |
|
|
| |
| width = min(width, 768) |
| height = min(height, 768) |
| steps = min(steps, 30) |
| |
| |
| if randomize_seed: |
| seed = random.randint(0, MAX_SEED) |
| generator = torch.Generator("cuda").manual_seed(seed) |
|
|
| |
| with torch.inference_mode(): |
| output = pipe( |
| prompt=f"{actual_prompt} {trigger_word}", |
| num_inference_steps=steps, |
| guidance_scale=cfg_scale, |
| width=width, |
| height=height, |
| generator=generator, |
| cross_attention_kwargs={"scale": lora_scale}, |
| ) |
| |
| image = output.images[0] |
| |
| |
| del pipe |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| return image, seed |
|
|
| except Exception as e: |
| print(f"Error in generate_fashion: {str(e)}") |
| raise gr.Error(f"Generation failed: {str(e)}") |
|
|
| class ModelManager: |
| def __init__(self): |
| self.mask_predictor = None |
| self.densepose_predictor = None |
| self.translator = None |
|
|
| @spaces.GPU() |
| def get_mask_predictor(self): |
| if self.mask_predictor is None: |
| self.mask_predictor = AutoMasker( |
| densepose_path="./ckpts/densepose", |
| schp_path="./ckpts/schp", |
| ) |
| return self.mask_predictor |
|
|
| @spaces.GPU() |
| def get_densepose_predictor(self): |
| if self.densepose_predictor is None: |
| self.densepose_predictor = DensePosePredictor( |
| config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml", |
| weights_path="./ckpts/densepose/model_final_162be9.pkl", |
| ) |
| return self.densepose_predictor |
|
|
| @spaces.GPU() |
| def get_translator(self): |
| if self.translator is None: |
| self.translator = pipeline("translation", |
| model="Helsinki-NLP/opus-mt-ko-en", |
| device="cuda") |
| return self.translator |
|
|
| |
| model_manager = ModelManager() |
|
|
| @spaces.GPU() |
| def leffa_predict(src_image_path, ref_image_path, control_type): |
| try: |
| with torch_gc(): |
| |
| model, inference = get_vt_model() |
|
|
| |
| src_image = Image.open(src_image_path) |
| ref_image = Image.open(ref_image_path) |
| src_image = resize_and_center(src_image, 768, 1024) |
| ref_image = resize_and_center(ref_image, 768, 1024) |
|
|
| src_image_array = np.array(src_image) |
| ref_image_array = np.array(ref_image) |
|
|
| |
| with torch.inference_mode(): |
| src_image = src_image.convert("RGB") |
| mask_pred = model_manager.get_mask_predictor() |
| mask = mask_pred(src_image, "upper")["mask"] |
|
|
| dense_pred = model_manager.get_densepose_predictor() |
| src_image_seg_array = dense_pred.predict_seg(src_image_array) |
| densepose = Image.fromarray(src_image_seg_array) |
|
|
| |
| transform = LeffaTransform() |
| data = { |
| "src_image": [src_image], |
| "ref_image": [ref_image], |
| "mask": [mask], |
| "densepose": [densepose], |
| } |
| data = transform(data) |
| |
| with torch.inference_mode(): |
| output = inference(data) |
|
|
| |
| del model |
| del inference |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| return np.array(output["generated_image"][0]) |
| |
| except Exception as e: |
| print(f"Error in leffa_predict: {str(e)}") |
| raise |
|
|
| @spaces.GPU() |
| def leffa_predict_vt(src_image_path, ref_image_path): |
| try: |
| return leffa_predict(src_image_path, ref_image_path, "virtual_tryon") |
| except Exception as e: |
| print(f"Error in leffa_predict_vt: {str(e)}") |
| raise |
|
|
| @spaces.GPU() |
| def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85): |
| try: |
| with torch_gc(): |
| |
| if contains_korean(prompt): |
| translator = model_manager.get_translator() |
| with torch.inference_mode(): |
| translated = translator(prompt)[0]['translation_text'] |
| actual_prompt = translated |
| else: |
| actual_prompt = prompt |
|
|
| |
| pipe = DiffusionPipeline.from_pretrained( |
| BASE_MODEL, |
| torch_dtype=torch.float16, |
| ) |
| pipe = pipe.to("cuda") |
|
|
| |
| if mode == "Generate Model": |
| pipe.load_lora_weights(MODEL_LORA_REPO) |
| trigger_word = "fashion photography, professional model" |
| else: |
| pipe.load_lora_weights(CLOTHES_LORA_REPO) |
| trigger_word = "upper clothing, fashion item" |
|
|
| |
| with torch.inference_mode(): |
| result = pipe( |
| prompt=f"{actual_prompt} {trigger_word}", |
| num_inference_steps=steps, |
| guidance_scale=cfg_scale, |
| width=width, |
| height=height, |
| generator=torch.Generator("cuda").manual_seed( |
| seed if seed is not None else torch.randint(0, 2**32 - 1, (1,)).item() |
| ), |
| joint_attention_kwargs={"scale": lora_scale}, |
| ).images[0] |
|
|
| |
| del pipe |
| return result, seed |
|
|
| except Exception as e: |
| raise gr.Error(f"Generation failed: {str(e)}") |
|
|
| |
| setup() |
|
|
| def create_interface(): |
| with gr.Blocks(theme="soft") as demo: |
| gr.Markdown("# 🎭 FitGen:Fashion Studio & Virtual Try-on") |
| |
| with gr.Tabs(): |
| |
| with gr.Tab("Fashion Generation"): |
| with gr.Column(): |
| mode = gr.Radio( |
| choices=["Generate Model", "Generate Clothes"], |
| label="Generation Mode", |
| value="Generate Model" |
| ) |
| |
| |
| example_model_prompts = [ |
| "professional fashion model, full body shot, standing pose, natural lighting, studio background, high fashion, elegant pose", |
| "fashion model portrait, upper body, confident pose, fashion photography, neutral background, professional lighting", |
| "stylish fashion model, three-quarter view, editorial pose, high-end fashion magazine style, minimal background" |
| ] |
|
|
| example_clothes_prompts = [ |
| "luxury designer sweater, cashmere material, cream color, cable knit pattern, high-end fashion, product photography", |
| "elegant business blazer, tailored fit, charcoal grey, premium wool fabric, professional wear", |
| "modern streetwear hoodie, oversized fit, minimalist design, premium cotton, urban style" |
| ] |
| |
| prompt = gr.TextArea( |
| label="Fashion Description (한글 또는 영어)", |
| placeholder="패션 모델이나 의류를 설명하세요..." |
| ) |
| |
| |
| gr.Examples( |
| examples=example_model_prompts + example_clothes_prompts, |
| inputs=prompt, |
| label="Example Prompts" |
| ) |
| |
| with gr.Row(): |
| with gr.Column(): |
| result = gr.Image(label="Generated Result") |
| generate_button = gr.Button("Generate Fashion") |
| |
| with gr.Accordion("Advanced Options", open=False): |
| with gr.Group(): |
| with gr.Row(): |
| with gr.Column(): |
| cfg_scale = gr.Slider( |
| label="CFG Scale", |
| minimum=1, |
| maximum=20, |
| step=0.5, |
| value=7.0 |
| ) |
| steps = gr.Slider( |
| label="Steps", |
| minimum=1, |
| maximum=30, |
| step=1, |
| value=30 |
| ) |
| lora_scale = gr.Slider( |
| label="LoRA Scale", |
| minimum=0, |
| maximum=1, |
| step=0.01, |
| value=0.85 |
| ) |
| |
| with gr.Row(): |
| width = gr.Slider( |
| label="Width", |
| minimum=256, |
| maximum=768, |
| step=64, |
| value=512 |
| ) |
| height = gr.Slider( |
| label="Height", |
| minimum=256, |
| maximum=768, |
| step=64, |
| value=768 |
| ) |
| |
| with gr.Row(): |
| randomize_seed = gr.Checkbox( |
| True, |
| label="Randomize seed" |
| ) |
| seed = gr.Slider( |
| label="Seed", |
| minimum=0, |
| maximum=2**32-1, |
| step=1, |
| value=42 |
| ) |
|
|
| |
| with gr.Tab("Virtual Try-on"): |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("#### Person Image") |
| vt_src_image = gr.Image( |
| sources=["upload"], |
| type="filepath", |
| label="Person Image", |
| width=512, |
| height=512, |
| ) |
| gr.Examples( |
| inputs=vt_src_image, |
| examples_per_page=5, |
| examples=["a1.webp", |
| "a2.webp", |
| "a3.webp", |
| "a4.webp", |
| "a5.webp"] |
| ) |
|
|
| with gr.Column(): |
| gr.Markdown("#### Garment Image") |
| vt_ref_image = gr.Image( |
| sources=["upload"], |
| type="filepath", |
| label="Garment Image", |
| width=512, |
| height=512, |
| ) |
| gr.Examples( |
| inputs=vt_ref_image, |
| examples_per_page=5, |
| examples=["b1.webp", |
| "b2.webp", |
| "b3.webp", |
| "b4.webp", |
|
|
| "c1.png", |
| "c2.png", |
| "c3.png", |
| "c4.png", |
| "c5.png", |
| "c6.png", |
| "c7.png", |
| "c8.png", |
| "c9.png", |
| "c10.png", |
| "c11.png", |
| "c12.png", |
| "c13.png", |
| "c14.png", |
| "c15.png", |
| "c16.png", |
| "b5.webp"] |
| ) |
|
|
| with gr.Column(): |
| gr.Markdown("#### Generated Image") |
| vt_gen_image = gr.Image( |
| label="Generated Image", |
| width=512, |
| height=512, |
| ) |
| vt_gen_button = gr.Button("Try-on") |
|
|
| vt_gen_button.click( |
| fn=leffa_predict_vt, |
| inputs=[vt_src_image, vt_ref_image], |
| outputs=[vt_gen_image] |
| ) |
|
|
| generate_button.click( |
| fn=generate_image, |
| inputs=[prompt, mode, cfg_scale, steps, seed, width, height, lora_scale], |
| outputs=[result, seed] |
| ).success( |
| fn=lambda: gc.collect(), |
| inputs=None, |
| outputs=None |
| ) |
|
|
| return demo |
|
|
| if __name__ == "__main__": |
| setup_environment() |
| demo = create_interface() |
| demo.queue() |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False |
| ) |