""" Entrenamiento DPO para el agente desktop. El agente guarda interacciones (screenshot + accion + reward). Este script entrena el modelo con DPO usando las mejores acciones como 'chosen'. """ import os import json from pathlib import Path from datetime import datetime import torch from datasets import Dataset from transformers import ( AutoModelForCausalLM, AutoProcessor, TrainingArguments, BitsAndBytesConfig, ) from trl import DPOTrainer, DPOConfig import trackio # Config MODEL_ID = os.getenv("TRAIN_MODEL", "huihui-ai/Huihui-Qwen3.5-35B-A3B-abliterated") LOGS_DIR = Path("/app/agent_logs") OUTPUT_DIR = f"/app/dpo_output/{datetime.now().strftime('%Y%m%d_%H%M%S')}" def load_interaction_logs(logs_dir: Path, min_reward: float = 0.5): """ Carga logs del agente y construye dataset DPO. Formato esperado de log: [ { "step": 1, "screenshot": "path/to/img.png", "action": "click(0.5, 0.3)", "reward": 1.0, # 1 = exito, 0 = fracaso, 0.5 = neutral "task": "Open Chrome..." } ] """ logs = [] for log_file in logs_dir.glob("*_log.json"): with open(log_file) as f: logs.extend(json.load(f)) # Agrupar por tarea tasks = {} for entry in logs: task = entry.get("task", "unknown") if task not in tasks: tasks[task] = [] tasks[task].append(entry) # Construir pares chosen/rejected para DPO dpo_data = [] for task, entries in tasks.items(): # Separar exitosos y fallidos successful = [e for e in entries if e.get("reward", 0) >= min_reward] failed = [e for e in entries if e.get("reward", 0) < min_reward] for good in successful: for bad in failed: dpo_data.append({ "prompt": f"Task: {task}\nScreenshot shows desktop. What action?", "chosen": good["action"], "rejected": bad["action"], }) return Dataset.from_list(dpo_data) def train_dpo( model_id: str = MODEL_ID, logs_dir: Path = LOGS_DIR, output_dir: str = OUTPUT_DIR, num_epochs: int = 3, batch_size: int = 1, gradient_accumulation: int = 4, learning_rate: float = 5e-7, ): """Entrena el modelo con DPO usando logs de interacciones.""" # Trackio trackio.init( project="desktop-agent-dpo", run_name=f"dpo_{datetime.now().strftime('%Y%m%d_%H%M%S')}", ) print(f"🧠 Modelo base: {model_id}") print(f"📁 Logs: {logs_dir}") print(f"📤 Output: {output_dir}") # Cargar dataset dataset = load_interaction_logs(logs_dir) print(f"📊 Dataset DPO: {len(dataset)} pares") if len(dataset) == 0: print("⚠️ No hay suficientes datos. El agente necesita interactuar primero.") return # Quantization para ahorrar VRAM bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, torch_dtype="auto", ) ref_model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, torch_dtype="auto", ) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) # Config DPO dpo_config = DPOConfig( output_dir=output_dir, num_train_epochs=num_epochs, per_device_train_batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation, learning_rate=learning_rate, logging_steps=10, save_steps=100, warmup_ratio=0.1, bf16=True, report_to="trackio", remove_unused_columns=False, ) # Trainer trainer = DPOTrainer( model=model, ref_model=ref_model, args=dpo_config, train_dataset=dataset, tokenizer=processor.tokenizer, ) print("🚀 Iniciando entrenamiento DPO...") trainer.train() # Guardar trainer.save_model(output_dir) processor.save_pretrained(output_dir) print(f"✅ Modelo guardado en: {output_dir}") # Subir a HF Hub hub_id = os.getenv("HF_HUB_MODEL_ID", "Matzan/desktop-agent-dpo") print(f"📤 Subiendo a Hugging Face: {hub_id}") trainer.push_to_hub(hub_id) trackio.finish() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--model", default=MODEL_ID) parser.add_argument("--logs", default=str(LOGS_DIR)) parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--lr", type=float, default=5e-7) args = parser.parse_args() train_dpo( model_id=args.model, logs_dir=Path(args.logs), num_epochs=args.epochs, learning_rate=args.lr, )