Spaces:
Running
Running
File size: 2,445 Bytes
ba23d94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import torch
import torchvision
# pyre-ignore[21]: Cannot find module `sapiens.engine.config`
from sapiens.engine.config import Config, DictAction
from sapiens.engine.runners import *
from sapiens.pose.runners import *
torch.set_float32_matmul_precision("high") # A100 gpus
torchvision.disable_beta_transforms_warning() # Disable the beta transforms warning
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train a segmentor")
parser.add_argument("config", help="train config file path")
parser.add_argument("--work-dir", help="the dir to save logs and models")
parser.add_argument(
"--resume",
nargs="?",
type=str,
const="auto",
help="If specify checkpint path, resume from it, while if not "
"specify, try to auto resume from the latest checkpoint "
"in the work directory.",
)
parser.add_argument(
"--cfg-options",
nargs="+",
action=DictAction,
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file. If the value to "
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
"Note that the quotation marks are necessary and that no white space "
"is allowed.",
)
parser.add_argument("--local_rank", "--local-rank", type=int, default=0)
args = parser.parse_args(argv)
if "LOCAL_RANK" not in os.environ:
os.environ["LOCAL_RANK"] = str(args.local_rank)
return args
def main(argv: list[str] | None = None) -> None:
args = parse_args(argv)
# load config
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
cfg.work_dir = args.work_dir
# resume training
if args.resume is not None:
cfg.resume = True
cfg.load_from = args.resume
# start training
runner_type = cfg.get("runner_type", "BaseRunner")
runner = eval(runner_type).from_cfg(cfg)
runner.train()
if __name__ == "__main__":
main()
|