# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # pyre-strict # originally copied from https://www.internalfb.com/code/fbsource/[671aa4920700]/fbcode/xrcia/projects/sapiens/experimental_ghe_import/sapiens2/sapiens/seg/tools/deployment/pytorch2torchscript.py?lines=1-204 import argparse import os import torch import torch._C import torch.serialization from sapiens.dense.tools.deployment.pytorch2torchscript import check_torch_version from sapiens.pose.datasets import parse_pose_metainfo, UDPHeatmap from sapiens.pose.models import init_model torch.manual_seed(3) TORCH_MINIMUM_VERSION = "1.8.0" def pytorch2torchscript( model: torch.nn.Module, input_shape: tuple[int, int, int, int], device: str, show_graph: bool = False, output_file: str = "tmp.pt", verify: bool = False, ) -> None: """Export Pytorch model to TorchScript model and verify the outputs are same between Pytorch and TorchScript. Args: model (nn.Module): Pytorch model we want to export. input_shape (tuple): Use this input shape to construct the corresponding dummy input and execute the model. show_graph (bool): Whether print the computation graph. Default: False. output_file (string): The path to where we store the output TorchScript model. Default: `tmp.pt`. verify (bool): Whether compare the outputs between Pytorch and TorchScript. Default: False. """ inputs = torch.rand(input_shape).to(device) # replace the original forward with forward_dummy # model.forward = model.forward_dummy model.eval() traced_model = torch.jit.trace( model, example_inputs=inputs, check_trace=verify, ) if show_graph: print(traced_model.graph) traced_model.save(output_file) print(f"Successfully exported TorchScript model: {output_file}") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Convert .pth checkpoint to TorchScript" ) parser.add_argument("config", help="test config file path") parser.add_argument("--checkpoint", help="Checkpoint file") parser.add_argument( "--show-graph", action="store_true", help="show TorchScript graph" ) parser.add_argument( "--verify", action="store_true", help="verify the TorchScript model" ) parser.add_argument("--output-file", type=str, default="tmp.pt") parser.add_argument( "--shape", type=int, nargs="+", default=[1024, 768], help="input image size (height, width)", ) parser.add_argument("--device", default="cuda:0", help="Device used for inference") args = parser.parse_args() return args def main() -> None: args = parse_args() check_torch_version() if len(args.shape) == 1: input_shape = (1, 3, args.shape[0], args.shape[0]) elif len(args.shape) == 2: input_shape = ( 1, 3, ) + tuple(args.shape) else: raise ValueError("invalid input shape") # build the model, load checkpoint model = init_model(args.config, args.checkpoint, device=args.device) ## add pose metainfo to model num_keypoints = model.cfg.num_keypoints if num_keypoints == 308: model.pose_metainfo = parse_pose_metainfo( dict(from_file="configs/_base_/keypoints308.py") ) ## add codec to model codec_type = model.cfg.codec.pop("type") assert codec_type == "UDPHeatmap", "Only support UDPHeatmap" model.codec = UDPHeatmap(**model.cfg.codec) ## create the output directory if it does not exist output_dir = os.path.dirname(args.output_file) if not os.path.exists(output_dir): os.makedirs(output_dir) # convert the PyTorch model to TorchScript model pytorch2torchscript( model, input_shape=input_shape, device=args.device, show_graph=args.show_graph, output_file=args.output_file, verify=args.verify, ) if __name__ == "__main__": main()