sapiens2-pose / sapiens /dense /tools /deployment /torch_optimization.py
Rawal Khirodkar
Pin Python 3.10 + torch 2.1.2; vendor sapiens2 to bypass requires-python
5f5f544
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""PyTorch model optimization utilities for exporting and compiling models.
This module provides utilities for:
- Converting SyncBatchNorm layers to standard BatchNorm
- Benchmarking model performance
- Exporting models with torch.export
- Compiling models with torch.compile
"""
import argparse
from typing import Any
import numpy as np
import torch
from sapiens.dense.models import init_model
from torch import nn
# =============================================================================
# BatchNorm Conversion Utilities
# =============================================================================
class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
"""A general BatchNorm layer without input dimension check.
Reproduced from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
is `_check_input_dim` that is designed for tensor sanity checks.
The check has been bypassed in this class for the convenience of converting
SyncBatchNorm.
"""
def _check_input_dim(self, input: torch.Tensor) -> None:
return
def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
"""Convert all SyncBatchNorm layers in the model to BatchNormXd layers.
Adapted from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
Args:
module: The module containing `SyncBatchNorm` layers.
Returns:
The converted module with `BatchNormXd` layers.
"""
module_output = module
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
if isinstance(module, tuple(module_checklist)):
module_output = _BatchNormXd(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
module_output.training = module.training
if hasattr(module, "qconfig"):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
try:
module_output.add_module(name, revert_sync_batchnorm(child))
except Exception:
print(f"Failed to convert {child} from SyncBN to BN!")
del module
return module_output
def convert_batchnorm(module: nn.Module) -> nn.Module:
"""Convert SyncBatchNorm to BatchNorm2d and optionally SiLU to ReLU.
Args:
module: The module to convert.
Returns:
The converted module.
"""
module_output = module
if isinstance(module, torch.nn.SyncBatchNorm):
module_output = torch.nn.BatchNorm2d(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if isinstance(module, torch.nn.SiLU):
module_output = torch.nn.ReLU(inplace=True)
for name, child in module.named_children():
module_output.add_module(name, convert_batchnorm(child))
del module
return module_output
# =============================================================================
# Benchmarking Utilities
# =============================================================================
def benchmark_model(
model: nn.Module,
inputs: dict[str, Any],
model_name: str = "",
num_warmup: int = 3,
num_iterations: int = 10,
) -> float:
"""Benchmark model inference time.
Args:
model: The model to benchmark.
inputs: Dictionary containing 'imgs' tensor.
model_name: Name for logging purposes.
num_warmup: Number of warmup iterations (not counted).
num_iterations: Number of timed iterations.
Returns:
Mean inference time per sample in milliseconds.
"""
imgs = (
inputs["imgs"][0, ...].unsqueeze(0)
if model_name.lower() == "original"
else inputs["imgs"]
)
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for benchmarking")
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
times = []
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream), torch.no_grad():
# Warmup
for _ in range(num_warmup):
model(imgs)
torch.cuda.synchronize()
# Timed iterations
for _ in range(num_iterations):
torch.cuda.synchronize()
start_event.record()
model(imgs)
end_event.record()
torch.cuda.synchronize()
times.append(start_event.elapsed_time(end_event))
torch.cuda.current_stream().wait_stream(stream)
mean_time = np.mean(times) / imgs.shape[0]
print(f"Benchmark results for '{model_name}':")
print(f" Average time per sample: {mean_time:.2f} ms")
print(f" Total time ({num_iterations} iterations): {sum(times):.2f} ms")
print(f" Individual times: {[f'{t:.2f}' for t in times]}")
return mean_time
# =============================================================================
# Input Generation
# =============================================================================
def create_demo_inputs(input_shape: tuple[int, int, int, int]) -> dict[str, Any]:
"""Create demo inputs for model testing and export.
Args:
input_shape: Tuple of (N, C, H, W) for input dimensions.
Returns:
Dictionary with 'imgs' tensor and 'img_metas' list.
"""
n, c, h, w = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
img_metas = [
{
"img_shape": (h, w, c),
"ori_shape": (h, w, c),
"pad_shape": (h, w, c),
"filename": "<demo>.png",
"scale_factor": 1.0,
"flip": False,
}
for _ in range(n)
]
return {
"imgs": torch.FloatTensor(imgs),
"img_metas": img_metas,
}
# =============================================================================
# Model Export and Compilation
# =============================================================================
class _ToDeviceTransformer(torch.fx.Transformer):
"""FX Transformer to move operations to a specific device."""
def __init__(self, module: nn.Module, device: str):
super().__init__(module)
self.target_device = torch.device(device)
def call_function(self, target, args, kwargs):
if "device" not in kwargs:
return super().call_function(target, args, kwargs)
kwargs = dict(kwargs)
kwargs["device"] = self.target_device
return super().call_function(target, args, kwargs)
def compile_and_export_model(
model: nn.Module,
inputs: dict[str, Any],
output_file: str = "compiled_model.pt",
max_batch_size: int = 32,
dtype: torch.dtype = torch.bfloat16,
) -> None:
"""Export model using torch.export and optionally compile with torch.compile.
Args:
model: The model to export.
inputs: Demo inputs for tracing.
output_file: Path to save the exported model.
max_batch_size: Maximum batch size for dynamic shapes.
dtype: Data type for the model.
"""
inputs["imgs"] = inputs["imgs"].to(dtype)
imgs = inputs["imgs"]
model.eval()
# Define dynamic shapes
dynamic_batch = torch.export.Dim("batch", min=1, max=max_batch_size)
dynamic_h = torch.export.Dim("h", min=1024, max=2048)
dynamic_w = torch.export.Dim("w", min=768, max=1536)
dynamic_shapes = {"inputs": {0: dynamic_batch, 2: dynamic_h, 3: dynamic_w}}
# Export model
exported_model = torch.export.export(
model,
args=(imgs,),
kwargs={},
dynamic_shapes=dynamic_shapes,
)
torch.export.save(exported_model, output_file)
print(f"Model exported to: {output_file}")
if not torch.cuda.is_available():
return
# Compile and benchmark
device = "cuda:0"
model = torch.export.load(output_file).module().to(device)
model = _ToDeviceTransformer(model, device).transform()
imgs = imgs.to(device)
inputs["imgs"] = inputs["imgs"].to(device)
_compile_and_benchmark(model, imgs, inputs)
def _compile_and_benchmark(
model: nn.Module,
imgs: torch.Tensor,
inputs: dict[str, Any],
) -> None:
"""Compile model and benchmark different compilation modes.
Args:
model: Model to compile.
imgs: Input images tensor.
inputs: Full inputs dictionary for benchmarking.
"""
modes = {"default": "default"}
best_mode = None
min_mean = float("inf")
for mode_name, mode in modes.items():
print(f"Compiling model with '{mode_name}' mode...")
compiled_model = torch.compile(model, mode=mode)
# Warmup
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream), torch.no_grad():
for _ in range(3):
compiled_model(imgs)
torch.cuda.synchronize()
torch.cuda.current_stream().wait_stream(stream)
mean_time = benchmark_model(compiled_model, inputs, model_name=mode_name)
if mean_time < min_mean:
min_mean = mean_time
best_mode = mode_name
print(f"Best compilation mode: {best_mode}")
# =============================================================================
# CLI Interface
# =============================================================================
def parse_args() -> argparse.Namespace:
"""Parse command line arguments.
Returns:
Parsed arguments.
"""
parser = argparse.ArgumentParser(
description="Export and optimize a model for deployment"
)
parser.add_argument("config", help="Model config file path")
parser.add_argument("--checkpoint", help="Checkpoint file path")
parser.add_argument(
"--shape",
type=int,
nargs="+",
default=[1024, 768],
help="Input image size as (height, width)",
)
parser.add_argument(
"--output-file",
"--output-dir",
type=str,
required=True,
help="Output file path for exported model",
)
parser.add_argument(
"--max-batch-size",
type=int,
default=32,
help="Maximum batch size for dynamic export",
)
parser.add_argument(
"--fp16",
action="store_true",
help="Use fp16 instead of bfloat16",
)
return parser.parse_args()
def main() -> None:
"""Main entry point for model optimization CLI."""
args = parse_args()
# Determine input shape
if len(args.shape) == 1:
input_shape = (16, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (16, 3, args.shape[0], args.shape[1])
else:
raise ValueError("Shape must be 1 or 2 integers (height, width)")
# Clamp batch size
max_batch_size = args.max_batch_size
input_shape = (max(1, min(input_shape[0], max_batch_size)), *input_shape[1:])
# Initialize model
model = init_model(args.config, args.checkpoint, device="cpu")
model.eval()
model = revert_sync_batchnorm(model)
# Create demo inputs
demo_inputs = create_demo_inputs(input_shape)
# Set dtype
dtype = torch.half if args.fp16 else torch.bfloat16
model.to(dtype)
demo_inputs["imgs"] = demo_inputs["imgs"].to(dtype)
# Export and compile
compile_and_export_model(
model,
demo_inputs,
output_file=args.output_file,
max_batch_size=max_batch_size,
dtype=dtype,
)
if __name__ == "__main__":
main()