| |
| import os |
| import gradio as gr |
| from preprocess import process_dataset |
| import subprocess |
| import zipfile |
| import shutil |
| import time |
|
|
| def train_lora_interface( |
| dataset_input, input_type, model_name, lora_rank, learning_rate, |
| num_epochs, hub_token, concept_name, description |
| ): |
| if not dataset_input: |
| yield "❌ Por favor, envie um ZIP ou selecione imagens." |
| return |
| if not concept_name.strip(): |
| yield "❌ Por favor, defina um nome para o conceito (ex: brenda)." |
| return |
| if not description.strip(): |
| yield "❌ Por favor, adicione uma descrição base." |
| return |
|
|
| concept_name = concept_name.strip().replace(" ", "_") |
| full_description = f"{description.strip()}, {concept_name}" |
|
|
| yield f"🏷️ Treinando: '{concept_name}'" |
|
|
| dataset_dir = "processed_data" |
| os.makedirs(dataset_dir, exist_ok=True) |
|
|
| |
| for item in os.listdir(dataset_dir): |
| item_path = os.path.join(dataset_dir, item) |
| try: |
| if os.path.isfile(item_path) or os.path.islink(item_path): |
| os.unlink(item_path) |
| elif os.path.isdir(item_path): |
| shutil.rmtree(item_path) |
| except Exception as e: |
| yield f"⚠️ Erro ao limpar: {e}" |
|
|
| |
| if input_type == "Upload de ZIP": |
| zip_file = dataset_input[0] if isinstance(dataset_input, list) else dataset_input |
| |
| if not zipfile.is_zipfile(zip_file): |
| yield "❌ Arquivo não é um ZIP válido." |
| return |
|
|
| yield "📦 Descompactando..." |
| with zipfile.ZipFile(zip_file, 'r') as z: |
| z.extractall(dataset_dir) |
| yield f"✅ ZIP extraído! {len(z.namelist())} arquivos." |
|
|
| else: |
| image_files = dataset_input if isinstance(dataset_input, list) else [dataset_input] |
| yield f"🖼️ Copiando {len(image_files)} imagens..." |
|
|
| for uploaded_file in image_files: |
| if hasattr(uploaded_file, 'name'): |
| src = uploaded_file.name |
| dest = os.path.join(dataset_dir, os.path.basename(src)) |
| shutil.copy(src, dest) |
| |
| yield f"✅ {len(image_files)} imagens copiadas." |
|
|
| |
| exts = ('.png', '.jpg', '.jpeg', '.bmp', '.webp') |
| images = [f for f in os.listdir(dataset_dir) if f.lower().endswith(exts)] |
| |
| if len(images) == 0: |
| yield "❌ Nenhuma imagem encontrada!" |
| return |
|
|
| yield f"📝 Aplicando legenda: '{full_description}'" |
|
|
| for img in images: |
| txt = os.path.join(dataset_dir, os.path.splitext(img)[0] + ".txt") |
| if not os.path.exists(txt): |
| with open(txt, "w", encoding="utf-8") as f: |
| f.write(full_description) |
|
|
| yield "🔍 Legendas prontas!" |
|
|
| |
| output_dir = "lora-output" |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| cmd = [ |
| "python", "train_lora.py", |
| "--dataset_dir", dataset_dir, |
| "--model_name", model_name, |
| "--lora_rank", str(lora_rank), |
| "--learning_rate", str(learning_rate), |
| "--num_epochs", str(num_epochs), |
| "--batch_size", "1", |
| "--output_dir", output_dir |
| ] |
|
|
| if hub_token: |
| os.environ["HF_TOKEN"] = hub_token |
| cmd += ["--push_to_hub", "--hub_model_id", f"{concept_name}-lora"] |
|
|
| yield "🔥 Iniciando treinamento..." |
|
|
| try: |
| process = subprocess.Popen( |
| cmd, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.STDOUT, |
| universal_newlines=True, |
| bufsize=1, |
| encoding='utf-8' |
| ) |
|
|
| log_output = "" |
| for line in process.stdout: |
| log_output += line |
| if "loss" in line.lower() or "epoch" in line.lower(): |
| yield f"📊 {line.strip()}" |
|
|
| process.wait() |
|
|
| if process.returncode == 0: |
| yield f""" |
| 🎉 SUCESSO! |
| |
| 🔹 Use no prompt: `photo of {concept_name} in the forest` |
| 🔹 Modelo salvo em: `{output_dir}` |
| {'🔹 Publicado no Hub!' if hub_token else ''} |
| """ |
| else: |
| yield f"❌ Falha no treinamento. Código: {process.returncode}\nLogs:\n{log_output[-1000:]}" |
|
|
| except Exception as e: |
| yield f"💥 Erro: {str(e)}" |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face") |
| gr.Markdown("Treine personagens, estilos ou objetos personalizados.") |
|
|
| with gr.Row(): |
| input_type = gr.Radio( |
| ["Upload de ZIP", "Selecionar várias imagens"], |
| label="Tipo de Entrada", |
| value="Upload de ZIP" |
| ) |
|
|
| with gr.Row(): |
| dataset_input = gr.File( |
| label="📤 Envie seu ZIP ou imagens", |
| file_types=[".zip", ".jpg", ".jpeg", ".png", ".bmp", ".webp"], |
| file_count="multiple" |
| ) |
|
|
| gr.Markdown("### 🔖 Identidade do Personagem") |
| with gr.Row(): |
| concept_name = gr.Textbox( |
| label="Nome do Conceito (ex: brenda)", |
| placeholder="Ex: brenda, cyborg_x", |
| value="" |
| ) |
| with gr.Row(): |
| description = gr.Textbox( |
| label="Descrição Base (ex: woman, curly hair)", |
| placeholder="Ex: young black woman, realistic style", |
| lines=2 |
| ) |
|
|
| gr.Markdown("### ⚙️ Configurações") |
| with gr.Row(): |
| model_name = gr.Dropdown( |
| ["runwayml/stable-diffusion-v1-5"], |
| value="runwayml/stable-diffusion-v1-5", |
| label="Modelo Base" |
| ) |
| lora_rank = gr.Slider(4, 64, value=4, step=4, label="LoRA Rank") |
| learning_rate = gr.Number(value=1e-4, label="Taxa de Aprendizado") |
| num_epochs = gr.Slider(1, 30, value=10, step=1, label="Épocas") |
|
|
| hub_token = gr.Textbox(label="🔐 Token do HF (opcional)", type="password") |
|
|
| btn = gr.Button("🚀 Iniciar Treinamento", variant="primary") |
| output = gr.Textbox(label="📦 Logs", lines=12) |
|
|
| btn.click( |
| train_lora_interface, |
| inputs=[ |
| dataset_input, input_type, model_name, lora_rank, |
| learning_rate, num_epochs, hub_token, concept_name, description |
| ], |
| outputs=output |
| ) |
|
|
| demo.queue() |
| if __name__ == "__main__": |
| demo.launch() |