desktop-agent-uncensored / train_dpo.py
Matzan's picture
Upload train_dpo.py with huggingface_hub
2f08430 verified
"""
Entrenamiento DPO para el agente desktop.
El agente guarda interacciones (screenshot + accion + reward).
Este script entrena el modelo con DPO usando las mejores acciones como 'chosen'.
"""
import os
import json
from pathlib import Path
from datetime import datetime
import torch
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
TrainingArguments,
BitsAndBytesConfig,
)
from trl import DPOTrainer, DPOConfig
import trackio
# Config
MODEL_ID = os.getenv("TRAIN_MODEL", "huihui-ai/Huihui-Qwen3.5-35B-A3B-abliterated")
LOGS_DIR = Path("/app/agent_logs")
OUTPUT_DIR = f"/app/dpo_output/{datetime.now().strftime('%Y%m%d_%H%M%S')}"
def load_interaction_logs(logs_dir: Path, min_reward: float = 0.5):
"""
Carga logs del agente y construye dataset DPO.
Formato esperado de log:
[
{
"step": 1,
"screenshot": "path/to/img.png",
"action": "click(0.5, 0.3)",
"reward": 1.0, # 1 = exito, 0 = fracaso, 0.5 = neutral
"task": "Open Chrome..."
}
]
"""
logs = []
for log_file in logs_dir.glob("*_log.json"):
with open(log_file) as f:
logs.extend(json.load(f))
# Agrupar por tarea
tasks = {}
for entry in logs:
task = entry.get("task", "unknown")
if task not in tasks:
tasks[task] = []
tasks[task].append(entry)
# Construir pares chosen/rejected para DPO
dpo_data = []
for task, entries in tasks.items():
# Separar exitosos y fallidos
successful = [e for e in entries if e.get("reward", 0) >= min_reward]
failed = [e for e in entries if e.get("reward", 0) < min_reward]
for good in successful:
for bad in failed:
dpo_data.append({
"prompt": f"Task: {task}\nScreenshot shows desktop. What action?",
"chosen": good["action"],
"rejected": bad["action"],
})
return Dataset.from_list(dpo_data)
def train_dpo(
model_id: str = MODEL_ID,
logs_dir: Path = LOGS_DIR,
output_dir: str = OUTPUT_DIR,
num_epochs: int = 3,
batch_size: int = 1,
gradient_accumulation: int = 4,
learning_rate: float = 5e-7,
):
"""Entrena el modelo con DPO usando logs de interacciones."""
# Trackio
trackio.init(
project="desktop-agent-dpo",
run_name=f"dpo_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
)
print(f"🧠 Modelo base: {model_id}")
print(f"πŸ“ Logs: {logs_dir}")
print(f"πŸ“€ Output: {output_dir}")
# Cargar dataset
dataset = load_interaction_logs(logs_dir)
print(f"πŸ“Š Dataset DPO: {len(dataset)} pares")
if len(dataset) == 0:
print("⚠️ No hay suficientes datos. El agente necesita interactuar primero.")
return
# Quantization para ahorrar VRAM
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
torch_dtype="auto",
)
ref_model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
torch_dtype="auto",
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# Config DPO
dpo_config = DPOConfig(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation,
learning_rate=learning_rate,
logging_steps=10,
save_steps=100,
warmup_ratio=0.1,
bf16=True,
report_to="trackio",
remove_unused_columns=False,
)
# Trainer
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=dpo_config,
train_dataset=dataset,
tokenizer=processor.tokenizer,
)
print("πŸš€ Iniciando entrenamiento DPO...")
trainer.train()
# Guardar
trainer.save_model(output_dir)
processor.save_pretrained(output_dir)
print(f"βœ… Modelo guardado en: {output_dir}")
# Subir a HF Hub
hub_id = os.getenv("HF_HUB_MODEL_ID", "Matzan/desktop-agent-dpo")
print(f"πŸ“€ Subiendo a Hugging Face: {hub_id}")
trainer.push_to_hub(hub_id)
trackio.finish()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model", default=MODEL_ID)
parser.add_argument("--logs", default=str(LOGS_DIR))
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--lr", type=float, default=5e-7)
args = parser.parse_args()
train_dpo(
model_id=args.model,
logs_dir=Path(args.logs),
num_epochs=args.epochs,
learning_rate=args.lr,
)