| import os
|
| import torch
|
| from model import AutoModel, Config
|
|
|
| def load_model(model_path, config_path):
|
| """
|
| 加载模型权重和配置
|
| """
|
|
|
| if not os.path.exists(config_path):
|
| raise FileNotFoundError(f"配置文件未找到: {config_path}")
|
| print(f"加载配置文件: {config_path}")
|
| config = Config()
|
|
|
|
|
| model = AutoModel(config)
|
|
|
|
|
| if not os.path.exists(model_path):
|
| raise FileNotFoundError(f"模型文件未找到: {model_path}")
|
| print(f"加载模型权重: {model_path}")
|
| state_dict = torch.load(model_path, map_location=torch.device("cpu"))
|
| model.load_state_dict(state_dict)
|
| model.eval()
|
| print("模型加载成功并设置为评估模式。")
|
|
|
| return model, config
|
|
|
|
|
| def run_inference(model, config):
|
| """
|
| 使用模型运行推理
|
| """
|
|
|
| image = torch.randn(1, 3, 224, 224)
|
| text = torch.randn(1, config.max_position_embeddings, config.hidden_size)
|
| audio = torch.randn(1, config.audio_sample_rate)
|
|
|
|
|
| outputs = model(image, text, audio)
|
| vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output = outputs
|
|
|
|
|
| print("\n推理结果:")
|
| print(f"VQA output shape: {vqa_output.shape}")
|
| print(f"Caption output shape: {caption_output.shape}")
|
| print(f"Retrieval output shape: {retrieval_output.shape}")
|
| print(f"ASR output shape: {asr_output.shape}")
|
| print(f"Realtime ASR output shape: {realtime_asr_output.shape}")
|
|
|
| if __name__ == "__main__":
|
|
|
| model_path = "AutoModel.pth"
|
| config_path = "config.json"
|
|
|
|
|
| try:
|
| model, config = load_model(model_path, config_path)
|
|
|
|
|
| run_inference(model, config)
|
| except Exception as e:
|
| print(f"运行失败: {e}")
|
|
|