| from dataclasses import dataclass |
|
|
| import __main__ |
|
|
| import os |
| import ujson |
| from huggingface_hub import hf_hub_download |
| import dataclasses |
| import datetime |
| from typing import Any |
| from dataclasses import dataclass, fields |
| import socket |
| import git |
| import time |
| import torch |
| import sys |
|
|
| def torch_load_dnn(path): |
| if path.startswith("http:") or path.startswith("https:"): |
| dnn = torch.hub.load_state_dict_from_url(path, map_location='cpu') |
| else: |
| dnn = torch.load(path, map_location='cpu') |
| |
| return dnn |
|
|
| class dotdict(dict): |
| """ |
| dot.notation access to dictionary attributes |
| Credit: derek73 @ https://stackoverflow.com/questions/2352181 |
| """ |
| __getattr__ = dict.__getitem__ |
| __setattr__ = dict.__setitem__ |
| __delattr__ = dict.__delitem__ |
|
|
| def get_metadata_only(): |
| args = dotdict() |
|
|
| args.hostname = socket.gethostname() |
| try: |
| args.git_branch = git.Repo(search_parent_directories=True).active_branch.name |
| args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha |
| args.git_commit_datetime = str(git.Repo(search_parent_directories=True).head.object.committed_datetime) |
| except git.exc.InvalidGitRepositoryError as e: |
| pass |
| args.current_datetime = time.strftime('%b %d, %Y ; %l:%M%p %Z (%z)') |
| args.cmd = ' '.join(sys.argv) |
|
|
| return args |
|
|
| def timestamp(daydir=False): |
| format_str = f"%Y-%m{'/' if daydir else '-'}%d{'/' if daydir else '_'}%H.%M.%S" |
| result = datetime.datetime.now().strftime(format_str) |
| return result |
|
|
| @dataclass |
| class DefaultVal: |
| val: Any |
| |
| def __hash__(self): |
| return hash(repr(self.val)) |
|
|
| def __eq__(self, other): |
| self.val == other.val |
|
|
| @dataclass |
| class RunSettings: |
| """ |
| The defaults here have a special status in Run(), which initially calls assign_defaults(), |
| so these aren't soft defaults in that specific context. |
| """ |
|
|
| overwrite: bool = DefaultVal(False) |
|
|
| root: str = DefaultVal(os.path.join(os.getcwd(), 'experiments')) |
| experiment: str = DefaultVal('default') |
|
|
| index_root: str = DefaultVal(None) |
| name: str = DefaultVal(timestamp(daydir=True)) |
|
|
| rank: int = DefaultVal(0) |
| nranks: int = DefaultVal(1) |
| amp: bool = DefaultVal(True) |
|
|
| total_visible_gpus = torch.cuda.device_count() |
| gpus: int = DefaultVal(total_visible_gpus) |
|
|
| avoid_fork_if_possible: bool = DefaultVal(False) |
|
|
| @property |
| def gpus_(self): |
| value = self.gpus |
|
|
| if isinstance(value, int): |
| value = list(range(value)) |
|
|
| if isinstance(value, str): |
| value = value.split(',') |
|
|
| value = list(map(int, value)) |
| value = sorted(list(set(value))) |
|
|
| assert all(device_idx in range(0, self.total_visible_gpus) for device_idx in value), value |
|
|
| return value |
|
|
| @property |
| def index_root_(self): |
| return self.index_root or os.path.join(self.root, self.experiment, 'indexes/') |
|
|
| @property |
| def script_name_(self): |
| if '__file__' in dir(__main__): |
| cwd = os.path.abspath(os.getcwd()) |
| script_path = os.path.abspath(__main__.__file__) |
| root_path = os.path.abspath(self.root) |
|
|
| if script_path.startswith(cwd): |
| script_path = script_path[len(cwd):] |
|
|
| else: |
| try: |
| commonpath = os.path.commonpath([script_path, root_path]) |
| script_path = script_path[len(commonpath):] |
| except: |
| pass |
|
|
|
|
| assert script_path.endswith('.py') |
| script_name = script_path.replace('/', '.').strip('.')[:-3] |
|
|
| assert len(script_name) > 0, (script_name, script_path, cwd) |
|
|
| return script_name |
|
|
| return 'none' |
|
|
| @property |
| def path_(self): |
| return os.path.join(self.root, self.experiment, self.script_name_, self.name) |
|
|
| @property |
| def device_(self): |
| return self.gpus_[self.rank % self.nranks] |
|
|
|
|
| @dataclass |
| class TokenizerSettings: |
| query_token_id: str = DefaultVal("[unused0]") |
| doc_token_id: str = DefaultVal("[unused1]") |
| query_token: str = DefaultVal("[Q]") |
| doc_token: str = DefaultVal("[D]") |
|
|
|
|
| @dataclass |
| class ResourceSettings: |
| checkpoint: str = DefaultVal(None) |
| triples: str = DefaultVal(None) |
| collection: str = DefaultVal(None) |
| queries: str = DefaultVal(None) |
| index_name: str = DefaultVal(None) |
| name_or_path: str = DefaultVal(None) |
|
|
|
|
| @dataclass |
| class DocSettings: |
| dim: int = DefaultVal(128) |
| doc_maxlen: int = DefaultVal(220) |
| mask_punctuation: bool = DefaultVal(True) |
|
|
|
|
| @dataclass |
| class QuerySettings: |
| query_maxlen: int = DefaultVal(32) |
| attend_to_mask_tokens : bool = DefaultVal(False) |
| interaction: str = DefaultVal('colbert') |
|
|
|
|
| @dataclass |
| class TrainingSettings: |
| similarity: str = DefaultVal('cosine') |
|
|
| bsize: int = DefaultVal(32) |
|
|
| accumsteps: int = DefaultVal(1) |
|
|
| lr: float = DefaultVal(3e-06) |
|
|
| maxsteps: int = DefaultVal(500_000) |
|
|
| save_every: int = DefaultVal(None) |
|
|
| resume: bool = DefaultVal(False) |
|
|
| |
| warmup: int = DefaultVal(None) |
|
|
| warmup_bert: int = DefaultVal(None) |
|
|
| relu: bool = DefaultVal(False) |
|
|
| nway: int = DefaultVal(2) |
|
|
| use_ib_negatives: bool = DefaultVal(False) |
|
|
| reranker: bool = DefaultVal(False) |
|
|
| distillation_alpha: float = DefaultVal(1.0) |
|
|
| ignore_scores: bool = DefaultVal(False) |
|
|
| model_name: str = DefaultVal(None) |
|
|
| @dataclass |
| class IndexingSettings: |
| index_path: str = DefaultVal(None) |
|
|
| index_bsize: int = DefaultVal(64) |
|
|
| nbits: int = DefaultVal(1) |
|
|
| kmeans_niters: int = DefaultVal(4) |
|
|
| resume: bool = DefaultVal(False) |
|
|
| @property |
| def index_path_(self): |
| return self.index_path or os.path.join(self.index_root_, self.index_name) |
|
|
| @dataclass |
| class SearchSettings: |
| ncells: int = DefaultVal(None) |
| centroid_score_threshold: float = DefaultVal(None) |
| ndocs: int = DefaultVal(None) |
| load_index_with_mmap: bool = DefaultVal(False) |
|
|
|
|
| @dataclass |
| class CoreConfig: |
| def __post_init__(self): |
| """ |
| Source: https://stackoverflow.com/a/58081120/1493011 |
| """ |
|
|
| self.assigned = {} |
|
|
| for field in fields(self): |
| field_val = getattr(self, field.name) |
|
|
| if isinstance(field_val, DefaultVal) or field_val is None: |
| setattr(self, field.name, field.default.val) |
|
|
| if not isinstance(field_val, DefaultVal): |
| self.assigned[field.name] = True |
| |
| def assign_defaults(self): |
| for field in fields(self): |
| setattr(self, field.name, field.default.val) |
| self.assigned[field.name] = True |
|
|
| def configure(self, ignore_unrecognized=True, **kw_args): |
| ignored = set() |
|
|
| for key, value in kw_args.items(): |
| self.set(key, value, ignore_unrecognized) or ignored.update({key}) |
|
|
| return ignored |
|
|
| """ |
| # TODO: Take a config object, not kw_args. |
| |
| for key in config.assigned: |
| value = getattr(config, key) |
| """ |
|
|
| def set(self, key, value, ignore_unrecognized=False): |
| if hasattr(self, key): |
| setattr(self, key, value) |
| self.assigned[key] = True |
| return True |
|
|
| if not ignore_unrecognized: |
| raise Exception(f"Unrecognized key `{key}` for {type(self)}") |
|
|
| def help(self): |
| print(ujson.dumps(self.export(), indent=4)) |
|
|
| def __export_value(self, v): |
| v = v.provenance() if hasattr(v, 'provenance') else v |
|
|
| if isinstance(v, list) and len(v) > 100: |
| v = (f"list with {len(v)} elements starting with...", v[:3]) |
|
|
| if isinstance(v, dict) and len(v) > 100: |
| v = (f"dict with {len(v)} keys starting with...", list(v.keys())[:3]) |
|
|
| return v |
|
|
| def export(self): |
| d = dataclasses.asdict(self) |
|
|
| for k, v in d.items(): |
| d[k] = self.__export_value(v) |
|
|
| return d |
|
|
| @dataclass |
| class BaseConfig(CoreConfig): |
| @classmethod |
| def from_existing(cls, *sources): |
| kw_args = {} |
|
|
| for source in sources: |
| if source is None: |
| continue |
|
|
| local_kw_args = dataclasses.asdict(source) |
| local_kw_args = {k: local_kw_args[k] for k in source.assigned} |
| kw_args = {**kw_args, **local_kw_args} |
|
|
| obj = cls(**kw_args) |
|
|
| return obj |
|
|
| @classmethod |
| def from_deprecated_args(cls, args): |
| obj = cls() |
| ignored = obj.configure(ignore_unrecognized=True, **args) |
|
|
| return obj, ignored |
|
|
| @classmethod |
| def from_path(cls, name): |
| with open(name) as f: |
| args = ujson.load(f) |
|
|
| if "config" in args: |
| args = args["config"] |
|
|
| return cls.from_deprecated_args( |
| args |
| ) |
|
|
| @classmethod |
| def load_from_checkpoint(cls, checkpoint_path): |
| if checkpoint_path.endswith(".dnn"): |
| dnn = torch_load_dnn(checkpoint_path) |
| config, _ = cls.from_deprecated_args(dnn.get("arguments", {})) |
|
|
| |
| config.set("checkpoint", checkpoint_path) |
|
|
| return config |
|
|
| name_or_path = checkpoint_path |
| try: |
| checkpoint_path = hf_hub_download( |
| repo_id=checkpoint_path, filename="artifact.metadata" |
| ).split("artifact")[0] |
| except Exception: |
| pass |
| loaded_config_path = os.path.join(checkpoint_path, "artifact.metadata") |
| if os.path.exists(loaded_config_path): |
| loaded_config, _ = cls.from_path(loaded_config_path) |
| loaded_config.set("checkpoint", checkpoint_path) |
| loaded_config.set("name_or_path", name_or_path) |
|
|
| return loaded_config |
|
|
| return ( |
| None |
| ) |
|
|
| @classmethod |
| def load_from_index(cls, index_path): |
| |
| |
|
|
| |
| |
|
|
| |
|
|
| try: |
| metadata_path = os.path.join(index_path, "metadata.json") |
| loaded_config, _ = cls.from_path(metadata_path) |
| except: |
| metadata_path = os.path.join(index_path, "plan.json") |
| loaded_config, _ = cls.from_path(metadata_path) |
|
|
| return loaded_config |
|
|
| def save(self, path, overwrite=False): |
| assert overwrite or not os.path.exists(path), path |
|
|
| with open(path, "w") as f: |
| args = self.export() |
| args["meta"] = get_metadata_only() |
| args["meta"]["version"] = "colbert-v0.4" |
| |
|
|
| f.write(ujson.dumps(args, indent=4) + "\n") |
|
|
| def save_for_checkpoint(self, checkpoint_path): |
| assert not checkpoint_path.endswith( |
| ".dnn" |
| ), f"{checkpoint_path}: We reserve *.dnn names for the deprecated checkpoint format." |
|
|
| output_config_path = os.path.join(checkpoint_path, "artifact.metadata") |
| self.save(output_config_path, overwrite=True) |
|
|
|
|
| @dataclass |
| class ColBERTConfig(RunSettings, ResourceSettings, DocSettings, QuerySettings, TrainingSettings, |
| IndexingSettings, SearchSettings, BaseConfig, TokenizerSettings): |
| pass |