| |
| """ |
| LoRA Trainer Funcional para Hugging Face |
| Baseado no kohya-ss sd-scripts |
| """ |
|
|
| import gradio as gr |
| import os |
| import sys |
| import json |
| import subprocess |
| import shutil |
| import zipfile |
| import tempfile |
| import toml |
| import logging |
| from pathlib import Path |
| from typing import Optional, Tuple, List, Dict, Any |
| import time |
| import threading |
| import queue |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent / "sd-scripts")) |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| class LoRATrainerHF: |
| def __init__(self): |
| self.base_dir = Path("/tmp/lora_training") |
| self.base_dir.mkdir(exist_ok=True) |
| |
| self.models_dir = self.base_dir / "models" |
| self.models_dir.mkdir(exist_ok=True) |
| |
| self.projects_dir = self.base_dir / "projects" |
| self.projects_dir.mkdir(exist_ok=True) |
| |
| self.sd_scripts_dir = Path(__file__).parent / "sd-scripts" |
| |
| |
| self.model_urls = { |
| "Anime (animefull-final-pruned)": "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/animefull-final-pruned-fp16.safetensors", |
| "AnyLoRA": "https://huggingface.co/Lykon/AnyLoRA/resolve/main/AnyLoRA_noVae_fp16-pruned.ckpt", |
| "Stable Diffusion 1.5": "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors", |
| "Waifu Diffusion 1.4": "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt" |
| } |
| |
| self.training_process = None |
| self.training_output_queue = queue.Queue() |
| |
| def install_dependencies(self) -> str: |
| """Instala as dependências necessárias""" |
| try: |
| logger.info("Instalando dependências...") |
| |
| |
| packages = [ |
| "torch>=2.0.0", |
| "torchvision>=0.15.0", |
| "diffusers>=0.21.0", |
| "transformers>=4.25.0", |
| "accelerate>=0.20.0", |
| "safetensors>=0.3.0", |
| "huggingface-hub>=0.16.0", |
| "xformers>=0.0.20", |
| "bitsandbytes>=0.41.0", |
| "opencv-python>=4.7.0", |
| "Pillow>=9.0.0", |
| "numpy>=1.21.0", |
| "tqdm>=4.64.0", |
| "toml>=0.10.0", |
| "tensorboard>=2.13.0", |
| "wandb>=0.15.0", |
| "scipy>=1.9.0", |
| "matplotlib>=3.5.0", |
| "datasets>=2.14.0", |
| "peft>=0.5.0", |
| "omegaconf>=2.3.0" |
| ] |
| |
| |
| for package in packages: |
| try: |
| subprocess.run([ |
| sys.executable, "-m", "pip", "install", package, "--quiet" |
| ], check=True, capture_output=True, text=True) |
| logger.info(f"✓ {package} instalado") |
| except subprocess.CalledProcessError as e: |
| logger.warning(f"⚠ Erro ao instalar {package}: {e}") |
| |
| return "✅ Dependências instaladas com sucesso!" |
| |
| except Exception as e: |
| logger.error(f"Erro ao instalar dependências: {e}") |
| return f"❌ Erro ao instalar dependências: {e}" |
| |
| def download_model(self, model_choice: str, custom_url: str = "") -> str: |
| """Download do modelo base""" |
| try: |
| if custom_url.strip(): |
| model_url = custom_url.strip() |
| model_name = model_url.split("/")[-1] |
| else: |
| if model_choice not in self.model_urls: |
| return f"❌ Modelo '{model_choice}' não encontrado" |
| model_url = self.model_urls[model_choice] |
| model_name = model_url.split("/")[-1] |
| |
| model_path = self.models_dir / model_name |
| |
| if model_path.exists(): |
| return f"✅ Modelo já existe: {model_name}" |
| |
| logger.info(f"Baixando modelo: {model_url}") |
| |
| |
| result = subprocess.run([ |
| "wget", "-O", str(model_path), model_url, "--progress=bar:force" |
| ], capture_output=True, text=True) |
| |
| if result.returncode == 0: |
| return f"✅ Modelo baixado: {model_name} ({model_path.stat().st_size // (1024*1024)} MB)" |
| else: |
| return f"❌ Erro no download: {result.stderr}" |
| |
| except Exception as e: |
| logger.error(f"Erro ao baixar modelo: {e}") |
| return f"❌ Erro ao baixar modelo: {e}" |
| |
| def process_dataset(self, dataset_zip, project_name: str) -> Tuple[str, str]: |
| """Processa o dataset enviado""" |
| try: |
| if not dataset_zip: |
| return "❌ Nenhum dataset foi enviado", "" |
| |
| if not project_name.strip(): |
| return "❌ Nome do projeto é obrigatório", "" |
| |
| project_name = project_name.strip().replace(" ", "_") |
| project_dir = self.projects_dir / project_name |
| project_dir.mkdir(exist_ok=True) |
| |
| dataset_dir = project_dir / "dataset" |
| if dataset_dir.exists(): |
| shutil.rmtree(dataset_dir) |
| dataset_dir.mkdir() |
| |
| |
| with zipfile.ZipFile(dataset_zip.name, 'r') as zip_ref: |
| zip_ref.extractall(dataset_dir) |
| |
| |
| image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff'} |
| images = [] |
| captions = [] |
| |
| for file_path in dataset_dir.rglob("*"): |
| if file_path.suffix.lower() in image_extensions: |
| images.append(file_path) |
| |
| |
| caption_path = file_path.with_suffix('.txt') |
| if caption_path.exists(): |
| captions.append(caption_path) |
| |
| info = f"✅ Dataset processado!\n" |
| info += f"📁 Projeto: {project_name}\n" |
| info += f"🖼️ Imagens: {len(images)}\n" |
| info += f"📝 Captions: {len(captions)}\n" |
| info += f"📂 Diretório: {dataset_dir}" |
| |
| return info, str(dataset_dir) |
| |
| except Exception as e: |
| logger.error(f"Erro ao processar dataset: {e}") |
| return f"❌ Erro ao processar dataset: {e}", "" |
| |
| def create_training_config(self, |
| project_name: str, |
| dataset_dir: str, |
| model_choice: str, |
| custom_model_url: str, |
| resolution: int, |
| batch_size: int, |
| epochs: int, |
| learning_rate: float, |
| text_encoder_lr: float, |
| network_dim: int, |
| network_alpha: int, |
| lora_type: str, |
| optimizer: str, |
| scheduler: str, |
| flip_aug: bool, |
| shuffle_caption: bool, |
| keep_tokens: int, |
| clip_skip: int, |
| mixed_precision: str, |
| save_every_n_epochs: int, |
| max_train_steps: int) -> str: |
| """Cria configuração de treinamento""" |
| try: |
| if not project_name.strip(): |
| return "❌ Nome do projeto é obrigatório" |
| |
| project_name = project_name.strip().replace(" ", "_") |
| project_dir = self.projects_dir / project_name |
| project_dir.mkdir(exist_ok=True) |
| |
| output_dir = project_dir / "output" |
| output_dir.mkdir(exist_ok=True) |
| |
| log_dir = project_dir / "logs" |
| log_dir.mkdir(exist_ok=True) |
| |
| |
| if custom_model_url.strip(): |
| model_name = custom_model_url.strip().split("/")[-1] |
| else: |
| model_name = self.model_urls[model_choice].split("/")[-1] |
| |
| model_path = self.models_dir / model_name |
| |
| if not model_path.exists(): |
| return f"❌ Modelo não encontrado: {model_name}. Faça o download primeiro." |
| |
| |
| dataset_config = { |
| "general": { |
| "shuffle_caption": shuffle_caption, |
| "caption_extension": ".txt", |
| "keep_tokens": keep_tokens, |
| "flip_aug": flip_aug, |
| "color_aug": False, |
| "face_crop_aug_range": None, |
| "random_crop": False, |
| "debug_dataset": False |
| }, |
| "datasets": [{ |
| "resolution": resolution, |
| "batch_size": batch_size, |
| "subsets": [{ |
| "image_dir": str(dataset_dir), |
| "num_repeats": 1 |
| }] |
| }] |
| } |
| |
| |
| training_config = { |
| "model_arguments": { |
| "pretrained_model_name_or_path": str(model_path), |
| "v2": False, |
| "v_parameterization": False, |
| "clip_skip": clip_skip |
| }, |
| "dataset_arguments": { |
| "dataset_config": str(project_dir / "dataset_config.toml") |
| }, |
| "training_arguments": { |
| "output_dir": str(output_dir), |
| "output_name": project_name, |
| "save_precision": "fp16", |
| "save_every_n_epochs": save_every_n_epochs, |
| "max_train_epochs": epochs if max_train_steps == 0 else None, |
| "max_train_steps": max_train_steps if max_train_steps > 0 else None, |
| "train_batch_size": batch_size, |
| "gradient_accumulation_steps": 1, |
| "learning_rate": learning_rate, |
| "text_encoder_lr": text_encoder_lr, |
| "lr_scheduler": scheduler, |
| "lr_warmup_steps": 0, |
| "optimizer_type": optimizer, |
| "mixed_precision": mixed_precision, |
| "save_model_as": "safetensors", |
| "seed": 42, |
| "max_data_loader_n_workers": 2, |
| "persistent_data_loader_workers": True, |
| "gradient_checkpointing": True, |
| "xformers": True, |
| "lowram": True, |
| "cache_latents": True, |
| "cache_latents_to_disk": True, |
| "logging_dir": str(log_dir), |
| "log_with": "tensorboard" |
| }, |
| "network_arguments": { |
| "network_module": "networks.lora" if lora_type == "LoRA" else "networks.dylora", |
| "network_dim": network_dim, |
| "network_alpha": network_alpha, |
| "network_train_unet_only": False, |
| "network_train_text_encoder_only": False |
| } |
| } |
| |
| |
| if lora_type == "LoCon": |
| training_config["network_arguments"]["network_module"] = "networks.lora" |
| training_config["network_arguments"]["conv_dim"] = max(1, network_dim // 2) |
| training_config["network_arguments"]["conv_alpha"] = max(1, network_alpha // 2) |
| |
| |
| dataset_config_path = project_dir / "dataset_config.toml" |
| training_config_path = project_dir / "training_config.toml" |
| |
| with open(dataset_config_path, 'w') as f: |
| toml.dump(dataset_config, f) |
| |
| with open(training_config_path, 'w') as f: |
| toml.dump(training_config, f) |
| |
| return f"✅ Configuração criada!\n📁 Dataset: {dataset_config_path}\n⚙️ Treinamento: {training_config_path}" |
| |
| except Exception as e: |
| logger.error(f"Erro ao criar configuração: {e}") |
| return f"❌ Erro ao criar configuração: {e}" |
| |
| def start_training(self, project_name: str) -> str: |
| """Inicia o treinamento""" |
| try: |
| if not project_name.strip(): |
| return "❌ Nome do projeto é obrigatório" |
| |
| project_name = project_name.strip().replace(" ", "_") |
| project_dir = self.projects_dir / project_name |
| |
| training_config_path = project_dir / "training_config.toml" |
| if not training_config_path.exists(): |
| return "❌ Configuração não encontrada. Crie a configuração primeiro." |
| |
| |
| train_script = self.sd_scripts_dir / "train_network.py" |
| if not train_script.exists(): |
| return "❌ Script de treinamento não encontrado" |
| |
| |
| cmd = [ |
| sys.executable, |
| str(train_script), |
| "--config_file", str(training_config_path) |
| ] |
| |
| logger.info(f"Iniciando treinamento: {' '.join(cmd)}") |
| |
| |
| def run_training(): |
| try: |
| process = subprocess.Popen( |
| cmd, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.STDOUT, |
| text=True, |
| bufsize=1, |
| universal_newlines=True, |
| cwd=str(self.sd_scripts_dir) |
| ) |
| |
| self.training_process = process |
| |
| for line in process.stdout: |
| self.training_output_queue.put(line.strip()) |
| logger.info(line.strip()) |
| |
| process.wait() |
| |
| if process.returncode == 0: |
| self.training_output_queue.put("✅ TREINAMENTO CONCLUÍDO COM SUCESSO!") |
| else: |
| self.training_output_queue.put(f"❌ TREINAMENTO FALHOU (código {process.returncode})") |
| |
| except Exception as e: |
| self.training_output_queue.put(f"❌ ERRO NO TREINAMENTO: {e}") |
| finally: |
| self.training_process = None |
| |
| |
| training_thread = threading.Thread(target=run_training) |
| training_thread.daemon = True |
| training_thread.start() |
| |
| return "🚀 Treinamento iniciado! Acompanhe o progresso abaixo." |
| |
| except Exception as e: |
| logger.error(f"Erro ao iniciar treinamento: {e}") |
| return f"❌ Erro ao iniciar treinamento: {e}" |
| |
| def get_training_output(self) -> str: |
| """Obtém output do treinamento""" |
| output_lines = [] |
| try: |
| while not self.training_output_queue.empty(): |
| line = self.training_output_queue.get_nowait() |
| output_lines.append(line) |
| except queue.Empty: |
| pass |
| |
| if output_lines: |
| return "\n".join(output_lines) |
| elif self.training_process and self.training_process.poll() is None: |
| return "🔄 Treinamento em andamento..." |
| else: |
| return "⏸️ Nenhum treinamento ativo" |
| |
| def stop_training(self) -> str: |
| """Para o treinamento""" |
| try: |
| if self.training_process and self.training_process.poll() is None: |
| self.training_process.terminate() |
| self.training_process.wait(timeout=10) |
| return "⏹️ Treinamento interrompido" |
| else: |
| return "ℹ️ Nenhum treinamento ativo para parar" |
| except Exception as e: |
| return f"❌ Erro ao parar treinamento: {e}" |
| |
| def list_output_files(self, project_name: str) -> List[str]: |
| """Lista arquivos de saída""" |
| try: |
| if not project_name.strip(): |
| return [] |
| |
| project_name = project_name.strip().replace(" ", "_") |
| project_dir = self.projects_dir / project_name |
| output_dir = project_dir / "output" |
| |
| if not output_dir.exists(): |
| return [] |
| |
| files = [] |
| for file_path in output_dir.rglob("*.safetensors"): |
| size_mb = file_path.stat().st_size // (1024 * 1024) |
| files.append(f"{file_path.name} ({size_mb} MB)") |
| |
| return sorted(files, reverse=True) |
| |
| except Exception as e: |
| logger.error(f"Erro ao listar arquivos: {e}") |
| return [] |
|
|
| |
| trainer = LoRATrainerHF() |
|
|
| def create_interface(): |
| """Cria a interface Gradio""" |
| |
| with gr.Blocks(title="LoRA Trainer Funcional - Hugging Face", theme=gr.themes.Soft()) as interface: |
| |
| gr.Markdown(""" |
| # 🎨 LoRA Trainer Funcional para Hugging Face |
| |
| **Treine seus próprios modelos LoRA para Stable Diffusion de forma profissional!** |
| |
| Esta ferramenta é baseada no kohya-ss sd-scripts e oferece treinamento real e funcional de modelos LoRA. |
| """) |
| |
| |
| dataset_dir_state = gr.State("") |
| |
| with gr.Tab("🔧 Instalação"): |
| gr.Markdown("### Primeiro, instale as dependências necessárias:") |
| install_btn = gr.Button("📦 Instalar Dependências", variant="primary", size="lg") |
| install_status = gr.Textbox(label="Status da Instalação", lines=3, interactive=False) |
| |
| install_btn.click( |
| fn=trainer.install_dependencies, |
| outputs=install_status |
| ) |
| |
| with gr.Tab("📁 Configuração do Projeto"): |
| with gr.Row(): |
| project_name = gr.Textbox( |
| label="Nome do Projeto", |
| placeholder="meu_lora_anime", |
| info="Nome único para seu projeto (sem espaços especiais)" |
| ) |
| |
| gr.Markdown("### 📥 Download do Modelo Base") |
| with gr.Row(): |
| model_choice = gr.Dropdown( |
| choices=list(trainer.model_urls.keys()), |
| label="Modelo Base Pré-definido", |
| value="Anime (animefull-final-pruned)", |
| info="Escolha um modelo base ou use URL personalizada" |
| ) |
| custom_model_url = gr.Textbox( |
| label="URL Personalizada (opcional)", |
| placeholder="https://huggingface.co/...", |
| info="URL direta para download de modelo personalizado" |
| ) |
| |
| download_btn = gr.Button("📥 Baixar Modelo", variant="primary") |
| download_status = gr.Textbox(label="Status do Download", lines=2, interactive=False) |
| |
| gr.Markdown("### 📊 Upload do Dataset") |
| gr.Markdown(""" |
| **Formato do Dataset:** |
| - Crie um arquivo ZIP contendo suas imagens |
| - Para cada imagem, inclua um arquivo .txt com o mesmo nome contendo as tags/descrições |
| - Exemplo: `imagem1.jpg` + `imagem1.txt` |
| """) |
| |
| dataset_upload = gr.File( |
| label="Upload do Dataset (ZIP)", |
| file_types=[".zip"] |
| ) |
| |
| process_btn = gr.Button("📊 Processar Dataset", variant="primary") |
| dataset_status = gr.Textbox(label="Status do Dataset", lines=4, interactive=False) |
| |
| with gr.Tab("⚙️ Parâmetros de Treinamento"): |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("#### 🖼️ Configurações de Imagem") |
| resolution = gr.Slider( |
| minimum=512, maximum=1024, step=64, value=512, |
| label="Resolução", |
| info="Resolução das imagens (512 = mais rápido, 1024 = melhor qualidade)" |
| ) |
| batch_size = gr.Slider( |
| minimum=1, maximum=8, step=1, value=1, |
| label="Batch Size", |
| info="Imagens por lote (aumente se tiver GPU potente)" |
| ) |
| flip_aug = gr.Checkbox( |
| label="Flip Augmentation", |
| info="Espelhar imagens para aumentar dataset" |
| ) |
| shuffle_caption = gr.Checkbox( |
| value=True, |
| label="Shuffle Caption", |
| info="Embaralhar ordem das tags" |
| ) |
| keep_tokens = gr.Slider( |
| minimum=0, maximum=5, step=1, value=1, |
| label="Keep Tokens", |
| info="Número de tokens iniciais que não serão embaralhados" |
| ) |
| |
| with gr.Column(): |
| gr.Markdown("#### 🎯 Configurações de Treinamento") |
| epochs = gr.Slider( |
| minimum=1, maximum=100, step=1, value=10, |
| label="Épocas", |
| info="Número de épocas de treinamento" |
| ) |
| max_train_steps = gr.Number( |
| value=0, |
| label="Max Train Steps (0 = usar épocas)", |
| info="Número máximo de steps (deixe 0 para usar épocas)" |
| ) |
| save_every_n_epochs = gr.Slider( |
| minimum=1, maximum=10, step=1, value=1, |
| label="Salvar a cada N épocas", |
| info="Frequência de salvamento dos checkpoints" |
| ) |
| mixed_precision = gr.Dropdown( |
| choices=["fp16", "bf16", "no"], |
| value="fp16", |
| label="Mixed Precision", |
| info="fp16 = mais rápido, bf16 = mais estável" |
| ) |
| clip_skip = gr.Slider( |
| minimum=1, maximum=12, step=1, value=2, |
| label="CLIP Skip", |
| info="Camadas CLIP a pular (2 para anime, 1 para realista)" |
| ) |
| |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("#### 📚 Learning Rate") |
| learning_rate = gr.Number( |
| value=1e-4, |
| label="Learning Rate (UNet)", |
| info="Taxa de aprendizado principal" |
| ) |
| text_encoder_lr = gr.Number( |
| value=5e-5, |
| label="Learning Rate (Text Encoder)", |
| info="Taxa de aprendizado do text encoder" |
| ) |
| scheduler = gr.Dropdown( |
| choices=["cosine", "cosine_with_restarts", "constant", "constant_with_warmup", "linear"], |
| value="cosine_with_restarts", |
| label="LR Scheduler", |
| info="Algoritmo de ajuste da learning rate" |
| ) |
| optimizer = gr.Dropdown( |
| choices=["AdamW8bit", "AdamW", "Lion", "SGD"], |
| value="AdamW8bit", |
| label="Otimizador", |
| info="AdamW8bit = menos memória" |
| ) |
| |
| with gr.Column(): |
| gr.Markdown("#### 🧠 Arquitetura LoRA") |
| lora_type = gr.Radio( |
| choices=["LoRA", "LoCon"], |
| value="LoRA", |
| label="Tipo de LoRA", |
| info="LoRA = geral, LoCon = estilos artísticos" |
| ) |
| network_dim = gr.Slider( |
| minimum=4, maximum=128, step=4, value=32, |
| label="Network Dimension", |
| info="Dimensão da rede (maior = mais detalhes, mais memória)" |
| ) |
| network_alpha = gr.Slider( |
| minimum=1, maximum=128, step=1, value=16, |
| label="Network Alpha", |
| info="Controla a força do LoRA (geralmente dim/2)" |
| ) |
| |
| with gr.Tab("🚀 Treinamento"): |
| create_config_btn = gr.Button("📝 Criar Configuração de Treinamento", variant="primary", size="lg") |
| config_status = gr.Textbox(label="Status da Configuração", lines=3, interactive=False) |
| |
| with gr.Row(): |
| start_training_btn = gr.Button("🎯 Iniciar Treinamento", variant="primary", size="lg") |
| stop_training_btn = gr.Button("⏹️ Parar Treinamento", variant="stop") |
| |
| training_output = gr.Textbox( |
| label="Output do Treinamento", |
| lines=15, |
| interactive=False, |
| info="Acompanhe o progresso do treinamento em tempo real" |
| ) |
| |
| |
| def update_output(): |
| return trainer.get_training_output() |
| |
| with gr.Tab("📥 Download dos Resultados"): |
| refresh_files_btn = gr.Button("🔄 Atualizar Lista de Arquivos", variant="secondary") |
| |
| output_files = gr.Dropdown( |
| label="Arquivos LoRA Gerados", |
| choices=[], |
| info="Selecione um arquivo para download" |
| ) |
| |
| download_info = gr.Markdown("ℹ️ Os arquivos LoRA estarão disponíveis após o treinamento") |
| |
| |
| download_btn.click( |
| fn=trainer.download_model, |
| inputs=[model_choice, custom_model_url], |
| outputs=download_status |
| ) |
| |
| process_btn.click( |
| fn=trainer.process_dataset, |
| inputs=[dataset_upload, project_name], |
| outputs=[dataset_status, dataset_dir_state] |
| ) |
| |
| create_config_btn.click( |
| fn=trainer.create_training_config, |
| inputs=[ |
| project_name, dataset_dir_state, model_choice, custom_model_url, |
| resolution, batch_size, epochs, learning_rate, text_encoder_lr, |
| network_dim, network_alpha, lora_type, optimizer, scheduler, |
| flip_aug, shuffle_caption, keep_tokens, clip_skip, mixed_precision, |
| save_every_n_epochs, max_train_steps |
| ], |
| outputs=config_status |
| ) |
| |
| start_training_btn.click( |
| fn=trainer.start_training, |
| inputs=project_name, |
| outputs=training_output |
| ) |
| |
| stop_training_btn.click( |
| fn=trainer.stop_training, |
| outputs=training_output |
| ) |
| |
| refresh_files_btn.click( |
| fn=trainer.list_output_files, |
| inputs=project_name, |
| outputs=output_files |
| ) |
| |
| return interface |
|
|
| if __name__ == "__main__": |
| print("🚀 Iniciando LoRA Trainer Funcional...") |
| interface = create_interface() |
| interface.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| show_error=True |
| ) |
|
|
|
|