sapiens2-pose / sapiens /pose /src /models /init_model.py
Rawal Khirodkar
Pin Python 3.10 + torch 2.1.2; vendor sapiens2 to bypass requires-python
5f5f544
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
from typing import Optional, Union
import torch
from safetensors.torch import load_file
from sapiens.engine.config import Config
from sapiens.engine.datasets import Compose
from sapiens.registry import MODELS
def init_model(
config: Union[str, Path],
checkpoint: Optional[Union[str, Path]] = None,
device: str = "cuda:0",
):
assert isinstance(config, (str, Path))
assert checkpoint is None or isinstance(checkpoint, (str, Path))
config = Config.fromfile(config)
## avoid loading the pretrained backbone weights
if "init_cfg" in config.model["backbone"]:
config.model["backbone"].pop("init_cfg")
model = MODELS.build(config.model)
data_preprocessor = MODELS.build(config.data_preprocessor)
if checkpoint is not None:
if str(checkpoint).endswith(".safetensors"):
state_dict = load_file(checkpoint, device="cpu")
else: # Handle .pth and .bin files
checkpoint_data = torch.load(
checkpoint, map_location="cpu", weights_only=False
)
state_dict = (
checkpoint_data["state_dict"]
if "state_dict" in checkpoint_data
else checkpoint_data["model"]
)
incompat = model.load_state_dict(state_dict, strict=False)
if incompat.missing_keys:
print(f"Missing keys: {incompat.missing_keys}")
if incompat.unexpected_keys:
print(f"Unexpected keys: {incompat.unexpected_keys}")
print(f"\033[96mModel loaded from {checkpoint}\033[0m")
model.cfg = config
model.data_preprocessor = data_preprocessor
model.pipeline = Compose(config.test_pipeline)
model.to(device)
model.eval()
return model