vectorplasticity commited on
Commit
1741386
·
verified ·
1 Parent(s): fbcefba

Add training utilities

Browse files
Files changed (1) hide show
  1. app/utils/training_utils.py +491 -0
app/utils/training_utils.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training Utilities - Helper functions for model training
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import json
8
+ import hashlib
9
+ from typing import Dict, Any, List, Optional, Tuple
10
+ from datetime import datetime
11
+ import torch
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ AutoModelForCausalLM,
15
+ AutoModelForSeq2SeqLM,
16
+ AutoModelForTokenClassification,
17
+ AutoModelForQuestionAnswering,
18
+ AutoModelForSequenceClassification,
19
+ AutoConfig,
20
+ TrainingArguments,
21
+ Trainer,
22
+ DataCollatorForLanguageModeling,
23
+ DataCollatorForSeq2Seq,
24
+ DataCollatorForTokenClassification,
25
+ )
26
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
27
+ from datasets import Dataset
28
+ import numpy as np
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def get_model_class_for_task(task_type: str):
34
+ """Get the appropriate model class for a task type."""
35
+ model_map = {
36
+ "causal-lm": AutoModelForCausalLM,
37
+ "seq2seq": AutoModelForSeq2SeqLM,
38
+ "token-classification": AutoModelForTokenClassification,
39
+ "question-answering": AutoModelForQuestionAnswering,
40
+ "text-classification": AutoModelForSequenceClassification,
41
+ }
42
+
43
+ if task_type not in model_map:
44
+ raise ValueError(f"Unknown task type: {task_type}")
45
+
46
+ return model_map[task_type]
47
+
48
+
49
+ def compute_model_hash(model_path: str) -> str:
50
+ """Compute a hash of model configuration for tracking."""
51
+ config_path = os.path.join(model_path, "config.json")
52
+ if os.path.exists(config_path):
53
+ with open(config_path, "rb") as f:
54
+ return hashlib.md5(f.read()).hexdigest()[:12]
55
+ return "unknown"
56
+
57
+
58
+ def estimate_memory_requirements(
59
+ model_name: str,
60
+ task_type: str,
61
+ batch_size: int = 1,
62
+ max_length: int = 512,
63
+ use_peft: bool = False
64
+ ) -> Dict[str, float]:
65
+ """Estimate memory requirements for training."""
66
+ try:
67
+ config = AutoConfig.from_pretrained(model_name)
68
+
69
+ # Estimate parameters
70
+ if hasattr(config, "hidden_size"):
71
+ hidden = config.hidden_size
72
+ elif hasattr(config, "n_embd"):
73
+ hidden = config.n_embd
74
+ else:
75
+ hidden = 768
76
+
77
+ if hasattr(config, "num_hidden_layers"):
78
+ layers = config.num_hidden_layers
79
+ elif hasattr(config, "n_layer"):
80
+ layers = config.n_layer
81
+ else:
82
+ layers = 12
83
+
84
+ # Rough parameter estimation
85
+ params = hidden ** 2 * layers * 12 # Very rough estimate
86
+ params_billion = params / 1e9
87
+
88
+ # Memory estimation (very approximate)
89
+ # FP32: 4 bytes per param, FP16: 2 bytes
90
+ model_memory_fp32 = params_billion * 4 # GB
91
+ model_memory_fp16 = params_billion * 2 # GB
92
+
93
+ # Gradients (same as model)
94
+ gradients_memory = model_memory_fp16
95
+
96
+ # Optimizer states (Adam: 2x model size)
97
+ optimizer_memory = model_memory_fp16 * 2
98
+
99
+ # Activations depend on batch size and sequence length
100
+ activation_memory = (batch_size * max_length * hidden * 4) / 1e9 # Rough estimate
101
+
102
+ # Total
103
+ if use_peft:
104
+ # PEFT reduces memory significantly
105
+ total_fp16 = (model_memory_fp16 * 0.1) + gradients_memory + optimizer_memory * 0.1 + activation_memory
106
+ else:
107
+ total_fp16 = model_memory_fp16 + gradients_memory + optimizer_memory + activation_memory
108
+
109
+ return {
110
+ "estimated_params_billion": round(params_billion, 2),
111
+ "model_memory_gb": round(model_memory_fp16, 2),
112
+ "optimizer_memory_gb": round(optimizer_memory, 2),
113
+ "activation_memory_gb": round(activation_memory, 2),
114
+ "total_memory_gb": round(total_fp16, 2),
115
+ "recommended_memory_gb": round(total_fp16 * 1.5, 2),
116
+ "can_run_on_cpu": total_fp16 < 8,
117
+ "recommended_hardware": "gpu" if total_fp16 > 4 else "cpu"
118
+ }
119
+
120
+ except Exception as e:
121
+ logger.warning(f"Could not estimate memory: {e}")
122
+ return {
123
+ "estimated_params_billion": 0.1,
124
+ "model_memory_gb": 0.5,
125
+ "optimizer_memory_gb": 1.0,
126
+ "activation_memory_gb": 0.5,
127
+ "total_memory_gb": 2.0,
128
+ "recommended_memory_gb": 4.0,
129
+ "can_run_on_cpu": True,
130
+ "recommended_hardware": "cpu"
131
+ }
132
+
133
+
134
+ def get_available_hardware() -> List[Dict[str, Any]]:
135
+ """Get available hardware options."""
136
+ hardware = [
137
+ {"id": "cpu-basic", "name": "CPU Basic", "memory_gb": 16, "gpu": False, "cost": "Free"},
138
+ {"id": "cpu-upgrade", "name": "CPU Upgrade", "memory_gb": 32, "gpu": False, "cost": "Low"},
139
+ {"id": "t4-small", "name": "T4 Small", "memory_gb": 16, "gpu": True, "gpu_memory_gb": 16, "cost": "Medium"},
140
+ {"id": "t4-medium", "name": "T4 Medium", "memory_gb": 32, "gpu": True, "gpu_memory_gb": 16, "cost": "Medium"},
141
+ {"id": "l4x1", "name": "L4 x1", "memory_gb": 32, "gpu": True, "gpu_memory_gb": 24, "cost": "High"},
142
+ {"id": "l4x4", "name": "L4 x4", "memory_gb": 96, "gpu": True, "gpu_memory_gb": 96, "cost": "Very High"},
143
+ {"id": "a10g-small", "name": "A10G Small", "memory_gb": 24, "gpu": True, "gpu_memory_gb": 24, "cost": "High"},
144
+ {"id": "a10g-large", "name": "A10G Large", "memory_gb": 48, "gpu": True, "gpu_memory_gb": 48, "cost": "Very High"},
145
+ {"id": "a100-large", "name": "A100 Large", "memory_gb": 80, "gpu": True, "gpu_memory_gb": 80, "cost": "Premium"},
146
+ ]
147
+
148
+ # Check what's actually available
149
+ if torch.cuda.is_available():
150
+ gpu_count = torch.cuda.device_count()
151
+ gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "Unknown"
152
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if gpu_count > 0 else 0
153
+
154
+ return hardware, {
155
+ "cuda_available": True,
156
+ "gpu_count": gpu_count,
157
+ "gpu_name": gpu_name,
158
+ "gpu_memory_gb": round(gpu_memory, 1)
159
+ }
160
+ else:
161
+ return hardware, {
162
+ "cuda_available": False,
163
+ "gpu_count": 0,
164
+ "gpu_name": None,
165
+ "gpu_memory_gb": 0
166
+ }
167
+
168
+
169
+ def get_training_args(
170
+ output_dir: str,
171
+ config: Dict[str, Any],
172
+ task_type: str
173
+ ) -> TrainingArguments:
174
+ """Create TrainingArguments from config."""
175
+
176
+ # Base arguments
177
+ args = {
178
+ "output_dir": output_dir,
179
+ "overwrite_output_dir": True,
180
+
181
+ # Training
182
+ "num_train_epochs": config.get("epochs", 3),
183
+ "per_device_train_batch_size": config.get("batch_size", 1),
184
+ "per_device_eval_batch_size": config.get("batch_size", 1),
185
+ "gradient_accumulation_steps": config.get("gradient_accumulation_steps", 1),
186
+
187
+ # Learning rate
188
+ "learning_rate": config.get("learning_rate", 5e-5),
189
+ "weight_decay": config.get("weight_decay", 0.01),
190
+ "warmup_steps": config.get("warmup_steps", 100),
191
+ "lr_scheduler_type": config.get("lr_scheduler_type", "cosine"),
192
+
193
+ # Logging
194
+ "logging_dir": os.path.join(output_dir, "logs"),
195
+ "logging_steps": config.get("logging_steps", 10),
196
+ "save_steps": config.get("save_steps", 500),
197
+ "save_total_limit": config.get("save_total_limit", 3),
198
+
199
+ # Evaluation
200
+ "evaluation_strategy": "steps" if config.get("eval_steps") else "no",
201
+ "eval_steps": config.get("eval_steps", 500),
202
+
203
+ # Optimization
204
+ "fp16": config.get("fp16", True) and torch.cuda.is_available(),
205
+ "bf16": config.get("bf16", False) and torch.cuda.is_bf16_supported(),
206
+
207
+ # Misc
208
+ "dataloader_num_workers": config.get("dataloader_num_workers", 0),
209
+ "dataloader_pin_memory": config.get("pin_memory", True) and torch.cuda.is_available(),
210
+ "gradient_checkpointing": config.get("gradient_checkpointing", False),
211
+
212
+ # Reporting
213
+ "report_to": config.get("report_to", ["none"]),
214
+
215
+ # Seed
216
+ "seed": config.get("seed", 42),
217
+ }
218
+
219
+ # Task-specific adjustments
220
+ if task_type == "causal-lm":
221
+ args["max_steps"] = config.get("max_steps", -1)
222
+ if config.get("max_length"):
223
+ args["max_length"] = config["max_length"]
224
+
225
+ elif task_type == "seq2seq":
226
+ args["predict_with_generate"] = config.get("predict_with_generate", False)
227
+ args["generation_max_length"] = config.get("generation_max_length", 128)
228
+ args["generation_num_beams"] = config.get("generation_num_beams", 4)
229
+
230
+ elif task_type == "token-classification":
231
+ args["label_names"] = config.get("label_names", [])
232
+
233
+ # DeepSpeed config if enabled
234
+ if config.get("deepspeed_config"):
235
+ args["deepspeed"] = config["deepspeed_config"]
236
+
237
+ return TrainingArguments(**args)
238
+
239
+
240
+ def get_peft_config(config: Dict[str, Any]) -> Optional[LoraConfig]:
241
+ """Create PEFT/LoRA config if enabled."""
242
+ if not config.get("use_peft", False):
243
+ return None
244
+
245
+ peft_config = LoraConfig(
246
+ r=config.get("lora_r", 8),
247
+ lora_alpha=config.get("lora_alpha", 32),
248
+ lora_dropout=config.get("lora_dropout", 0.1),
249
+ bias=config.get("lora_bias", "none"),
250
+ task_type=config.get("peft_task_type", "CAUSAL_LM"),
251
+ target_modules=config.get("lora_target_modules", None),
252
+ )
253
+
254
+ return peft_config
255
+
256
+
257
+ def get_data_collator(
258
+ tokenizer: Any,
259
+ task_type: str,
260
+ config: Dict[str, Any]
261
+ ) -> Any:
262
+ """Get appropriate data collator for task type."""
263
+
264
+ if task_type == "causal-lm":
265
+ return DataCollatorForLanguageModeling(
266
+ tokenizer=tokenizer,
267
+ mlm=False,
268
+ pad_to_multiple_of=config.get("pad_to_multiple_of", 8)
269
+ )
270
+
271
+ elif task_type == "seq2seq":
272
+ return DataCollatorForSeq2Seq(
273
+ tokenizer=tokenizer,
274
+ model=None,
275
+ padding=config.get("padding", "max_length"),
276
+ max_length=config.get("max_length", 512),
277
+ pad_to_multiple_of=config.get("pad_to_multiple_of", 8)
278
+ )
279
+
280
+ elif task_type == "token-classification":
281
+ return DataCollatorForTokenClassification(
282
+ tokenizer=tokenizer,
283
+ padding=config.get("padding", "max_length"),
284
+ max_length=config.get("max_length", 512),
285
+ pad_to_multiple_of=config.get("pad_to_multiple_of", 8)
286
+ )
287
+
288
+ elif task_type == "question-answering":
289
+ return DataCollatorForSeq2Seq(
290
+ tokenizer=tokenizer,
291
+ model=None,
292
+ padding=config.get("padding", "max_length"),
293
+ max_length=config.get("max_length", 384),
294
+ )
295
+
296
+ elif task_type == "text-classification":
297
+ from transformers import DataCollatorWithPadding
298
+ return DataCollatorWithPadding(
299
+ tokenizer=tokenizer,
300
+ padding=config.get("padding", "max_length"),
301
+ max_length=config.get("max_length", 512),
302
+ )
303
+
304
+ else:
305
+ logger.warning(f"Unknown task type {task_type}, using default collator")
306
+ from transformers import DataCollatorWithPadding
307
+ return DataCollatorWithPadding(tokenizer=tokenizer)
308
+
309
+
310
+ def compute_metrics_factory(task_type: str, tokenizer: Any = None):
311
+ """Factory for creating compute_metrics function."""
312
+
313
+ if task_type == "causal-lm":
314
+ def compute_metrics(eval_preds):
315
+ """Compute perplexity for language modeling."""
316
+ logits, labels = eval_preds
317
+ # Shift for causal LM
318
+ shift_logits = logits[..., :-1, :].contiguous()
319
+ shift_labels = labels[..., 1:].contiguous()
320
+
321
+ loss_fct = torch.nn.CrossEntropyLoss(reduction='mean')
322
+ loss = loss_fct(
323
+ shift_logits.view(-1, shift_logits.size(-1)),
324
+ shift_labels.view(-1)
325
+ )
326
+ perplexity = torch.exp(loss)
327
+
328
+ return {
329
+ "perplexity": perplexity.item(),
330
+ "loss": loss.item()
331
+ }
332
+ return compute_metrics
333
+
334
+ elif task_type == "seq2seq":
335
+ def compute_metrics(eval_preds):
336
+ """Compute ROUGE scores for summarization."""
337
+ from evaluate import load
338
+ rouge = load("rouge")
339
+
340
+ predictions, labels = eval_preds
341
+ decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
342
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
343
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
344
+
345
+ result = rouge.compute(
346
+ predictions=decoded_preds,
347
+ references=decoded_labels,
348
+ use_stemmer=True
349
+ )
350
+
351
+ return {k: round(v * 100, 4) for k, v in result.items()}
352
+ return compute_metrics
353
+
354
+ elif task_type == "token-classification":
355
+ def compute_metrics(eval_preds):
356
+ """Compute precision, recall, F1 for NER."""
357
+ from evaluate import load
358
+ seqeval = load("seqeval")
359
+
360
+ predictions, labels = eval_preds
361
+ predictions = np.argmax(predictions, axis=2)
362
+
363
+ # Remove ignored index
364
+ true_predictions = [
365
+ [p for (p, l) in zip(prediction, label) if l != -100]
366
+ for prediction, label in zip(predictions, labels)
367
+ ]
368
+ true_labels = [
369
+ [l for (p, l) in zip(prediction, label) if l != -100]
370
+ for prediction, label in zip(predictions, labels)
371
+ ]
372
+
373
+ results = seqeval.compute(predictions=true_predictions, references=true_labels)
374
+
375
+ return {
376
+ "precision": results["overall_precision"],
377
+ "recall": results["overall_recall"],
378
+ "f1": results["overall_f1"],
379
+ "accuracy": results["overall_accuracy"]
380
+ }
381
+ return compute_metrics
382
+
383
+ elif task_type == "text-classification":
384
+ def compute_metrics(eval_preds):
385
+ """Compute accuracy and F1 for classification."""
386
+ from sklearn.metrics import accuracy_score, f1_score
387
+
388
+ predictions, labels = eval_preds
389
+ predictions = np.argmax(predictions, axis=1)
390
+
391
+ return {
392
+ "accuracy": accuracy_score(labels, predictions),
393
+ "f1": f1_score(labels, predictions, average="weighted")
394
+ }
395
+ return compute_metrics
396
+
397
+ elif task_type == "question-answering":
398
+ def compute_metrics(eval_preds):
399
+ """Compute SQuAD metrics."""
400
+ from evaluate import load
401
+ squad_metric = load("squad_v2")
402
+
403
+ predictions, labels = eval_preds
404
+ # Process predictions and labels for QA
405
+ # This is simplified - real implementation needs proper post-processing
406
+
407
+ return {
408
+ "exact_match": 0.0,
409
+ "f1": 0.0
410
+ }
411
+ return compute_metrics
412
+
413
+ else:
414
+ def compute_metrics(eval_preds):
415
+ return {}
416
+ return compute_metrics
417
+
418
+
419
+ def save_training_artifacts(
420
+ output_dir: str,
421
+ model: Any,
422
+ tokenizer: Any,
423
+ config: Dict[str, Any],
424
+ metrics: Dict[str, float]
425
+ ) -> Dict[str, str]:
426
+ """Save training artifacts."""
427
+ os.makedirs(output_dir, exist_ok=True)
428
+
429
+ saved_files = []
430
+
431
+ # Save model
432
+ model.save_pretrained(output_dir)
433
+ saved_files.append("model")
434
+
435
+ # Save tokenizer
436
+ tokenizer.save_pretrained(output_dir)
437
+ saved_files.append("tokenizer")
438
+
439
+ # Save config
440
+ with open(os.path.join(output_dir, "training_config.json"), "w") as f:
441
+ json.dump(config, f, indent=2)
442
+ saved_files.append("training_config.json")
443
+
444
+ # Save metrics
445
+ with open(os.path.join(output_dir, "metrics.json"), "w") as f:
446
+ json.dump(metrics, f, indent=2)
447
+ saved_files.append("metrics.json")
448
+
449
+ # Create README
450
+ readme_content = f"""# Model Fine-tuned with Universal Model Trainer
451
+
452
+ ## Model Details
453
+ - Base Model: {config.get('model_name', 'Unknown')}
454
+ - Task: {config.get('task_type', 'Unknown')}
455
+ - Training Date: {datetime.utcnow().isoformat()}
456
+
457
+ ## Training Configuration
458
+ - Epochs: {config.get('epochs', 'Unknown')}
459
+ - Batch Size: {config.get('batch_size', 'Unknown')}
460
+ - Learning Rate: {config.get('learning_rate', 'Unknown')}
461
+ - PEFT/LoRA: {'Yes' if config.get('use_peft') else 'No'}
462
+
463
+ ## Metrics
464
+ ```
465
+ {json.dumps(metrics, indent=2)}
466
+ ```
467
+
468
+ ## Usage
469
+ ```python
470
+ from transformers import AutoModel, AutoTokenizer
471
+
472
+ model = AutoModel.from_pretrained("path/to/model")
473
+ tokenizer = AutoTokenizer.from_pretrained("path/to/model")
474
+ ```
475
+ """
476
+
477
+ with open(os.path.join(output_dir, "README.md"), "w") as f:
478
+ f.write(readme_content)
479
+ saved_files.append("README.md")
480
+
481
+ return {
482
+ "output_dir": output_dir,
483
+ "saved_files": saved_files,
484
+ "total_size": sum(os.path.getsize(os.path.join(output_dir, f)) for f in os.listdir(output_dir) if os.path.isfile(os.path.join(output_dir, f)))
485
+ }
486
+
487
+
488
+ def generate_job_id(config: Dict[str, Any]) -> str:
489
+ """Generate unique job ID."""
490
+ import uuid
491
+ return f"train_{config['task_type']}_{uuid.uuid4().hex[:8]}"