import sys import argparse from pathlib import Path import torch import torch.nn as nn from omegaconf import OmegaConf ROOT = Path(__file__).resolve().parent.parent.parent sys.path.append(str(ROOT)) sys.path.append(str(ROOT / 'flow_matching')) sys.path.append(str(ROOT / 'flow_matching/Matcha-TTS')) from flow_matching.src.stage1.medarc_architecture import MultiSubjectConvLinearEncoder from flow_matching.src.stage2.CFM import CFM def export_stage1(cfg): print('Exporting Stage 1...') subjects_list = cfg.get('subjects', [1, 2, 3, 5]) feat_dims = (32,) model = MultiSubjectConvLinearEncoder( num_subjects=len(subjects_list), feat_dims=feat_dims, **cfg.stage1.model ) weights_path = ROOT / 'output' / 'debug_run' / 'stage1_best.pt' if weights_path.exists(): state_dict = torch.load(weights_path, map_location='cpu', weights_only=True) model.load_state_dict(state_dict, strict=False) print('Loaded Stage 1 weights.') model.eval() dummy_input = [torch.randn(2, 10, 32)] save_path = 'stage1_model.onnx' try: torch.onnx.export( model, dummy_input, save_path, opset_version=14, input_names=['features'], output_names=['embeddings'] ) print(f'Saved {save_path}') except Exception as e: print(f'Failed exporting Stage 1: {e}') def export_stage2(cfg, target_dim=1000): print('\nExporting Stage 2...') cfm_params = cfg.stage2.cfm velocity_net_params = cfg.stage2.velocity_net source_ve_params = cfg.stage2.source_ve transport_params = cfg.stage2.transport model = CFM( feat_dim=target_dim, cfm_params=cfm_params, velocity_net_params=velocity_net_params, source_ve_params=source_ve_params, transport_params=transport_params, ) weights_path = ROOT / 'output' / 'debug_run' / 'stage2_epoch_0.pt' if weights_path.exists(): state_dict = torch.load(weights_path, map_location='cpu', weights_only=True) model.load_state_dict(state_dict, strict=False) print('Loaded Stage 2 weights.') model.eval() class Stage2ExportWrapper(nn.Module): def __init__(self, cfm_model: nn.Module, steps: int): super().__init__() self.cfm_model = cfm_model self.steps = steps def forward(self, mu): return self.cfm_model(mu, n_timesteps=self.steps) export_wrapper = Stage2ExportWrapper(model, steps=10) export_wrapper.eval() # Dummy input: (B, C, T) B, C, T = 1, target_dim, 10 mu = torch.randn(B, C, T) save_path = 'stage2_model.onnx' try: torch.onnx.export( export_wrapper, mu, save_path, opset_version=14, input_names=['mu'], output_names=['output'] ) print(f'Saved {save_path}') except Exception as e: print(f'Failed exporting Stage 2: {e}') if __name__ == '__main__': config_path = ROOT / 'output' / 'debug_run' / 'config.yaml' if config_path.exists(): cfg = OmegaConf.load(config_path) else: # Dummy config if config not found cfg = OmegaConf.create( { 'stage1': {'model': {}}, 'stage2': { 'cfm': {'solver': 'euler', 'kld_weight': 3.0, 'kld_target_std': 1.0, 'detach_ut': False}, 'velocity_net': {'hidden_dim': 256, 'n_blocks': 2, 'n_heads': 4, 'dropout': 0.1, 'max_seq_len': 2048, 'temporal_attn_layers': 1}, 'source_ve': {'depth': 2, 'num_heads': 4, 'num_queries': 8, 'dropout': 0.1, 'use_variational': True}, 'transport': {'path_type': 'Linear', 'prediction': 'velocity', 'time_dist_type': 'uniform', 'time_dist_shift': 1.0}, }, } ) export_stage1(cfg) export_stage2(cfg, target_dim=1000)