File size: 6,560 Bytes
ba23d94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# originally copied from https://www.internalfb.com/code/fbsource/[671aa4920700]/fbcode/xrcia/projects/sapiens/experimental_ghe_import/sapiens2/sapiens/seg/tools/deployment/pytorch2torchscript.py?lines=1-204

import argparse
import os

import torch
import torch._C
import torch.serialization
from sapiens.dense.models import init_model

torch.manual_seed(3)
TORCH_MINIMUM_VERSION = "1.8.0"


def digit_version(version_str: str) -> list[int]:
    """Convert a version string into a list of integers for comparison.

    This function parses version strings with complex formats and converts them into
    comparable numeric arrays. It handles standard version numbers (like '1.2.3')
    as well as release candidates (containing 'rc').

    For standard version components, each number is directly converted to an integer.
    For release candidates (e.g., '2rc1'), the function treats them as slightly
    earlier than the final release by:
    - Converting the number before 'rc' to (number - 1)
    - Appending the rc number as an additional version component

    Examples:
        '1.2.3' -> [1, 2, 3]
        '0.1.2rc1' -> [0, 1, 1, 1]  # 2rc1 becomes [1, 1]
        '2.0rc1' -> [2, -1, 1]  # 0rc1 becomes [-1, 1]

    Args:
        version_str (str): The version string to convert.

    Returns:
        list[int]: A list of integers representing the version for comparison.
    """
    digit_version = []
    for x in version_str.split("."):  # Split the version string by '.'
        if x.isdigit():  # Check if the part is a digit
            digit_version.append(int(x))  # Append the digit as an integer
        elif x.find("rc") != -1:  # Check if the part contains 'rc'
            patch_version = x.split("rc")  # Split the part by 'rc'
            digit_version.append(
                int(patch_version[0]) - 1
            )  # Append the number before 'rc' minus 1
            digit_version.append(int(patch_version[1]))  # Append the number after 'rc'
    return digit_version


def check_torch_version() -> None:
    """Validate that the installed PyTorch version meets the minimum requirement.

    Raises:
        RuntimeError: If the installed PyTorch version is below TORCH_MINIMUM_VERSION.
    """
    torch_version = digit_version(torch.__version__)
    if torch_version < digit_version(TORCH_MINIMUM_VERSION):
        raise RuntimeError(
            f"Torch=={torch.__version__} is not supported for converting to "
            f"torchscript. Please install pytorch>={TORCH_MINIMUM_VERSION}."
        )


def pytorch2torchscript(
    model: torch.nn.Module,
    input_shape: tuple[int, int, int, int],
    device: str,
    show_graph: bool = False,
    output_file: str = "tmp.pt",
    verify: bool = False,
) -> None:
    """Export Pytorch model to TorchScript model and verify the outputs are
    same between Pytorch and TorchScript.

    Args:
        model (nn.Module): Pytorch model we want to export.
        input_shape (tuple): Use this input shape to construct
            the corresponding dummy input and execute the model.
        show_graph (bool): Whether print the computation graph. Default: False.
        output_file (string): The path to where we store the
            output TorchScript model. Default: `tmp.pt`.
        verify (bool): Whether compare the outputs between
            Pytorch and TorchScript. Default: False.
    """

    # Clear CUDA cache before starting conversion
    if device == "cuda" or device.startswith("cuda:"):
        torch.cuda.empty_cache()
        print(f"Cleared CUDA cache before conversion")

    # replace the original forward with forward_dummy
    # model.forward = model.forward_dummy
    model.eval()

    # Use no_grad context to avoid storing gradients during tracing
    # Create inputs inside the context to minimize memory footprint
    with torch.no_grad():
        inputs = torch.rand(input_shape).to(device)
        traced_model = torch.jit.trace(
            model,
            example_inputs=inputs,
            check_trace=verify,
        )
        # Explicitly delete inputs and clear cache to free memory
        del inputs
        if device == "cuda" or device.startswith("cuda:"):
            torch.cuda.empty_cache()

    if show_graph:
        print(traced_model.graph)

    # Clear CUDA cache before saving to free up memory
    if device == "cuda" or device.startswith("cuda:"):
        torch.cuda.empty_cache()
        print(f"Cleared CUDA cache before saving")

    traced_model.save(output_file)
    print(f"Successfully exported TorchScript model: {output_file}")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Convert .pth checkpoint to TorchScript"
    )
    parser.add_argument("config", help="test config file path")
    parser.add_argument("--checkpoint", help="Checkpoint file")
    parser.add_argument(
        "--show-graph", action="store_true", help="show TorchScript graph"
    )
    parser.add_argument(
        "--verify", action="store_true", help="verify the TorchScript model"
    )
    parser.add_argument("--output-file", type=str, default="tmp.pt")
    parser.add_argument(
        "--shape",
        type=int,
        nargs="+",
        default=[1024, 768],
        help="input image size (height, width)",
    )
    parser.add_argument("--device", default="cuda:0", help="Device used for inference")
    args = parser.parse_args()
    return args


def main() -> None:
    args = parse_args()
    check_torch_version()

    if len(args.shape) == 1:
        input_shape = (1, 3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (
            1,
            3,
        ) + tuple(args.shape)
    else:
        raise ValueError("invalid input shape")

    # build the model, load checkpoint
    model = init_model(args.config, args.checkpoint, device=args.device)

    ## create the output directory if it does not exist
    output_dir = os.path.dirname(args.output_file)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # convert the PyTorch model to TorchScript model
    pytorch2torchscript(
        model,
        input_shape=input_shape,
        device=args.device,
        show_graph=args.show_graph,
        output_file=args.output_file,
        verify=args.verify,
    )


if __name__ == "__main__":
    main()