| """ |
| Entrenamiento DPO para el agente desktop. |
| El agente guarda interacciones (screenshot + accion + reward). |
| Este script entrena el modelo con DPO usando las mejores acciones como 'chosen'. |
| """ |
|
|
| import os |
| import json |
| from pathlib import Path |
| from datetime import datetime |
|
|
| import torch |
| from datasets import Dataset |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoProcessor, |
| TrainingArguments, |
| BitsAndBytesConfig, |
| ) |
| from trl import DPOTrainer, DPOConfig |
| import trackio |
|
|
| |
| MODEL_ID = os.getenv("TRAIN_MODEL", "huihui-ai/Huihui-Qwen3.5-35B-A3B-abliterated") |
| LOGS_DIR = Path("/app/agent_logs") |
| OUTPUT_DIR = f"/app/dpo_output/{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
|
|
|
| def load_interaction_logs(logs_dir: Path, min_reward: float = 0.5): |
| """ |
| Carga logs del agente y construye dataset DPO. |
| |
| Formato esperado de log: |
| [ |
| { |
| "step": 1, |
| "screenshot": "path/to/img.png", |
| "action": "click(0.5, 0.3)", |
| "reward": 1.0, # 1 = exito, 0 = fracaso, 0.5 = neutral |
| "task": "Open Chrome..." |
| } |
| ] |
| """ |
| logs = [] |
| for log_file in logs_dir.glob("*_log.json"): |
| with open(log_file) as f: |
| logs.extend(json.load(f)) |
| |
| |
| tasks = {} |
| for entry in logs: |
| task = entry.get("task", "unknown") |
| if task not in tasks: |
| tasks[task] = [] |
| tasks[task].append(entry) |
| |
| |
| dpo_data = [] |
| for task, entries in tasks.items(): |
| |
| successful = [e for e in entries if e.get("reward", 0) >= min_reward] |
| failed = [e for e in entries if e.get("reward", 0) < min_reward] |
| |
| for good in successful: |
| for bad in failed: |
| dpo_data.append({ |
| "prompt": f"Task: {task}\nScreenshot shows desktop. What action?", |
| "chosen": good["action"], |
| "rejected": bad["action"], |
| }) |
| |
| return Dataset.from_list(dpo_data) |
|
|
|
|
| def train_dpo( |
| model_id: str = MODEL_ID, |
| logs_dir: Path = LOGS_DIR, |
| output_dir: str = OUTPUT_DIR, |
| num_epochs: int = 3, |
| batch_size: int = 1, |
| gradient_accumulation: int = 4, |
| learning_rate: float = 5e-7, |
| ): |
| """Entrena el modelo con DPO usando logs de interacciones.""" |
| |
| |
| trackio.init( |
| project="desktop-agent-dpo", |
| run_name=f"dpo_{datetime.now().strftime('%Y%m%d_%H%M%S')}", |
| ) |
| |
| print(f"π§ Modelo base: {model_id}") |
| print(f"π Logs: {logs_dir}") |
| print(f"π€ Output: {output_dir}") |
| |
| |
| dataset = load_interaction_logs(logs_dir) |
| print(f"π Dataset DPO: {len(dataset)} pares") |
| |
| if len(dataset) == 0: |
| print("β οΈ No hay suficientes datos. El agente necesita interactuar primero.") |
| return |
| |
| |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| quantization_config=bnb_config, |
| device_map="auto", |
| trust_remote_code=True, |
| torch_dtype="auto", |
| ) |
| |
| ref_model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| quantization_config=bnb_config, |
| device_map="auto", |
| trust_remote_code=True, |
| torch_dtype="auto", |
| ) |
| |
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
| |
| |
| dpo_config = DPOConfig( |
| output_dir=output_dir, |
| num_train_epochs=num_epochs, |
| per_device_train_batch_size=batch_size, |
| gradient_accumulation_steps=gradient_accumulation, |
| learning_rate=learning_rate, |
| logging_steps=10, |
| save_steps=100, |
| warmup_ratio=0.1, |
| bf16=True, |
| report_to="trackio", |
| remove_unused_columns=False, |
| ) |
| |
| |
| trainer = DPOTrainer( |
| model=model, |
| ref_model=ref_model, |
| args=dpo_config, |
| train_dataset=dataset, |
| tokenizer=processor.tokenizer, |
| ) |
| |
| print("π Iniciando entrenamiento DPO...") |
| trainer.train() |
| |
| |
| trainer.save_model(output_dir) |
| processor.save_pretrained(output_dir) |
| |
| print(f"β
Modelo guardado en: {output_dir}") |
| |
| |
| hub_id = os.getenv("HF_HUB_MODEL_ID", "Matzan/desktop-agent-dpo") |
| print(f"π€ Subiendo a Hugging Face: {hub_id}") |
| trainer.push_to_hub(hub_id) |
| |
| trackio.finish() |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", default=MODEL_ID) |
| parser.add_argument("--logs", default=str(LOGS_DIR)) |
| parser.add_argument("--epochs", type=int, default=3) |
| parser.add_argument("--lr", type=float, default=5e-7) |
| |
| args = parser.parse_args() |
| |
| train_dpo( |
| model_id=args.model, |
| logs_dir=Path(args.logs), |
| num_epochs=args.epochs, |
| learning_rate=args.lr, |
| ) |
|
|