Spaces:
Running on Zero
Running on Zero
File size: 2,004 Bytes
ba23d94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import torch
# pyre-ignore[21]: Cannot find module `sapiens.engine.config`
from sapiens.engine.config import Config, DictAction
from sapiens.engine.runners import *
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train a segmentor")
parser.add_argument("config", help="train config file path")
parser.add_argument("checkpoint", help="checkpoint file")
parser.add_argument("--work-dir", help="the dir to save logs and models")
parser.add_argument(
"--cfg-options",
nargs="+",
action=DictAction,
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file. If the value to "
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
"Note that the quotation marks are necessary and that no white space "
"is allowed.",
)
parser.add_argument("--local_rank", "--local-rank", type=int, default=0)
args = parser.parse_args(argv)
if "LOCAL_RANK" not in os.environ:
os.environ["LOCAL_RANK"] = str(args.local_rank)
return args
def main(argv: list[str] | None = None) -> None:
args = parse_args(argv)
# load config
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
cfg.work_dir = args.work_dir
cfg.load_from = args.checkpoint
# set train to false
cfg.train_dataloader = None
# start testing
runner_type = cfg.get("runner_type", "BaseRunner")
runner = eval(runner_type).from_cfg(cfg)
runner.test()
if __name__ == "__main__":
main()
|