| import gradio as gr |
| import json |
| import base64 |
| import requests |
| import time |
| import os |
| from dotenv import load_dotenv |
| import numpy as np |
| from PIL import Image |
| import io |
|
|
| |
| load_dotenv() |
| API_KEY = os.getenv('API_KEY') |
| CURRENT_URL = os.getenv('CURRENT_URL') |
|
|
| |
| TRYON_URL = CURRENT_URL + 'api/tryon/' |
| FETCH_URL = CURRENT_URL + 'api/tryon_state/' |
|
|
| |
| headers = { |
| 'Authorization': 'Bearer ' + API_KEY, |
| 'Content-Type': 'application/json', |
| } |
|
|
| |
| os.makedirs("examples/garments", exist_ok=True) |
| os.makedirs("examples/persons", exist_ok=True) |
|
|
| |
| sample_garments = [ |
| "samples/garments/g1.jpg", |
| "samples/garments/g2.jpg", |
| ] |
|
|
| sample_humans = [ |
| "samples/humans/h1.jpg", |
| "samples/humans/h2.jpg", |
| |
| ] |
|
|
| def preprocess_image(img, target_size=None): |
| """Preprocess image without resizing if target_size is None""" |
| if img is None: |
| return None |
| |
| |
| if isinstance(img, np.ndarray): |
| img = Image.fromarray(img.astype('uint8')) |
| |
| |
| if target_size is not None: |
| img = img.resize(target_size, Image.LANCZOS) |
| |
| return img |
|
|
| def virtual_tryon(garment_img, person_img): |
| |
| if person_img is None or garment_img is None: |
| return None |
| |
| |
| human_pil = preprocess_image(person_img) |
| garment_pil = preprocess_image(garment_img) |
| |
| human_buffer = io.BytesIO() |
| garment_buffer = io.BytesIO() |
| |
| human_pil.save(human_buffer, format="JPEG") |
| garment_pil.save(garment_buffer, format="JPEG") |
| |
| human_base64_image = base64.b64encode(human_buffer.getvalue()).decode('utf-8') |
| garment_base64_image = base64.b64encode(garment_buffer.getvalue()).decode('utf-8') |
| |
| |
| data = { |
| 'human_image_base64': human_base64_image, |
| 'garment_image_base64': garment_base64_image, |
| } |
| |
| |
| response = requests.post(TRYON_URL, headers=headers, data=json.dumps(data)) |
| |
| if response.status_code != 200: |
| return None |
| |
| json_response = response.json() |
| tryon_pk = json_response['tryon_pk'] |
| |
| |
| time_elapsed = 0 |
| while time_elapsed < 60: |
| fetch_response = requests.post(FETCH_URL, headers=headers, data=json.dumps({ |
| 'tryon_pk': tryon_pk, |
| })) |
| |
| if fetch_response.status_code != 200: |
| return None |
| |
| json_response = fetch_response.json() |
| |
| if json_response.get('message') != 'success': |
| return None |
| |
| if json_response.get('status') == 'done': |
| |
| result_url = json_response['s3_url'] |
| img_response = requests.get(result_url) |
| if img_response.status_code == 200: |
| return Image.open(io.BytesIO(img_response.content)) |
| |
| time.sleep(2) |
| time_elapsed += 2 |
| |
| return None |
|
|
| custom_css = """ |
| body, .gradio-container { |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; |
| background-color: #121212; |
| color: white; |
| } |
| |
| h1, h2, h3 { |
| color: white !important; |
| } |
| |
| .container { |
| max-width: 1200px; |
| margin: 0 auto; |
| } |
| |
| .image-container img { |
| object-fit: contain; |
| max-height: 450px; |
| width: auto; |
| margin: 0 auto; |
| display: block; |
| border-radius: 8px; |
| } |
| |
| .examples-container { |
| display: grid; |
| grid-template-columns: repeat(auto-fill, minmax(150px, 1fr)); |
| gap: 10px; |
| margin-top: 10px; |
| } |
| |
| .examples-container img { |
| height: 120px; |
| object-fit: cover; |
| border-radius: 8px; |
| cursor: pointer; |
| transition: transform 0.2s; |
| } |
| |
| .examples-container img:hover { |
| transform: scale(1.05); |
| } |
| |
| button#try-on-button { |
| background-color: #FF6B00 !important; |
| color: white !important; |
| border: none !important; |
| padding: 12px 20px !important; |
| font-weight: 600 !important; |
| border-radius: 8px !important; |
| cursor: pointer !important; |
| transition: background-color 0.3s !important; |
| } |
| |
| button#try-on-button:hover { |
| background-color: #FF8C33 !important; |
| } |
| |
| footer {visibility: hidden} |
| """ |
|
|
| |
| with gr.Blocks(theme=gr.themes.Base(), css=custom_css) as demo: |
| gr.HTML("<h1 style='text-align: center; margin-bottom: 20px;'>AlphaBakeVirtual Try-On</h1>") |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### Garment Image") |
| garment_input = gr.Image( |
| label="Upload a garment image", |
| type="pil", |
| elem_id="garment-image", |
| elem_classes=["image-container"], |
| height=350 |
| ) |
| |
| |
| gr.Examples( |
| examples=sample_garments, |
| inputs=garment_input, |
| label="Garment Examples", |
| examples_per_page=4 |
| ) |
| |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### Person Image") |
| person_input = gr.Image( |
| label="Upload a person image", |
| type="pil", |
| elem_id="person-image", |
| elem_classes=["image-container"], |
| height=350 |
| ) |
| |
| |
| gr.Examples( |
| examples=sample_humans, |
| inputs=person_input, |
| label="Person Examples", |
| examples_per_page=4 |
| ) |
| |
| |
| with gr.Column(scale=1): |
| |
| |
| try_on_button = gr.Button("Try On", elem_id="try-on-button", variant="primary", size="lg") |
| |
| |
| output_image = gr.Image( |
| label="Result", |
| type="pil", |
| elem_classes=["result-image"], |
| height=400 |
| ) |
| |
| |
| |
| def validate_inputs(garment_img, person_img, garment_type, sleeve_length, garment_length): |
| if garment_img is None: |
| raise gr.Error("Please upload a garment image") |
| if person_img is None: |
| raise gr.Error("Please upload a person image") |
| |
| |
| try: |
| result = virtual_tryon(garment_img, person_img) |
| return result |
| except Exception as e: |
| raise gr.Error(f"Error: {str(e)}") |
| |
| |
| try_on_button.click( |
| fn=validate_inputs, |
| inputs=[garment_input, person_input], |
| outputs=output_image |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|