| import sys |
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| from omegaconf import OmegaConf |
|
|
| ROOT = Path(__file__).resolve().parent.parent.parent |
| sys.path.append(str(ROOT)) |
| sys.path.append(str(ROOT / 'flow_matching')) |
| sys.path.append(str(ROOT / 'flow_matching/Matcha-TTS')) |
|
|
| from flow_matching.src.stage1.medarc_architecture import MultiSubjectConvLinearEncoder |
| from flow_matching.src.stage2.CFM import CFM |
|
|
| def export_stage1(cfg): |
| print('Exporting Stage 1...') |
| |
| subjects_list = cfg.get('subjects', [1, 2, 3, 5]) |
| feat_dims = (32,) |
| |
| model = MultiSubjectConvLinearEncoder( |
| num_subjects=len(subjects_list), |
| feat_dims=feat_dims, |
| **cfg.stage1.model |
| ) |
| |
| weights_path = ROOT / 'output' / 'debug_run' / 'stage1_best.pt' |
| if weights_path.exists(): |
| state_dict = torch.load(weights_path, map_location='cpu', weights_only=True) |
| model.load_state_dict(state_dict, strict=False) |
| print('Loaded Stage 1 weights.') |
|
|
| model.eval() |
|
|
| dummy_input = [torch.randn(2, 10, 32)] |
| save_path = 'stage1_model.onnx' |
| |
| try: |
| torch.onnx.export( |
| model, |
| dummy_input, |
| save_path, |
| opset_version=14, |
| input_names=['features'], |
| output_names=['embeddings'] |
| ) |
| print(f'Saved {save_path}') |
| except Exception as e: |
| print(f'Failed exporting Stage 1: {e}') |
|
|
| def export_stage2(cfg, target_dim=1000): |
| print('\nExporting Stage 2...') |
| |
| cfm_params = cfg.stage2.cfm |
| velocity_net_params = cfg.stage2.velocity_net |
| source_ve_params = cfg.stage2.source_ve |
| transport_params = cfg.stage2.transport |
| |
| model = CFM( |
| feat_dim=target_dim, |
| cfm_params=cfm_params, |
| velocity_net_params=velocity_net_params, |
| source_ve_params=source_ve_params, |
| transport_params=transport_params, |
| ) |
| |
| weights_path = ROOT / 'output' / 'debug_run' / 'stage2_epoch_0.pt' |
| if weights_path.exists(): |
| state_dict = torch.load(weights_path, map_location='cpu', weights_only=True) |
| model.load_state_dict(state_dict, strict=False) |
| print('Loaded Stage 2 weights.') |
| |
| model.eval() |
|
|
| class Stage2ExportWrapper(nn.Module): |
| def __init__(self, cfm_model: nn.Module, steps: int): |
| super().__init__() |
| self.cfm_model = cfm_model |
| self.steps = steps |
|
|
| def forward(self, mu): |
| return self.cfm_model(mu, n_timesteps=self.steps) |
|
|
| export_wrapper = Stage2ExportWrapper(model, steps=10) |
| export_wrapper.eval() |
| |
| |
| B, C, T = 1, target_dim, 10 |
| mu = torch.randn(B, C, T) |
| |
| save_path = 'stage2_model.onnx' |
| try: |
| torch.onnx.export( |
| export_wrapper, |
| mu, |
| save_path, |
| opset_version=14, |
| input_names=['mu'], |
| output_names=['output'] |
| ) |
| print(f'Saved {save_path}') |
| except Exception as e: |
| print(f'Failed exporting Stage 2: {e}') |
|
|
| if __name__ == '__main__': |
| config_path = ROOT / 'output' / 'debug_run' / 'config.yaml' |
| if config_path.exists(): |
| cfg = OmegaConf.load(config_path) |
| else: |
| |
| cfg = OmegaConf.create( |
| { |
| 'stage1': {'model': {}}, |
| 'stage2': { |
| 'cfm': {'solver': 'euler', 'kld_weight': 3.0, 'kld_target_std': 1.0, 'detach_ut': False}, |
| 'velocity_net': {'hidden_dim': 256, 'n_blocks': 2, 'n_heads': 4, 'dropout': 0.1, 'max_seq_len': 2048, 'temporal_attn_layers': 1}, |
| 'source_ve': {'depth': 2, 'num_heads': 4, 'num_queries': 8, 'dropout': 0.1, 'use_variational': True}, |
| 'transport': {'path_type': 'Linear', 'prediction': 'velocity', 'time_dist_type': 'uniform', 'time_dist_shift': 1.0}, |
| }, |
| } |
| ) |
|
|
| export_stage1(cfg) |
| export_stage2(cfg, target_dim=1000) |
|
|
|
|