Spaces:
Running on Zero
Running on Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from __future__ import annotations | |
| import argparse | |
| import ast | |
| import json | |
| from importlib import util as _iu | |
| from pathlib import Path | |
| from types import ModuleType, SimpleNamespace | |
| class Config(SimpleNamespace): | |
| """Dot-dict wrapper around a Python-module config.""" | |
| DELETE_KEY = "_delete_" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) # keep SimpleNamespace behaviour | |
| object.__setattr__( | |
| self, | |
| "_cfg_dict", # use object.__setattr__ to avoid | |
| json.loads(json.dumps(kwargs)), | |
| ) # recursion hook | |
| def fromfile(cls, path: str | Path) -> "Config": | |
| path = Path(path).expanduser() | |
| spec = _iu.spec_from_file_location(path.stem, path) | |
| mod = _iu.module_from_spec(spec) | |
| spec.loader.exec_module(mod) # type: ignore | |
| data = { | |
| k: v | |
| for k, v in vars(mod).items() | |
| if not k.startswith("_") and not isinstance(v, ModuleType) | |
| } | |
| data["filename"] = str(path) | |
| return cls(**data) | |
| def merge_from_dict( | |
| self, options: dict[str, object], allow_list_keys: bool = True | |
| ) -> None: | |
| """Merge a dict of dotted keys (or nested dict) into this Config.""" | |
| # Lazily ensure _cfg_dict exists if __init__ was bypassed | |
| if not hasattr(self, "_cfg_dict"): | |
| object.__setattr__(self, "_cfg_dict", self.to_dict()) | |
| # 1. expand dotted keys -> nested dict | |
| nested_patch: dict = {} | |
| for dotted_key, value in options.items(): | |
| node = nested_patch | |
| parts = dotted_key.split(".") | |
| for p in parts[:-1]: | |
| node = node.setdefault(p, {}) | |
| node[parts[-1]] = value | |
| # 2. deep-merge | |
| merged = self._merge_a_into_b( | |
| nested_patch, self._cfg_dict, allow_list_keys=allow_list_keys | |
| ) | |
| object.__setattr__(self, "_cfg_dict", merged) | |
| # 3. rebuild attributes so dot-access sees the new values | |
| self.__dict__.update(merged) | |
| # ------------------------------------------------------------------ | |
| def _merge_a_into_b(a: dict, b: dict | list, *, allow_list_keys: bool = False): | |
| """Deep-merge *a* into *b* (returns a **new** structure).""" | |
| from copy import deepcopy | |
| # If the target is a list and we allow list keys --------------- | |
| if allow_list_keys and isinstance(b, list): | |
| b = deepcopy(b) | |
| for k, v in a.items(): | |
| if not k.isdigit(): | |
| raise TypeError(f"Expected int-like key for list merge, got {k!r}") | |
| idx = int(k) | |
| if idx >= len(b): | |
| raise IndexError(f"Index {idx} out of range for list.") | |
| b[idx] = Config._merge_a_into_b( | |
| v, b[idx], allow_list_keys=allow_list_keys | |
| ) | |
| return b | |
| # Otherwise we are merging dicts ------------------------------- | |
| b = deepcopy(b) | |
| for k, v in a.items(): | |
| # Deletion flag (_delete_ = True) -------------------------- | |
| if isinstance(v, dict) and v.pop(Config.DELETE_KEY, False): | |
| b[k] = Config._merge_a_into_b(v, {}, allow_list_keys) | |
| continue | |
| if ( | |
| k in b | |
| and isinstance(b[k], (dict, list)) | |
| and isinstance(v, (dict, list)) | |
| ): | |
| b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys=allow_list_keys) | |
| else: | |
| b[k] = deepcopy(v) | |
| return b | |
| def __getitem__(self, item): | |
| return self.__dict__[item] | |
| def __setitem__(self, key, value): | |
| self.__dict__[key] = value | |
| def get(self, key, default=None): | |
| """Get a value by key with an optional default if key doesn't exist.""" | |
| return self.__dict__.get(key, default) | |
| def to_dict(self) -> dict: | |
| def _rec(obj): | |
| if isinstance(obj, Config): | |
| return {k: _rec(v) for k, v in obj.__dict__.items()} | |
| if isinstance(obj, dict): | |
| return {k: _rec(v) for k, v in obj.items()} | |
| if isinstance(obj, (list, tuple)): | |
| return type(obj)(_rec(x) for x in obj) | |
| return obj | |
| return _rec(self) | |
| # --------------------------------------------------------------------------- | |
| class DictAction(argparse.Action): | |
| def __call__(self, _parser, namespace, values, _option_string=None): | |
| out: dict[str, object] = {} | |
| for kv in values: | |
| if "=" not in kv: | |
| raise ValueError(f"--cfg-options expected key=value, got {kv}") | |
| key, val = kv.split("=", 1) | |
| for parser_fn in (json.loads, ast.literal_eval, str): | |
| try: | |
| val = parser_fn(val) | |
| break | |
| except Exception: | |
| continue | |
| cur = out | |
| *parents, leaf = key.split(".") | |
| for p in parents: | |
| cur = cur.setdefault(p, {}) | |
| cur[leaf] = val | |
| setattr(namespace, self.dest, out) | |
| # --------------------------------------------------------------------------- | |
| _INDENT = 4 | |
| def _indent_block(txt: str, n: int = _INDENT) -> str: | |
| pad = " " * n | |
| return "\n".join(pad + ln if ln else ln for ln in txt.splitlines()) | |
| def _format_basic(k, v, mapping: bool) -> str: | |
| v_str = repr(v) if isinstance(v, str) else str(v) | |
| if mapping: | |
| k_str = repr(k) if isinstance(k, str) else str(k) | |
| return f"{k_str}: {v_str}" | |
| return f"{k}={v_str}" | |
| def _format_list_tuple(obj, mapping_key=None, mapping=False): | |
| left, right = ("[", "]") if isinstance(obj, list) else ("(", ")") | |
| body = [] | |
| for item in obj: | |
| if isinstance(item, dict): | |
| body.append("dict(" + _indent_block(_format_dict(item), _INDENT) + "),") | |
| elif isinstance(item, (list, tuple)): | |
| body.append(_indent_block(_format_list_tuple(item), _INDENT) + ",") | |
| else: | |
| body.append(repr(item) + ",") | |
| inner = "\n".join(body) | |
| if mapping_key is None: | |
| return _indent_block(left + "\n" + inner + "\n" + right, _INDENT) | |
| if mapping: | |
| k_str = repr(mapping_key) if isinstance(mapping_key, str) else str(mapping_key) | |
| return f"{k_str}: {left}\n{_indent_block(inner, _INDENT)}\n{right}" | |
| return f"{mapping_key}={left}\n{_indent_block(inner, _INDENT)}\n{right}" | |
| def _contains_non_identifier(keys): | |
| return any((not str(k).isidentifier()) for k in keys) | |
| def _format_dict(d: dict, outer: bool = False) -> str: | |
| lines = [] | |
| mapping = _contains_non_identifier(d) | |
| if mapping and not outer: | |
| lines.append("{") | |
| for idx, (k, v) in enumerate(sorted(d.items(), key=lambda x: str(x[0]))): | |
| is_last = idx == len(d) - 1 | |
| suffix = "" if outer or is_last else "," | |
| if isinstance(v, dict): | |
| inner = _format_dict(v) | |
| if mapping: | |
| k_str = repr(k) if isinstance(k, str) else str(k) | |
| line = f"{k_str}: dict(\n{_indent_block(inner, _INDENT)}\n){suffix}" | |
| else: | |
| line = f"{k}=dict(\n{_indent_block(inner, _INDENT)}\n){suffix}" | |
| elif isinstance(v, (list, tuple)): | |
| line = _format_list_tuple(v, mapping_key=k, mapping=mapping) + suffix | |
| else: | |
| line = _format_basic(k, v, mapping) + suffix | |
| lines.append(_indent_block(line, _INDENT)) | |
| if mapping and not outer: | |
| lines.append("}") | |
| return "\n".join(lines) | |
| # ----------------------------------------------------------------- | |
| def pretty_text(cfg_dict: dict) -> str: | |
| body = _format_dict(cfg_dict, outer=True) | |
| import textwrap | |
| body = textwrap.dedent(body) | |
| try: | |
| from yapf.yapflib.yapf_api import FormatCode | |
| yapf_style = dict( | |
| based_on_style="pep8", | |
| blank_line_before_nested_class_or_def=True, | |
| split_before_expression_after_opening_paren=True, | |
| ) | |
| body, _ = FormatCode(body, style_config=yapf_style) | |
| except ImportError: | |
| pass | |
| except Exception as exc: # keep raw body if yapf fails | |
| print(f"[pretty_text] yapf failed: {exc}\nReturning un-formatted text.") | |
| return body | |
| def print_cfg(cfg_dict): | |
| try: | |
| from rich.console import Console | |
| from rich.syntax import Syntax | |
| console = Console() | |
| code_str = pretty_text(cfg_dict) | |
| syntax_block = Syntax(code_str, "python", theme="ansi_dark", word_wrap=False) | |
| console.print(syntax_block, style="green") | |
| return "---" | |
| except ImportError: | |
| from pprint import pformat | |
| GREEN = "\033[92m" | |
| RESET = "\033[0m" | |
| print(GREEN + pformat(cfg_dict, sort_dicts=False) + RESET) | |
| return "---" | |