import argparse import sys import unittest.mock from pathlib import Path import torch import torch.nn as nn import numpy as np from omegaconf import OmegaConf # Add Matcha-TTS to python path to access its modules ROOT = Path(__file__).resolve().parent.parent sys.path.append(str(ROOT / "Matcha-TTS")) sys.path.append(str(ROOT)) import src.training from src.stage1.medarc_architecture import MultiSubjectConvLinearEncoder from src.stage2.CFM import CFM from torch.utils.data import DataLoader, Dataset class MockDataset(Dataset): def __init__( self, num_samples, num_subjects=4, time_steps=10, voxels=100, feat_dims=(32, 64) ): self.num_samples = num_samples self.num_subjects = num_subjects self.time_steps = time_steps self.voxels = voxels self.feat_dims = feat_dims def __len__(self): return self.num_samples def __getitem__(self, idx): # features list features = [torch.randn(self.time_steps, dim) for dim in self.feat_dims] # fmri: (S, T, V) fmri = torch.randn(self.num_subjects, self.time_steps, self.voxels) return {"features": features, "fmri": fmri} def mock_make_data_loaders(cfg): print("MOCKING DATA LOADERS FOR DEBUG") # Using small dimensions for debug num_samples = 4 batch_size = cfg.batch_size # Mock dimensions voxels = 1000 feat_dims = (32, 64) ds = MockDataset(num_samples=num_samples, voxels=voxels, feat_dims=feat_dims) loader = DataLoader(ds, batch_size=batch_size) return {"train": loader, "val_debug": loader} # Use same for val def main(): # Patch the make_data_loaders in training.py with unittest.mock.patch( "src.training.make_data_loaders", side_effect=mock_make_data_loaders ): # Manually set arguments to point to debug config # Or better yet, call main() but intercept argument parsing? # Since training.main() parses args, we can simulate command line args. # Override sys.argv sys.argv = ["training.py", "--cfg-path", "test/debug_config.yml"] # Call original main try: src.training.main() except Exception as e: print(f"Caught exception during debug run: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()