File size: 1,976 Bytes
5f5f544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from typing import Optional, Union

import torch
from safetensors.torch import load_file
from sapiens.engine.config import Config
from sapiens.engine.datasets import Compose
from sapiens.registry import MODELS


def init_model(
    config: Union[str, Path],
    checkpoint: Optional[Union[str, Path]] = None,
    device: str = "cuda:0",
):
    assert isinstance(config, (str, Path))
    assert checkpoint is None or isinstance(checkpoint, (str, Path))

    config = Config.fromfile(config)

    ## avoid loading the pretrained backbone weights
    if "init_cfg" in config.model["backbone"]:
        config.model["backbone"].pop("init_cfg")

    model = MODELS.build(config.model)
    data_preprocessor = MODELS.build(config.data_preprocessor)

    if checkpoint is not None:
        if str(checkpoint).endswith(".safetensors"):
            state_dict = load_file(checkpoint, device="cpu")
        else:  # Handle .pth and .bin files
            checkpoint_data = torch.load(
                checkpoint, map_location="cpu", weights_only=False
            )
            state_dict = (
                checkpoint_data["state_dict"]
                if "state_dict" in checkpoint_data
                else checkpoint_data["model"]
            )

        incompat = model.load_state_dict(state_dict, strict=False)

        if incompat.missing_keys:
            print(f"Missing keys: {incompat.missing_keys}")

        if incompat.unexpected_keys:
            print(f"Unexpected keys: {incompat.unexpected_keys}")

        print(f"\033[96mModel loaded from {checkpoint}\033[0m")

    model.cfg = config
    model.data_preprocessor = data_preprocessor
    model.pipeline = Compose(config.test_pipeline)

    model.to(device)
    model.eval()

    return model