Matzan commited on
Commit
2f08430
·
verified ·
1 Parent(s): bfcb2a0

Upload train_dpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_dpo.py +186 -0
train_dpo.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Entrenamiento DPO para el agente desktop.
3
+ El agente guarda interacciones (screenshot + accion + reward).
4
+ Este script entrena el modelo con DPO usando las mejores acciones como 'chosen'.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+ from datetime import datetime
11
+
12
+ import torch
13
+ from datasets import Dataset
14
+ from transformers import (
15
+ AutoModelForCausalLM,
16
+ AutoProcessor,
17
+ TrainingArguments,
18
+ BitsAndBytesConfig,
19
+ )
20
+ from trl import DPOTrainer, DPOConfig
21
+ import trackio
22
+
23
+ # Config
24
+ MODEL_ID = os.getenv("TRAIN_MODEL", "huihui-ai/Huihui-Qwen3.5-35B-A3B-abliterated")
25
+ LOGS_DIR = Path("/app/agent_logs")
26
+ OUTPUT_DIR = f"/app/dpo_output/{datetime.now().strftime('%Y%m%d_%H%M%S')}"
27
+
28
+
29
+ def load_interaction_logs(logs_dir: Path, min_reward: float = 0.5):
30
+ """
31
+ Carga logs del agente y construye dataset DPO.
32
+
33
+ Formato esperado de log:
34
+ [
35
+ {
36
+ "step": 1,
37
+ "screenshot": "path/to/img.png",
38
+ "action": "click(0.5, 0.3)",
39
+ "reward": 1.0, # 1 = exito, 0 = fracaso, 0.5 = neutral
40
+ "task": "Open Chrome..."
41
+ }
42
+ ]
43
+ """
44
+ logs = []
45
+ for log_file in logs_dir.glob("*_log.json"):
46
+ with open(log_file) as f:
47
+ logs.extend(json.load(f))
48
+
49
+ # Agrupar por tarea
50
+ tasks = {}
51
+ for entry in logs:
52
+ task = entry.get("task", "unknown")
53
+ if task not in tasks:
54
+ tasks[task] = []
55
+ tasks[task].append(entry)
56
+
57
+ # Construir pares chosen/rejected para DPO
58
+ dpo_data = []
59
+ for task, entries in tasks.items():
60
+ # Separar exitosos y fallidos
61
+ successful = [e for e in entries if e.get("reward", 0) >= min_reward]
62
+ failed = [e for e in entries if e.get("reward", 0) < min_reward]
63
+
64
+ for good in successful:
65
+ for bad in failed:
66
+ dpo_data.append({
67
+ "prompt": f"Task: {task}\nScreenshot shows desktop. What action?",
68
+ "chosen": good["action"],
69
+ "rejected": bad["action"],
70
+ })
71
+
72
+ return Dataset.from_list(dpo_data)
73
+
74
+
75
+ def train_dpo(
76
+ model_id: str = MODEL_ID,
77
+ logs_dir: Path = LOGS_DIR,
78
+ output_dir: str = OUTPUT_DIR,
79
+ num_epochs: int = 3,
80
+ batch_size: int = 1,
81
+ gradient_accumulation: int = 4,
82
+ learning_rate: float = 5e-7,
83
+ ):
84
+ """Entrena el modelo con DPO usando logs de interacciones."""
85
+
86
+ # Trackio
87
+ trackio.init(
88
+ project="desktop-agent-dpo",
89
+ run_name=f"dpo_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
90
+ )
91
+
92
+ print(f"🧠 Modelo base: {model_id}")
93
+ print(f"📁 Logs: {logs_dir}")
94
+ print(f"📤 Output: {output_dir}")
95
+
96
+ # Cargar dataset
97
+ dataset = load_interaction_logs(logs_dir)
98
+ print(f"📊 Dataset DPO: {len(dataset)} pares")
99
+
100
+ if len(dataset) == 0:
101
+ print("⚠️ No hay suficientes datos. El agente necesita interactuar primero.")
102
+ return
103
+
104
+ # Quantization para ahorrar VRAM
105
+ bnb_config = BitsAndBytesConfig(
106
+ load_in_4bit=True,
107
+ bnb_4bit_compute_dtype=torch.bfloat16,
108
+ bnb_4bit_use_double_quant=True,
109
+ )
110
+
111
+ model = AutoModelForCausalLM.from_pretrained(
112
+ model_id,
113
+ quantization_config=bnb_config,
114
+ device_map="auto",
115
+ trust_remote_code=True,
116
+ torch_dtype="auto",
117
+ )
118
+
119
+ ref_model = AutoModelForCausalLM.from_pretrained(
120
+ model_id,
121
+ quantization_config=bnb_config,
122
+ device_map="auto",
123
+ trust_remote_code=True,
124
+ torch_dtype="auto",
125
+ )
126
+
127
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
128
+
129
+ # Config DPO
130
+ dpo_config = DPOConfig(
131
+ output_dir=output_dir,
132
+ num_train_epochs=num_epochs,
133
+ per_device_train_batch_size=batch_size,
134
+ gradient_accumulation_steps=gradient_accumulation,
135
+ learning_rate=learning_rate,
136
+ logging_steps=10,
137
+ save_steps=100,
138
+ warmup_ratio=0.1,
139
+ bf16=True,
140
+ report_to="trackio",
141
+ remove_unused_columns=False,
142
+ )
143
+
144
+ # Trainer
145
+ trainer = DPOTrainer(
146
+ model=model,
147
+ ref_model=ref_model,
148
+ args=dpo_config,
149
+ train_dataset=dataset,
150
+ tokenizer=processor.tokenizer,
151
+ )
152
+
153
+ print("🚀 Iniciando entrenamiento DPO...")
154
+ trainer.train()
155
+
156
+ # Guardar
157
+ trainer.save_model(output_dir)
158
+ processor.save_pretrained(output_dir)
159
+
160
+ print(f"✅ Modelo guardado en: {output_dir}")
161
+
162
+ # Subir a HF Hub
163
+ hub_id = os.getenv("HF_HUB_MODEL_ID", "Matzan/desktop-agent-dpo")
164
+ print(f"📤 Subiendo a Hugging Face: {hub_id}")
165
+ trainer.push_to_hub(hub_id)
166
+
167
+ trackio.finish()
168
+
169
+
170
+ if __name__ == "__main__":
171
+ import argparse
172
+
173
+ parser = argparse.ArgumentParser()
174
+ parser.add_argument("--model", default=MODEL_ID)
175
+ parser.add_argument("--logs", default=str(LOGS_DIR))
176
+ parser.add_argument("--epochs", type=int, default=3)
177
+ parser.add_argument("--lr", type=float, default=5e-7)
178
+
179
+ args = parser.parse_args()
180
+
181
+ train_dpo(
182
+ model_id=args.model,
183
+ logs_dir=Path(args.logs),
184
+ num_epochs=args.epochs,
185
+ learning_rate=args.lr,
186
+ )