flow-matching / test /load_model.py
sabertoaster's picture
Upload folder using huggingface_hub
0254260 verified
import sys
import argparse
from pathlib import Path
import torch
import torch.nn as nn
from omegaconf import OmegaConf
ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(ROOT))
sys.path.append(str(ROOT / 'flow_matching'))
sys.path.append(str(ROOT / 'flow_matching/Matcha-TTS'))
from flow_matching.src.stage1.medarc_architecture import MultiSubjectConvLinearEncoder
from flow_matching.src.stage2.CFM import CFM
def export_stage1(cfg):
print('Exporting Stage 1...')
subjects_list = cfg.get('subjects', [1, 2, 3, 5])
feat_dims = (32,)
model = MultiSubjectConvLinearEncoder(
num_subjects=len(subjects_list),
feat_dims=feat_dims,
**cfg.stage1.model
)
weights_path = ROOT / 'output' / 'debug_run' / 'stage1_best.pt'
if weights_path.exists():
state_dict = torch.load(weights_path, map_location='cpu', weights_only=True)
model.load_state_dict(state_dict, strict=False)
print('Loaded Stage 1 weights.')
model.eval()
dummy_input = [torch.randn(2, 10, 32)]
save_path = 'stage1_model.onnx'
try:
torch.onnx.export(
model,
dummy_input,
save_path,
opset_version=14,
input_names=['features'],
output_names=['embeddings']
)
print(f'Saved {save_path}')
except Exception as e:
print(f'Failed exporting Stage 1: {e}')
def export_stage2(cfg, target_dim=1000):
print('\nExporting Stage 2...')
cfm_params = cfg.stage2.cfm
velocity_net_params = cfg.stage2.velocity_net
source_ve_params = cfg.stage2.source_ve
transport_params = cfg.stage2.transport
model = CFM(
feat_dim=target_dim,
cfm_params=cfm_params,
velocity_net_params=velocity_net_params,
source_ve_params=source_ve_params,
transport_params=transport_params,
)
weights_path = ROOT / 'output' / 'debug_run' / 'stage2_epoch_0.pt'
if weights_path.exists():
state_dict = torch.load(weights_path, map_location='cpu', weights_only=True)
model.load_state_dict(state_dict, strict=False)
print('Loaded Stage 2 weights.')
model.eval()
class Stage2ExportWrapper(nn.Module):
def __init__(self, cfm_model: nn.Module, steps: int):
super().__init__()
self.cfm_model = cfm_model
self.steps = steps
def forward(self, mu):
return self.cfm_model(mu, n_timesteps=self.steps)
export_wrapper = Stage2ExportWrapper(model, steps=10)
export_wrapper.eval()
# Dummy input: (B, C, T)
B, C, T = 1, target_dim, 10
mu = torch.randn(B, C, T)
save_path = 'stage2_model.onnx'
try:
torch.onnx.export(
export_wrapper,
mu,
save_path,
opset_version=14,
input_names=['mu'],
output_names=['output']
)
print(f'Saved {save_path}')
except Exception as e:
print(f'Failed exporting Stage 2: {e}')
if __name__ == '__main__':
config_path = ROOT / 'output' / 'debug_run' / 'config.yaml'
if config_path.exists():
cfg = OmegaConf.load(config_path)
else:
# Dummy config if config not found
cfg = OmegaConf.create(
{
'stage1': {'model': {}},
'stage2': {
'cfm': {'solver': 'euler', 'kld_weight': 3.0, 'kld_target_std': 1.0, 'detach_ut': False},
'velocity_net': {'hidden_dim': 256, 'n_blocks': 2, 'n_heads': 4, 'dropout': 0.1, 'max_seq_len': 2048, 'temporal_attn_layers': 1},
'source_ve': {'depth': 2, 'num_heads': 4, 'num_queries': 8, 'dropout': 0.1, 'use_variational': True},
'transport': {'path_type': 'Linear', 'prediction': 'velocity', 'time_dist_type': 'uniform', 'time_dist_shift': 1.0},
},
}
)
export_stage1(cfg)
export_stage2(cfg, target_dim=1000)