flow-matching / test /debug_training.py
sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
import argparse
import sys
import unittest.mock
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
from omegaconf import OmegaConf
# Add Matcha-TTS to python path to access its modules
ROOT = Path(__file__).resolve().parent.parent
sys.path.append(str(ROOT / "Matcha-TTS"))
sys.path.append(str(ROOT))
import src.training
from src.stage1.medarc_architecture import MultiSubjectConvLinearEncoder
from src.stage2.CFM import CFM
from torch.utils.data import DataLoader, Dataset
class MockDataset(Dataset):
def __init__(
self, num_samples, num_subjects=4, time_steps=10, voxels=100, feat_dims=(32, 64)
):
self.num_samples = num_samples
self.num_subjects = num_subjects
self.time_steps = time_steps
self.voxels = voxels
self.feat_dims = feat_dims
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# features list
features = [torch.randn(self.time_steps, dim) for dim in self.feat_dims]
# fmri: (S, T, V)
fmri = torch.randn(self.num_subjects, self.time_steps, self.voxels)
return {"features": features, "fmri": fmri}
def mock_make_data_loaders(cfg):
print("MOCKING DATA LOADERS FOR DEBUG")
# Using small dimensions for debug
num_samples = 4
batch_size = cfg.batch_size
# Mock dimensions
voxels = 1000
feat_dims = (32, 64)
ds = MockDataset(num_samples=num_samples, voxels=voxels, feat_dims=feat_dims)
loader = DataLoader(ds, batch_size=batch_size)
return {"train": loader, "val_debug": loader} # Use same for val
def main():
# Patch the make_data_loaders in training.py
with unittest.mock.patch(
"src.training.make_data_loaders", side_effect=mock_make_data_loaders
):
# Manually set arguments to point to debug config
# Or better yet, call main() but intercept argument parsing?
# Since training.main() parses args, we can simulate command line args.
# Override sys.argv
sys.argv = ["training.py", "--cfg-path", "test/debug_config.yml"]
# Call original main
try:
src.training.main()
except Exception as e:
print(f"Caught exception during debug run: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()