| """ |
| Multimodal PC Fault Detection - Training Script v2 |
| ==================================================== |
| Changes from v1: |
| - OGM-GE gradient modulation after loss.backward(), before optimizer.step() |
| - Asymmetric learning rates: higher for visual branch, lower for audio |
| - Auxiliary loss logging (loss_fusion, loss_visual, loss_audio per epoch) |
| - OGM-GE stats logging (visual_conf, audio_conf, modulation coefficients) |
| - Supports both old proxy data (dataset_real) and new built data (dataset_v2) |
| |
| Usage: |
| # With old proxy data (ToyADMOS + MVTec, default) |
| python train_v2.py --mode multimodal --finetune lora --eval_robustness |
| |
| # With new built dataset (from build_dataset.py) |
| python train_v2.py --dataset local --dataset_dir ../data/dataset_build --eval_robustness |
| python train_v2.py --dataset hub --hub_dataset Ellaft/pc-fault-real-dataset |
| |
| # Other options |
| python train_v2.py --mode visual_only --finetune lora --no_push |
| python train_v2.py --quick_test --no_push |
| |
| References: |
| OGM-GE: Peng et al., "Balanced Multimodal Learning via On-the-fly Gradient |
| Modulation", CVPR 2022 |
| """ |
|
|
| import os, sys, json, argparse, time |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import OneCycleLR |
| from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support |
|
|
| from config import ExperimentConfig, FAULT_CLASSES, NUM_CLASSES |
| from models_v2 import create_model, get_processors, OGMGEModulator |
|
|
|
|
| def compute_metrics(preds, labels, class_names=FAULT_CLASSES): |
| """Compute accuracy, F1, precision, recall, and confusion matrix.""" |
| accuracy = accuracy_score(labels, preds) |
| precision, recall, f1, support = precision_recall_fscore_support( |
| labels, preds, average=None, labels=range(len(class_names)), zero_division=0) |
| macro_f1 = f1_score(labels, preds, average="macro", zero_division=0) |
| weighted_f1 = f1_score(labels, preds, average="weighted", zero_division=0) |
| conf_matrix = confusion_matrix(labels, preds, labels=range(len(class_names))) |
|
|
| metrics = { |
| "accuracy": accuracy, |
| "macro_f1": macro_f1, |
| "weighted_f1": weighted_f1, |
| "confusion_matrix": conf_matrix.tolist(), |
| "per_class": {}, |
| } |
| for i, name in enumerate(class_names): |
| metrics["per_class"][name] = { |
| "precision": precision[i], "recall": recall[i], |
| "f1": f1[i], "support": int(support[i]), |
| } |
| return metrics |
|
|
|
|
| class MultimodalTrainerV2: |
| """ |
| Training loop v2 with OGM-GE gradient modulation. |
| |
| Key differences from v1: |
| 1. Three separate parameter groups with asymmetric LRs: |
| - visual_branch: higher LR (visual_lr_multiplier × base_lr) |
| - audio_branch: lower LR (audio_lr_multiplier × base_lr) |
| - fusion + auxiliary heads: base LR |
| 2. OGM-GE applied after backward(), before optimizer.step() |
| 3. Logs auxiliary losses and OGM-GE stats per epoch |
| """ |
|
|
| def __init__(self, model, train_dataset, val_dataset, config, device, |
| use_ogm=True, ogm_alpha=0.3, ogm_noise_sigma=0.1, |
| visual_lr_multiplier=3.0, audio_lr_multiplier=0.5, |
| collate_fn=None): |
| self.model = model.to(device) |
| self.device = device |
| self.config = config |
| self.use_ogm = use_ogm and (model.mode == "multimodal") |
|
|
| |
| if self.use_ogm: |
| self.ogm = OGMGEModulator(alpha=ogm_alpha, noise_sigma=ogm_noise_sigma) |
| print(f"[Trainer v2] OGM-GE enabled: alpha={ogm_alpha}, noise_sigma={ogm_noise_sigma}") |
| else: |
| self.ogm = None |
|
|
| |
| self.train_loader = DataLoader( |
| train_dataset, |
| batch_size=config.per_device_train_batch_size, |
| shuffle=True, |
| collate_fn=collate_fn, |
| num_workers=2, |
| pin_memory=True, |
| drop_last=True) |
| self.val_loader = DataLoader( |
| val_dataset, |
| batch_size=config.per_device_eval_batch_size, |
| shuffle=False, |
| collate_fn=collate_fn, |
| num_workers=2, |
| pin_memory=True) |
|
|
| |
| param_groups = self._get_param_groups(visual_lr_multiplier, audio_lr_multiplier) |
| self.optimizer = AdamW(param_groups, weight_decay=config.weight_decay) |
|
|
| total_steps = (len(self.train_loader) * config.num_epochs |
| // config.gradient_accumulation_steps) |
| self.scheduler = OneCycleLR( |
| self.optimizer, |
| max_lr=[pg["lr"] for pg in param_groups], |
| total_steps=max(total_steps, 1), |
| pct_start=config.warmup_ratio, |
| anneal_strategy="cos") |
|
|
| |
| self.scaler = (torch.amp.GradScaler("cuda") |
| if config.fp16 and device.type == "cuda" else None) |
|
|
| |
| self.best_metric = 0.0 |
| self.best_epoch = 0 |
| self.history = { |
| "train_loss": [], "val_loss": [], |
| "val_accuracy": [], "val_macro_f1": [], |
| "train_loss_fusion": [], "train_loss_visual": [], "train_loss_audio": [], |
| "ogm_visual_conf": [], "ogm_audio_conf": [], |
| "ogm_coeff_visual": [], "ogm_coeff_audio": [], |
| } |
|
|
| def _get_param_groups(self, visual_lr_multiplier, audio_lr_multiplier): |
| visual_params, audio_params, fusion_params = [], [], [] |
| for name, param in self.model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if "visual_branch" in name: |
| visual_params.append(param) |
| elif "audio_branch" in name: |
| audio_params.append(param) |
| else: |
| fusion_params.append(param) |
|
|
| base_lr = self.config.lora_learning_rate |
| groups = [] |
| if visual_params: |
| vlr = base_lr * visual_lr_multiplier |
| groups.append({"params": visual_params, "lr": vlr, "name": "visual_branch"}) |
| print(f"[Trainer v2] visual_branch: {len(visual_params)} tensors, lr={vlr:.2e}") |
| if audio_params: |
| alr = base_lr * audio_lr_multiplier |
| groups.append({"params": audio_params, "lr": alr, "name": "audio_branch"}) |
| print(f"[Trainer v2] audio_branch: {len(audio_params)} tensors, lr={alr:.2e}") |
| if fusion_params: |
| groups.append({"params": fusion_params, "lr": base_lr, "name": "fusion_heads"}) |
| print(f"[Trainer v2] fusion_heads: {len(fusion_params)} tensors, lr={base_lr:.2e}") |
| if not groups: |
| raise ValueError("No trainable parameters!") |
| return groups |
|
|
| def train_epoch(self, epoch): |
| self.model.train() |
| total_loss, total_loss_fusion, total_loss_visual, total_loss_audio = 0.0, 0.0, 0.0, 0.0 |
| num_batches = 0 |
| ogm_v_confs, ogm_a_confs, ogm_cv, ogm_ca = [], [], [], [] |
| self.optimizer.zero_grad() |
|
|
| for batch_idx, batch in enumerate(self.train_loader): |
| pv = batch["pixel_values"].to(self.device) |
| av = batch["audio_values"].to(self.device) |
| labels = batch["labels"].to(self.device) |
|
|
| if self.scaler: |
| with torch.amp.autocast("cuda"): |
| outputs = self.model(pixel_values=pv, audio_values=av, labels=labels) |
| loss = outputs["loss"] / self.config.gradient_accumulation_steps |
| self.scaler.scale(loss).backward() |
| else: |
| outputs = self.model(pixel_values=pv, audio_values=av, labels=labels) |
| loss = outputs["loss"] / self.config.gradient_accumulation_steps |
| loss.backward() |
|
|
| total_loss += loss.item() * self.config.gradient_accumulation_steps |
| num_batches += 1 |
| if "loss_fusion" in outputs: |
| total_loss_fusion += outputs["loss_fusion"] |
| total_loss_visual += outputs["loss_visual"] |
| total_loss_audio += outputs["loss_audio"] |
|
|
| if (self.use_ogm and self.ogm is not None |
| and "visual_logits" in outputs and "audio_logits" in outputs): |
| _cv, _ca, _stats = self.ogm.compute_modulation_coefficients( |
| outputs["visual_logits"], outputs["audio_logits"], labels) |
| ogm_v_confs.append(_stats["visual_conf"]) |
| ogm_a_confs.append(_stats["audio_conf"]) |
| ogm_cv.append(_stats["coeff_visual"]) |
| ogm_ca.append(_stats["coeff_audio"]) |
|
|
| if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0: |
| if self.scaler: |
| self.scaler.unscale_(self.optimizer) |
| if (self.use_ogm and self.ogm is not None and ogm_cv): |
| self.ogm.apply_gradient_modulation(self.model, ogm_cv[-1], ogm_ca[-1]) |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) |
| if self.scaler: |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| else: |
| self.optimizer.step() |
| self.scheduler.step() |
| self.optimizer.zero_grad() |
|
|
| if (batch_idx + 1) % self.config.logging_steps == 0 or batch_idx == 0: |
| avg_loss = total_loss / num_batches |
| msg = (f" [Epoch {epoch+1}] Step {batch_idx+1}/{len(self.train_loader)} " |
| f"| Loss: {avg_loss:.4f} | LR_v: {self.optimizer.param_groups[0]['lr']:.2e}") |
| if "loss_fusion" in outputs: |
| msg += (f" | L_fus: {total_loss_fusion/num_batches:.4f}" |
| f" L_vis: {total_loss_visual/num_batches:.4f}" |
| f" L_aud: {total_loss_audio/num_batches:.4f}") |
| if ogm_cv: |
| msg += f" | OGM c_v: {ogm_cv[-1]:.3f} c_a: {ogm_ca[-1]:.3f}" |
| print(msg) |
|
|
| n = max(num_batches, 1) |
| epoch_stats = {"train_loss": total_loss / n, "loss_fusion": total_loss_fusion / n, |
| "loss_visual": total_loss_visual / n, "loss_audio": total_loss_audio / n} |
| if ogm_v_confs: |
| epoch_stats.update({"ogm_visual_conf": np.mean(ogm_v_confs), "ogm_audio_conf": np.mean(ogm_a_confs), |
| "ogm_coeff_visual": np.mean(ogm_cv), "ogm_coeff_audio": np.mean(ogm_ca)}) |
| return epoch_stats |
|
|
| @torch.no_grad() |
| def evaluate(self, modality_mask=None): |
| self.model.eval() |
| all_preds, all_labels, total_loss, num_batches = [], [], 0.0, 0 |
| for batch in self.val_loader: |
| pv = batch["pixel_values"].to(self.device) |
| av = batch["audio_values"].to(self.device) |
| labels = batch["labels"].to(self.device) |
| if modality_mask: |
| if modality_mask.get("visual", 1.0) == 0.0: pv = torch.zeros_like(pv) |
| if modality_mask.get("audio", 1.0) == 0.0: av = torch.zeros_like(av) |
| outputs = self.model(pixel_values=pv, audio_values=av, labels=labels) |
| total_loss += outputs["loss"].item() |
| num_batches += 1 |
| all_preds.extend(outputs["logits"].argmax(dim=-1).cpu().numpy()) |
| all_labels.extend(labels.cpu().numpy()) |
| metrics = compute_metrics(np.array(all_preds), np.array(all_labels)) |
| metrics["val_loss"] = total_loss / max(num_batches, 1) |
| return metrics |
|
|
| def train(self): |
| print(f"\n{'='*60}") |
| print(f"Training v2: mode={self.model.mode}, epochs={self.config.num_epochs}, " |
| f"batch={self.config.per_device_train_batch_size}, device={self.device}") |
| print(f"OGM-GE: {'ENABLED' if self.use_ogm else 'DISABLED'}") |
| if self.model.mode == "multimodal": |
| print(f"Auxiliary loss weights: λ_visual={self.model.lambda_visual}, λ_audio={self.model.lambda_audio}") |
| print(f"{'='*60}\n") |
|
|
| for epoch in range(self.config.num_epochs): |
| t0 = time.time() |
| train_stats = self.train_epoch(epoch) |
| val_metrics = self.evaluate() |
| elapsed = time.time() - t0 |
|
|
| print(f"\n[Epoch {epoch+1}/{self.config.num_epochs}] ({elapsed:.1f}s)") |
| loss_msg = f" Train Loss: {train_stats['train_loss']:.4f}" |
| if train_stats.get("loss_fusion", 0) > 0: |
| loss_msg += (f" (fusion={train_stats['loss_fusion']:.4f} " |
| f"visual={train_stats['loss_visual']:.4f} audio={train_stats['loss_audio']:.4f})") |
| print(loss_msg) |
| print(f" Val Loss: {val_metrics['val_loss']:.4f} | Acc: {val_metrics['accuracy']:.4f} | F1: {val_metrics['macro_f1']:.4f}") |
| if "ogm_visual_conf" in train_stats: |
| print(f" OGM-GE: visual_conf={train_stats['ogm_visual_conf']:.4f} audio_conf={train_stats['ogm_audio_conf']:.4f} " |
| f"| coeff_v={train_stats['ogm_coeff_visual']:.4f} coeff_a={train_stats['ogm_coeff_audio']:.4f}") |
|
|
| self.history["train_loss"].append(train_stats["train_loss"]) |
| self.history["val_loss"].append(val_metrics["val_loss"]) |
| self.history["val_accuracy"].append(val_metrics["accuracy"]) |
| self.history["val_macro_f1"].append(val_metrics["macro_f1"]) |
| self.history["train_loss_fusion"].append(train_stats["loss_fusion"]) |
| self.history["train_loss_visual"].append(train_stats["loss_visual"]) |
| self.history["train_loss_audio"].append(train_stats["loss_audio"]) |
| if "ogm_visual_conf" in train_stats: |
| self.history["ogm_visual_conf"].append(train_stats["ogm_visual_conf"]) |
| self.history["ogm_audio_conf"].append(train_stats["ogm_audio_conf"]) |
| self.history["ogm_coeff_visual"].append(train_stats["ogm_coeff_visual"]) |
| self.history["ogm_coeff_audio"].append(train_stats["ogm_coeff_audio"]) |
|
|
| if val_metrics[self.config.metric_for_best_model] > self.best_metric: |
| self.best_metric = val_metrics[self.config.metric_for_best_model] |
| self.best_epoch = epoch + 1 |
| os.makedirs(self.config.output_dir, exist_ok=True) |
| torch.save({"model_state_dict": self.model.state_dict(), "epoch": epoch + 1, |
| "metrics": val_metrics}, os.path.join(self.config.output_dir, "best_model.pt")) |
| print(f" ✓ Best model saved (F1={self.best_metric:.4f})") |
|
|
| print(f"\nTraining complete. Best epoch={self.best_epoch}, Best F1={self.best_metric:.4f}") |
| return self.history |
|
|
| def run_robustness_evaluation(self): |
| print("\n=== Missing Modality Robustness Evaluation ===") |
| results = {} |
| for name, mask in [("both_modalities", None), ("visual_only", {"visual": 1.0, "audio": 0.0}), |
| ("audio_only", {"visual": 0.0, "audio": 1.0})]: |
| m = self.evaluate(modality_mask=mask) |
| results[name] = {"accuracy": m["accuracy"], "macro_f1": m["macro_f1"]} |
| print(f" {name:20s}: Acc={m['accuracy']:.4f} F1={m['macro_f1']:.4f}") |
| for cls, cls_m in m["per_class"].items(): |
| print(f" {cls:25s} P:{cls_m['precision']:.3f} R:{cls_m['recall']:.3f} F1:{cls_m['f1']:.3f}") |
| print("\n [Target] Visual-only should improve from ~0.23 acc / 0.08 F1 (v1)") |
| return results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Multimodal PC Fault Detection Training v2") |
| parser.add_argument("--mode", default="multimodal", choices=["multimodal", "visual_only", "audio_only"]) |
| parser.add_argument("--finetune", default="lora", choices=["lora", "full", "linear_probe"]) |
| parser.add_argument("--epochs", type=int) |
| parser.add_argument("--batch_size", type=int) |
| parser.add_argument("--lr", type=float) |
| parser.add_argument("--fusion", default="concat") |
| parser.add_argument("--modality_dropout", type=float) |
| parser.add_argument("--output_dir", type=str) |
| parser.add_argument("--hub_model_id", type=str) |
| parser.add_argument("--no_push", action="store_true") |
| parser.add_argument("--eval_robustness", action="store_true") |
| parser.add_argument("--quick_test", action="store_true") |
|
|
| |
| parser.add_argument("--dataset", default="proxy", |
| choices=["proxy", "local", "hub"], |
| help="Dataset source: 'proxy' (ToyADMOS+MVTec, default), " |
| "'local' (build_dataset.py output), 'hub' (HF Hub dataset)") |
| parser.add_argument("--dataset_dir", default="./dataset_build", |
| help="Path to build_dataset.py output (for --dataset local)") |
| parser.add_argument("--hub_dataset", default="Ellaft/pc-fault-real-dataset", |
| help="HuggingFace dataset ID (for --dataset hub)") |
|
|
| |
| parser.add_argument("--no_ogm", action="store_true") |
| parser.add_argument("--ogm_alpha", type=float, default=None) |
| parser.add_argument("--ogm_noise_sigma", type=float, default=None) |
| parser.add_argument("--lambda_visual", type=float, default=None) |
| parser.add_argument("--lambda_audio", type=float, default=None) |
| parser.add_argument("--visual_lr_mult", type=float, default=None) |
| parser.add_argument("--audio_lr_mult", type=float, default=None) |
|
|
| args = parser.parse_args() |
|
|
| |
| config = ExperimentConfig() |
| config.experiment_name = "multimodal_pc_fault_v2" |
| config.train.mode = args.mode |
| config.train.finetune_method = args.finetune |
| config.model.fusion_type = args.fusion |
|
|
| if args.epochs: config.train.num_epochs = args.epochs |
| if args.batch_size: config.train.per_device_train_batch_size = args.batch_size |
| if args.lr: config.train.learning_rate = config.train.lora_learning_rate = args.lr |
| if args.modality_dropout is not None: config.model.modality_dropout_p = args.modality_dropout |
| if args.output_dir: config.train.output_dir = args.output_dir |
| if args.hub_model_id: config.train.hub_model_id = args.hub_model_id |
| if args.no_push: config.train.push_to_hub = False |
| if args.quick_test: |
| config.train.num_epochs, config.train.per_device_train_batch_size = 2, 4 |
| config.train.per_device_eval_batch_size, config.train.gradient_accumulation_steps = 4, 1 |
| config.train.logging_steps = 2 |
| if args.finetune != "lora": config.lora.enabled = False |
|
|
| ogm_alpha = args.ogm_alpha if args.ogm_alpha is not None else config.ogm_alpha |
| ogm_noise_sigma = args.ogm_noise_sigma if args.ogm_noise_sigma is not None else config.ogm_noise_sigma |
| lambda_visual = args.lambda_visual if args.lambda_visual is not None else config.lambda_visual |
| lambda_audio = args.lambda_audio if args.lambda_audio is not None else config.lambda_audio |
| visual_lr_mult = args.visual_lr_mult if args.visual_lr_mult is not None else config.visual_lr_multiplier |
| audio_lr_mult = args.audio_lr_mult if args.audio_lr_mult is not None else config.audio_lr_multiplier |
| use_ogm = not args.no_ogm |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| torch.manual_seed(config.train.seed) |
| np.random.seed(config.train.seed) |
| if torch.cuda.is_available(): torch.cuda.manual_seed_all(config.train.seed) |
|
|
| print(f"\n{'='*60}") |
| print(f"Multimodal PC Fault Detection v2") |
| print(f"{'='*60}") |
| print(f"Mode: {args.mode} | Finetune: {args.finetune} | Device: {device}") |
| print(f"Dataset: {args.dataset}" + (f" ({args.dataset_dir})" if args.dataset == "local" else |
| f" ({args.hub_dataset})" if args.dataset == "hub" else " (ToyADMOS + MVTec proxy)")) |
| print(f"OGM-GE: {'ON' if use_ogm else 'OFF'} (alpha={ogm_alpha}, sigma={ogm_noise_sigma})") |
| print(f"Aux loss weights: λ_visual={lambda_visual}, λ_audio={lambda_audio}") |
| print(f"LR multipliers: visual={visual_lr_mult}x, audio={audio_lr_mult}x") |
| print(f"{'='*60}\n") |
|
|
| |
| vit_proc, ast_ext = get_processors(config.model) |
|
|
| |
| if args.dataset in ("local", "hub"): |
| from dataset_v2 import BuiltDataset as PCFaultDataset, multimodal_collate_fn |
| source = args.dataset |
| train_ds = PCFaultDataset( |
| config.data, config.model, "train", vit_proc, ast_ext, True, |
| source=source, dataset_dir=args.dataset_dir, hub_dataset=args.hub_dataset) |
| val_ds = PCFaultDataset( |
| config.data, config.model, "val", vit_proc, ast_ext, False, |
| source=source, dataset_dir=args.dataset_dir, hub_dataset=args.hub_dataset) |
| else: |
| |
| from dataset_real import RealPCFaultDataset as PCFaultDataset, multimodal_collate_fn |
| train_ds = PCFaultDataset(config.data, config.model, "train", vit_proc, ast_ext, True) |
| val_ds = PCFaultDataset(config.data, config.model, "val", vit_proc, ast_ext, False) |
|
|
| |
| model = create_model(config.model, config.lora, mode=args.mode, finetune_method=args.finetune, |
| use_ogm=use_ogm, lambda_visual=lambda_visual, lambda_audio=lambda_audio) |
|
|
| |
| trainer = MultimodalTrainerV2( |
| model, train_ds, val_ds, config.train, device, |
| use_ogm=use_ogm, ogm_alpha=ogm_alpha, ogm_noise_sigma=ogm_noise_sigma, |
| visual_lr_multiplier=visual_lr_mult, audio_lr_multiplier=audio_lr_mult, |
| collate_fn=multimodal_collate_fn) |
|
|
| |
| history = trainer.train() |
|
|
| |
| final = trainer.evaluate() |
| print(f"\nFinal Evaluation:") |
| print(f" Acc={final['accuracy']:.4f} F1={final['macro_f1']:.4f}") |
| for cls, m in final["per_class"].items(): |
| print(f" {cls:25s} P:{m['precision']:.3f} R:{m['recall']:.3f} F1:{m['f1']:.3f} N:{m['support']}") |
|
|
| robustness_results = None |
| if args.eval_robustness and config.train.mode == "multimodal": |
| robustness_results = trainer.run_robustness_evaluation() |
|
|
| |
| os.makedirs(config.train.output_dir, exist_ok=True) |
| results = { |
| "experiment": config.experiment_name, "version": "v2", |
| "mode": config.train.mode, "finetune_method": config.train.finetune_method, |
| "dataset_source": args.dataset, |
| "anti_collapse_config": { |
| "ogm_ge": use_ogm, "ogm_alpha": ogm_alpha, "ogm_noise_sigma": ogm_noise_sigma, |
| "lambda_visual": lambda_visual, "lambda_audio": lambda_audio, |
| "visual_lr_multiplier": visual_lr_mult, "audio_lr_multiplier": audio_lr_mult, |
| }, |
| "final_metrics": { |
| "accuracy": final["accuracy"], "macro_f1": final["macro_f1"], |
| "weighted_f1": final["weighted_f1"], "per_class": final["per_class"], |
| "confusion_matrix": final["confusion_matrix"], |
| }, |
| "history": history, "best_epoch": trainer.best_epoch, "best_metric": trainer.best_metric, |
| } |
| if robustness_results: results["robustness"] = robustness_results |
|
|
| with open(os.path.join(config.train.output_dir, "results_v2.json"), "w") as f: |
| json.dump(results, f, indent=2) |
| print(f"\nResults saved to {config.train.output_dir}/results_v2.json") |
|
|
| if config.train.push_to_hub: |
| try: |
| from huggingface_hub import HfApi, login |
| login(token=os.environ.get("HF_TOKEN")) |
| HfApi().upload_folder(folder_path=config.train.output_dir, repo_id=config.train.hub_model_id, |
| repo_type="model", commit_message=f"Training v2: {config.experiment_name} (OGM-GE)") |
| print(f"✓ Pushed to https://huggingface.co/{config.train.hub_model_id}") |
| except Exception as e: |
| print(f"✗ Push failed: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|