Rawal Khirodkar
Pin Python 3.10 + torch 2.1.2; vendor sapiens2 to bypass requires-python
5f5f544
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import argparse
import ast
import json
from importlib import util as _iu
from pathlib import Path
from types import ModuleType, SimpleNamespace
class Config(SimpleNamespace):
"""Dot-dict wrapper around a Python-module config."""
DELETE_KEY = "_delete_"
def __init__(self, **kwargs):
super().__init__(**kwargs) # keep SimpleNamespace behaviour
object.__setattr__(
self,
"_cfg_dict", # use object.__setattr__ to avoid
json.loads(json.dumps(kwargs)),
) # recursion hook
@classmethod
def fromfile(cls, path: str | Path) -> "Config":
path = Path(path).expanduser()
spec = _iu.spec_from_file_location(path.stem, path)
mod = _iu.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore
data = {
k: v
for k, v in vars(mod).items()
if not k.startswith("_") and not isinstance(v, ModuleType)
}
data["filename"] = str(path)
return cls(**data)
def merge_from_dict(
self, options: dict[str, object], allow_list_keys: bool = True
) -> None:
"""Merge a dict of dotted keys (or nested dict) into this Config."""
# Lazily ensure _cfg_dict exists if __init__ was bypassed
if not hasattr(self, "_cfg_dict"):
object.__setattr__(self, "_cfg_dict", self.to_dict())
# 1. expand dotted keys -> nested dict
nested_patch: dict = {}
for dotted_key, value in options.items():
node = nested_patch
parts = dotted_key.split(".")
for p in parts[:-1]:
node = node.setdefault(p, {})
node[parts[-1]] = value
# 2. deep-merge
merged = self._merge_a_into_b(
nested_patch, self._cfg_dict, allow_list_keys=allow_list_keys
)
object.__setattr__(self, "_cfg_dict", merged)
# 3. rebuild attributes so dot-access sees the new values
self.__dict__.update(merged)
# ------------------------------------------------------------------
@staticmethod
def _merge_a_into_b(a: dict, b: dict | list, *, allow_list_keys: bool = False):
"""Deep-merge *a* into *b* (returns a **new** structure)."""
from copy import deepcopy
# If the target is a list and we allow list keys ---------------
if allow_list_keys and isinstance(b, list):
b = deepcopy(b)
for k, v in a.items():
if not k.isdigit():
raise TypeError(f"Expected int-like key for list merge, got {k!r}")
idx = int(k)
if idx >= len(b):
raise IndexError(f"Index {idx} out of range for list.")
b[idx] = Config._merge_a_into_b(
v, b[idx], allow_list_keys=allow_list_keys
)
return b
# Otherwise we are merging dicts -------------------------------
b = deepcopy(b)
for k, v in a.items():
# Deletion flag (_delete_ = True) --------------------------
if isinstance(v, dict) and v.pop(Config.DELETE_KEY, False):
b[k] = Config._merge_a_into_b(v, {}, allow_list_keys)
continue
if (
k in b
and isinstance(b[k], (dict, list))
and isinstance(v, (dict, list))
):
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys=allow_list_keys)
else:
b[k] = deepcopy(v)
return b
def __getitem__(self, item):
return self.__dict__[item]
def __setitem__(self, key, value):
self.__dict__[key] = value
def get(self, key, default=None):
"""Get a value by key with an optional default if key doesn't exist."""
return self.__dict__.get(key, default)
def to_dict(self) -> dict:
def _rec(obj):
if isinstance(obj, Config):
return {k: _rec(v) for k, v in obj.__dict__.items()}
if isinstance(obj, dict):
return {k: _rec(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return type(obj)(_rec(x) for x in obj)
return obj
return _rec(self)
# ---------------------------------------------------------------------------
class DictAction(argparse.Action):
def __call__(self, _parser, namespace, values, _option_string=None):
out: dict[str, object] = {}
for kv in values:
if "=" not in kv:
raise ValueError(f"--cfg-options expected key=value, got {kv}")
key, val = kv.split("=", 1)
for parser_fn in (json.loads, ast.literal_eval, str):
try:
val = parser_fn(val)
break
except Exception:
continue
cur = out
*parents, leaf = key.split(".")
for p in parents:
cur = cur.setdefault(p, {})
cur[leaf] = val
setattr(namespace, self.dest, out)
# ---------------------------------------------------------------------------
_INDENT = 4
def _indent_block(txt: str, n: int = _INDENT) -> str:
pad = " " * n
return "\n".join(pad + ln if ln else ln for ln in txt.splitlines())
def _format_basic(k, v, mapping: bool) -> str:
v_str = repr(v) if isinstance(v, str) else str(v)
if mapping:
k_str = repr(k) if isinstance(k, str) else str(k)
return f"{k_str}: {v_str}"
return f"{k}={v_str}"
def _format_list_tuple(obj, mapping_key=None, mapping=False):
left, right = ("[", "]") if isinstance(obj, list) else ("(", ")")
body = []
for item in obj:
if isinstance(item, dict):
body.append("dict(" + _indent_block(_format_dict(item), _INDENT) + "),")
elif isinstance(item, (list, tuple)):
body.append(_indent_block(_format_list_tuple(item), _INDENT) + ",")
else:
body.append(repr(item) + ",")
inner = "\n".join(body)
if mapping_key is None:
return _indent_block(left + "\n" + inner + "\n" + right, _INDENT)
if mapping:
k_str = repr(mapping_key) if isinstance(mapping_key, str) else str(mapping_key)
return f"{k_str}: {left}\n{_indent_block(inner, _INDENT)}\n{right}"
return f"{mapping_key}={left}\n{_indent_block(inner, _INDENT)}\n{right}"
def _contains_non_identifier(keys):
return any((not str(k).isidentifier()) for k in keys)
def _format_dict(d: dict, outer: bool = False) -> str:
lines = []
mapping = _contains_non_identifier(d)
if mapping and not outer:
lines.append("{")
for idx, (k, v) in enumerate(sorted(d.items(), key=lambda x: str(x[0]))):
is_last = idx == len(d) - 1
suffix = "" if outer or is_last else ","
if isinstance(v, dict):
inner = _format_dict(v)
if mapping:
k_str = repr(k) if isinstance(k, str) else str(k)
line = f"{k_str}: dict(\n{_indent_block(inner, _INDENT)}\n){suffix}"
else:
line = f"{k}=dict(\n{_indent_block(inner, _INDENT)}\n){suffix}"
elif isinstance(v, (list, tuple)):
line = _format_list_tuple(v, mapping_key=k, mapping=mapping) + suffix
else:
line = _format_basic(k, v, mapping) + suffix
lines.append(_indent_block(line, _INDENT))
if mapping and not outer:
lines.append("}")
return "\n".join(lines)
# -----------------------------------------------------------------
def pretty_text(cfg_dict: dict) -> str:
body = _format_dict(cfg_dict, outer=True)
import textwrap
body = textwrap.dedent(body)
try:
from yapf.yapflib.yapf_api import FormatCode
yapf_style = dict(
based_on_style="pep8",
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True,
)
body, _ = FormatCode(body, style_config=yapf_style)
except ImportError:
pass
except Exception as exc: # keep raw body if yapf fails
print(f"[pretty_text] yapf failed: {exc}\nReturning un-formatted text.")
return body
def print_cfg(cfg_dict):
try:
from rich.console import Console
from rich.syntax import Syntax
console = Console()
code_str = pretty_text(cfg_dict)
syntax_block = Syntax(code_str, "python", theme="ansi_dark", word_wrap=False)
console.print(syntax_block, style="green")
return "---"
except ImportError:
from pprint import pformat
GREEN = "\033[92m"
RESET = "\033[0m"
print(GREEN + pformat(cfg_dict, sort_dicts=False) + RESET)
return "---"