Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/__init__.py +0 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/filter.py +56 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/group.py +115 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/instance.py +38 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/metrics.py +578 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/model.py +493 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/registry.py +196 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/samplers.py +232 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/task.py +1881 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/caching/__init__.py +0 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/caching/cache.py +59 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/decontamination/__init__.py +0 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/decontamination/janitor.py +328 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/__init__.py +2 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/evaluation_tracker.py +530 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/utils.py +149 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/wandb_logger.py +358 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/LLaDA.py +786 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/__init__.py +19 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/dummy.py +41 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/hts_sampler.py +315 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/huggingface.py +1489 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/utils.py +854 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/verifier.py +154 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/prompts/__init__.py +128 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/__init__.py +670 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/gsm8k.yaml +15 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/utils.py +13 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/humaneval.yaml +13 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/utils.py +43 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/mbpp.yaml +14 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/utils.py +79 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/certifi/__main__.py +12 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/INSTALLER +1 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/LICENSE +201 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/METADATA +477 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/RECORD +12 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/WHEEL +8 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/top_level.txt +1 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/INSTALLER +1 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/LICENSE.txt +22 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/METADATA +193 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/RECORD +52 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/WHEEL +5 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/top_level.txt +1 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/INSTALLER +1 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/LICENSE +13 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/METADATA +140 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/RECORD +19 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/WHEEL +6 -0
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/__init__.py
ADDED
|
File without changes
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/filter.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Callable, Iterable, List, Union
|
| 4 |
+
|
| 5 |
+
from dllm_eval.api.instance import Instance
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Filter(ABC):
|
| 9 |
+
"""
|
| 10 |
+
Filter classes operate on a per-task level.
|
| 11 |
+
They take all model outputs (`instance.resps` for all `task.instances`)
|
| 12 |
+
across all instances of a task, and perform operations.
|
| 13 |
+
In a single run, one can configure any number of separate filters or lists of filters.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, **kwargs) -> None:
|
| 18 |
+
"""
|
| 19 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
|
| 24 |
+
"""
|
| 25 |
+
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
|
| 26 |
+
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
|
| 27 |
+
if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
|
| 28 |
+
[<filtered resps for instance 0>, <filtered resps for instance 1>]
|
| 29 |
+
"""
|
| 30 |
+
return resps
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class FilterEnsemble:
|
| 35 |
+
"""
|
| 36 |
+
FilterEnsemble creates a pipeline applying multiple filters.
|
| 37 |
+
Its intended usage is to stack multiple post-processing steps in order.
|
| 38 |
+
`task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
|
| 39 |
+
pipeline separately.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
name: str
|
| 43 |
+
filters: List[Callable[[], Filter]]
|
| 44 |
+
|
| 45 |
+
def apply(self, instances: List[Instance]) -> None:
|
| 46 |
+
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
|
| 47 |
+
resps, docs = list(resps), list(docs)
|
| 48 |
+
|
| 49 |
+
for f in self.filters:
|
| 50 |
+
# apply filters in sequence
|
| 51 |
+
resps = f().apply(resps, docs)
|
| 52 |
+
|
| 53 |
+
# add the end results after filtering to filtered_requests of their respective source instances.
|
| 54 |
+
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
|
| 55 |
+
for inst, resp in zip(instances, resps):
|
| 56 |
+
inst.filtered_resps[self.name] = resp
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/group.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from dataclasses import asdict, dataclass
|
| 3 |
+
from inspect import getsource
|
| 4 |
+
from typing import Any, Callable, List, Optional, Union
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class AggMetricConfig(dict):
|
| 9 |
+
metric: Optional[str] = None
|
| 10 |
+
aggregation: Optional[str] = "mean"
|
| 11 |
+
weight_by_size: Optional[str] = False
|
| 12 |
+
# list of filter names which should be incorporated into the aggregated metric.
|
| 13 |
+
filter_list: Optional[Union[str, list]] = "none"
|
| 14 |
+
|
| 15 |
+
def __post_init__(self):
|
| 16 |
+
if self.aggregation != "mean" and not callable(self.aggregation):
|
| 17 |
+
raise ValueError(
|
| 18 |
+
f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'."
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
if isinstance(self.filter_list, str):
|
| 22 |
+
self.filter_list = [self.filter_list]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class GroupConfig(dict):
|
| 27 |
+
group: Optional[str] = None
|
| 28 |
+
group_alias: Optional[str] = None
|
| 29 |
+
task: Optional[Union[str, list]] = None
|
| 30 |
+
aggregate_metric_list: Optional[
|
| 31 |
+
Union[List[AggMetricConfig], AggMetricConfig, dict]
|
| 32 |
+
] = None
|
| 33 |
+
metadata: Optional[dict] = (
|
| 34 |
+
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, item):
|
| 38 |
+
return getattr(self, item)
|
| 39 |
+
|
| 40 |
+
def __setitem__(self, item, value):
|
| 41 |
+
return setattr(self, item, value)
|
| 42 |
+
|
| 43 |
+
def __post_init__(self):
|
| 44 |
+
if self.aggregate_metric_list is not None:
|
| 45 |
+
if isinstance(self.aggregate_metric_list, dict):
|
| 46 |
+
self.aggregate_metric_list = [self.aggregate_metric_list]
|
| 47 |
+
|
| 48 |
+
self.aggregate_metric_list = [
|
| 49 |
+
AggMetricConfig(**item) if isinstance(item, dict) else item
|
| 50 |
+
for item in self.aggregate_metric_list
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
def to_dict(self, keep_callable: bool = False) -> dict:
|
| 54 |
+
"""dumps the current config as a dictionary object, as a printable format.
|
| 55 |
+
null fields will not be printed.
|
| 56 |
+
Used for dumping results alongside full task configuration
|
| 57 |
+
|
| 58 |
+
:return: dict
|
| 59 |
+
A printable dictionary version of the TaskConfig object.
|
| 60 |
+
|
| 61 |
+
# TODO: should any default value in the TaskConfig not be printed?
|
| 62 |
+
"""
|
| 63 |
+
cfg_dict = asdict(self)
|
| 64 |
+
# remove values that are `None`
|
| 65 |
+
for k, v in list(cfg_dict.items()):
|
| 66 |
+
if callable(v):
|
| 67 |
+
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
|
| 68 |
+
return cfg_dict
|
| 69 |
+
|
| 70 |
+
def serialize_function(
|
| 71 |
+
self, value: Union[Callable, str], keep_callable=False
|
| 72 |
+
) -> Union[Callable, str]:
|
| 73 |
+
"""Serializes a given function or string.
|
| 74 |
+
|
| 75 |
+
If 'keep_callable' is True, the original callable is returned.
|
| 76 |
+
Otherwise, attempts to return the source code of the callable using 'getsource'.
|
| 77 |
+
"""
|
| 78 |
+
if keep_callable:
|
| 79 |
+
return value
|
| 80 |
+
else:
|
| 81 |
+
try:
|
| 82 |
+
return getsource(value)
|
| 83 |
+
except (TypeError, OSError):
|
| 84 |
+
return str(value)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ConfigurableGroup(abc.ABC):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
config: Optional[dict] = None,
|
| 91 |
+
) -> None:
|
| 92 |
+
self._config = GroupConfig(**config)
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def group(self):
|
| 96 |
+
return self._config.group
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def group_alias(self):
|
| 100 |
+
return self._config.group_alias
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def version(self):
|
| 104 |
+
return self._config.version
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def config(self):
|
| 108 |
+
return self._config.to_dict()
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def group_name(self) -> Any:
|
| 112 |
+
return self._config.group
|
| 113 |
+
|
| 114 |
+
def __repr__(self):
|
| 115 |
+
return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})"
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/instance.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Literal, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
OutputType = Literal[
|
| 6 |
+
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class Instance:
|
| 12 |
+
request_type: OutputType
|
| 13 |
+
doc: dict
|
| 14 |
+
arguments: tuple
|
| 15 |
+
idx: int
|
| 16 |
+
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
|
| 17 |
+
default_factory=lambda: (None, None, None)
|
| 18 |
+
)
|
| 19 |
+
resps: list = field(default_factory=list)
|
| 20 |
+
filtered_resps: dict = field(default_factory=dict)
|
| 21 |
+
|
| 22 |
+
# initialized after init
|
| 23 |
+
task_name: Optional[str] = None
|
| 24 |
+
doc_id: Optional[int] = None
|
| 25 |
+
repeats: Optional[int] = None
|
| 26 |
+
|
| 27 |
+
def __post_init__(self) -> None:
|
| 28 |
+
# unpack metadata field
|
| 29 |
+
self.task_name, self.doc_id, self.repeats = self.metadata
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def args(self):
|
| 33 |
+
"""
|
| 34 |
+
Returns (string,) where `string` is the string to calculate loglikelihood over
|
| 35 |
+
"""
|
| 36 |
+
return (
|
| 37 |
+
self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
|
| 38 |
+
)
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/metrics.py
ADDED
|
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
import string
|
| 6 |
+
from collections.abc import Iterable
|
| 7 |
+
from typing import List
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import sacrebleu
|
| 11 |
+
|
| 12 |
+
from dllm_eval.api.registry import register_aggregation, register_metric
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
eval_logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Register Aggregations First
|
| 19 |
+
@register_aggregation("bypass")
|
| 20 |
+
def bypass_agg(arr):
|
| 21 |
+
return 999
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@register_aggregation("nanmean")
|
| 25 |
+
def nanmean(arr):
|
| 26 |
+
if len(arr) == 0 or all(np.isnan(arr)):
|
| 27 |
+
return np.nan
|
| 28 |
+
return np.nanmean(arr)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@register_aggregation("mean")
|
| 32 |
+
def mean(arr):
|
| 33 |
+
return sum(arr) / len(arr)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@register_aggregation("median")
|
| 37 |
+
def median(arr):
|
| 38 |
+
return arr[len(arr) // 2]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Certain metrics must be calculated across all documents in a benchmark.
|
| 42 |
+
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
|
| 43 |
+
@register_aggregation("perplexity")
|
| 44 |
+
def perplexity(items):
|
| 45 |
+
return math.exp(-mean(items))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@register_aggregation("weighted_perplexity")
|
| 49 |
+
def weighted_perplexity(items):
|
| 50 |
+
return math.exp(-weighted_mean(items))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@register_aggregation("bits_per_byte")
|
| 54 |
+
def bits_per_byte(items):
|
| 55 |
+
return -weighted_mean(items) / math.log(2)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@register_aggregation("f1")
|
| 59 |
+
def f1_score(items):
|
| 60 |
+
from sklearn.metrics import f1_score
|
| 61 |
+
|
| 62 |
+
unzipped_list = list(zip(*items))
|
| 63 |
+
golds = unzipped_list[0]
|
| 64 |
+
preds = unzipped_list[1]
|
| 65 |
+
fscore = f1_score(golds, preds)
|
| 66 |
+
|
| 67 |
+
return np.max(fscore)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@register_aggregation("matthews_corrcoef")
|
| 71 |
+
def matthews_corrcoef(items):
|
| 72 |
+
from sklearn.metrics import matthews_corrcoef
|
| 73 |
+
|
| 74 |
+
unzipped_list = list(zip(*items))
|
| 75 |
+
golds = unzipped_list[0]
|
| 76 |
+
preds = unzipped_list[1]
|
| 77 |
+
return matthews_corrcoef(golds, preds)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@register_aggregation("bleu")
|
| 81 |
+
def bleu(items):
|
| 82 |
+
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
|
| 83 |
+
for evaluating a generated sentence to a reference sentence. It counts matching
|
| 84 |
+
n-grams in the candidate translation to n-grams in the reference text, where
|
| 85 |
+
1-gram or unigram would be each token and a bigram comparison would be each
|
| 86 |
+
word pair. The comparison is made regardless of word order
|
| 87 |
+
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
|
| 88 |
+
Paper: https://www.aclweb.org/anthology/P02-1040/
|
| 89 |
+
|
| 90 |
+
Higher is better
|
| 91 |
+
"""
|
| 92 |
+
refs = list(zip(*items))[0]
|
| 93 |
+
preds = list(zip(*items))[1]
|
| 94 |
+
refs, preds = _sacreformat(refs, preds)
|
| 95 |
+
return sacrebleu.corpus_bleu(preds, refs).score
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@register_aggregation("chrf")
|
| 99 |
+
def chrf(items):
|
| 100 |
+
"""chrF++ is a tool for automatic evaluation of machine translation output
|
| 101 |
+
based on character n-gram precision and recall enhanced with word n-grams.
|
| 102 |
+
Source: https://github.com/m-popovic/chrF
|
| 103 |
+
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
|
| 104 |
+
|
| 105 |
+
Higher is better # TODO I think
|
| 106 |
+
"""
|
| 107 |
+
refs = list(zip(*items))[0]
|
| 108 |
+
preds = list(zip(*items))[1]
|
| 109 |
+
refs, preds = _sacreformat(refs, preds)
|
| 110 |
+
return sacrebleu.corpus_chrf(preds, refs).score
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@register_aggregation("ter")
|
| 114 |
+
def ter(items):
|
| 115 |
+
"""Translation Error Rate is an error metric for machine translation that
|
| 116 |
+
measures the number of edits required to change a system output into one
|
| 117 |
+
of the references
|
| 118 |
+
Source: http://www.cs.umd.edu/~snover/tercom/
|
| 119 |
+
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
|
| 120 |
+
|
| 121 |
+
Lower is better
|
| 122 |
+
"""
|
| 123 |
+
refs = list(zip(*items))[0]
|
| 124 |
+
preds = list(zip(*items))[1]
|
| 125 |
+
refs, preds = _sacreformat(refs, preds)
|
| 126 |
+
return sacrebleu.corpus_ter(preds, refs).score
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@register_aggregation("brier_score")
|
| 130 |
+
def brier_score(items): # This is a passthrough function
|
| 131 |
+
gold, predictions = list(zip(*items))
|
| 132 |
+
bs, num_class = np.array(predictions).shape
|
| 133 |
+
|
| 134 |
+
gold = list(gold)
|
| 135 |
+
gold_one_hot = np.eye(num_class)[gold]
|
| 136 |
+
return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@register_metric(
|
| 140 |
+
metric="brier_score",
|
| 141 |
+
higher_is_better=False,
|
| 142 |
+
output_type=["multiple_choice"],
|
| 143 |
+
aggregation="brier_score",
|
| 144 |
+
)
|
| 145 |
+
def brier_score_fn(items): # This is a passthrough function
|
| 146 |
+
return items
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@register_metric(
|
| 150 |
+
metric="acc",
|
| 151 |
+
higher_is_better=True,
|
| 152 |
+
output_type=["loglikelihood", "multiple_choice"],
|
| 153 |
+
aggregation="mean",
|
| 154 |
+
)
|
| 155 |
+
def acc_fn(items): # This is a passthrough function
|
| 156 |
+
return items
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@register_metric(
|
| 160 |
+
metric="acc_norm",
|
| 161 |
+
higher_is_better=True,
|
| 162 |
+
output_type=["loglikelihood", "multiple_choice"],
|
| 163 |
+
aggregation="mean",
|
| 164 |
+
)
|
| 165 |
+
def acc_norm_fn(items): # This is a passthrough function
|
| 166 |
+
return items
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@register_metric(
|
| 170 |
+
metric="acc_mutual_info",
|
| 171 |
+
higher_is_better=True,
|
| 172 |
+
output_type="multiple_choice",
|
| 173 |
+
aggregation="mean",
|
| 174 |
+
)
|
| 175 |
+
def acc_mutual_info_fn(items): # This is a passthrough function
|
| 176 |
+
return items
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
### the code used in the `exact_match_hf_evaluate` function is ported from
|
| 180 |
+
### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py
|
| 181 |
+
### which is under the apache license.
|
| 182 |
+
|
| 183 |
+
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
| 184 |
+
|
| 185 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 186 |
+
# you may not use this file except in compliance with the License.
|
| 187 |
+
# You may obtain a copy of the License at
|
| 188 |
+
|
| 189 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 193 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 194 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 195 |
+
# See the License for the specific language governing permissions and
|
| 196 |
+
# limitations under the License.
|
| 197 |
+
def exact_match_hf_evaluate(
|
| 198 |
+
predictions,
|
| 199 |
+
references,
|
| 200 |
+
regexes_to_ignore=None,
|
| 201 |
+
ignore_case=False,
|
| 202 |
+
ignore_punctuation=False,
|
| 203 |
+
ignore_numbers=False,
|
| 204 |
+
):
|
| 205 |
+
if regexes_to_ignore is not None:
|
| 206 |
+
for s in regexes_to_ignore:
|
| 207 |
+
predictions = np.array([re.sub(s, "", x) for x in predictions])
|
| 208 |
+
references = np.array([re.sub(s, "", x) for x in references])
|
| 209 |
+
else:
|
| 210 |
+
predictions = np.asarray(predictions)
|
| 211 |
+
references = np.asarray(references)
|
| 212 |
+
|
| 213 |
+
if ignore_case:
|
| 214 |
+
predictions = np.char.lower(predictions)
|
| 215 |
+
references = np.char.lower(references)
|
| 216 |
+
|
| 217 |
+
if ignore_punctuation:
|
| 218 |
+
repl_table = string.punctuation.maketrans("", "", string.punctuation)
|
| 219 |
+
predictions = np.char.translate(predictions, table=repl_table)
|
| 220 |
+
references = np.char.translate(references, table=repl_table)
|
| 221 |
+
|
| 222 |
+
if ignore_numbers:
|
| 223 |
+
repl_table = string.digits.maketrans("", "", string.digits)
|
| 224 |
+
predictions = np.char.translate(predictions, table=repl_table)
|
| 225 |
+
references = np.char.translate(references, table=repl_table)
|
| 226 |
+
|
| 227 |
+
score_list = predictions == references
|
| 228 |
+
|
| 229 |
+
return {"exact_match": np.mean(score_list)}
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
###
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@register_metric(
|
| 236 |
+
metric="exact_match",
|
| 237 |
+
higher_is_better=True,
|
| 238 |
+
output_type="generate_until",
|
| 239 |
+
aggregation="mean",
|
| 240 |
+
)
|
| 241 |
+
def exact_match_fn(**kwargs):
|
| 242 |
+
return exact_match_hf_evaluate(**kwargs)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@register_metric(
|
| 246 |
+
metric="perplexity",
|
| 247 |
+
higher_is_better=False,
|
| 248 |
+
output_type="loglikelihood",
|
| 249 |
+
aggregation="perplexity",
|
| 250 |
+
)
|
| 251 |
+
def perplexity_fn(items): # This is a passthrough function
|
| 252 |
+
return items
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@register_metric(
|
| 256 |
+
metric="word_perplexity",
|
| 257 |
+
higher_is_better=False,
|
| 258 |
+
output_type="loglikelihood_rolling",
|
| 259 |
+
aggregation="weighted_perplexity",
|
| 260 |
+
)
|
| 261 |
+
def word_perplexity_fn(items): # This is a passthrough function
|
| 262 |
+
return items
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
@register_metric(
|
| 266 |
+
metric="byte_perplexity",
|
| 267 |
+
higher_is_better=False,
|
| 268 |
+
output_type="loglikelihood_rolling",
|
| 269 |
+
aggregation="weighted_perplexity",
|
| 270 |
+
)
|
| 271 |
+
def byte_perplexity_fn(items): # This is a passthrough function
|
| 272 |
+
return items
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@register_metric(
|
| 276 |
+
metric="bits_per_byte",
|
| 277 |
+
higher_is_better=False,
|
| 278 |
+
output_type="loglikelihood_rolling",
|
| 279 |
+
aggregation="bits_per_byte",
|
| 280 |
+
)
|
| 281 |
+
def bits_per_byte_fn(items): # This is a passthrough function
|
| 282 |
+
return items
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def pop_stddev(arr):
|
| 286 |
+
mu = mean(arr)
|
| 287 |
+
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def sample_stddev(arr):
|
| 291 |
+
mu = mean(arr)
|
| 292 |
+
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def mean_stderr(arr):
|
| 296 |
+
return sample_stddev(arr) / math.sqrt(len(arr))
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
@register_metric(
|
| 300 |
+
metric="bypass",
|
| 301 |
+
higher_is_better=True,
|
| 302 |
+
output_type=["loglikelihood", "multiple_choice", "generate_until"],
|
| 303 |
+
aggregation="bypass",
|
| 304 |
+
)
|
| 305 |
+
def bypass(items):
|
| 306 |
+
return None
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
@register_metric(
|
| 310 |
+
metric="mcc",
|
| 311 |
+
higher_is_better=True,
|
| 312 |
+
output_type="multiple_choice",
|
| 313 |
+
aggregation="matthews_corrcoef",
|
| 314 |
+
)
|
| 315 |
+
def mcc_fn(items): # This is a passthrough function
|
| 316 |
+
return items
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
@register_metric(
|
| 320 |
+
metric="f1",
|
| 321 |
+
higher_is_better=True,
|
| 322 |
+
output_type="multiple_choice",
|
| 323 |
+
aggregation="f1",
|
| 324 |
+
)
|
| 325 |
+
def f1_fn(items): # This is a passthrough function
|
| 326 |
+
return items
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@register_metric(
|
| 330 |
+
metric="bleu",
|
| 331 |
+
higher_is_better=True,
|
| 332 |
+
output_type="generate_until",
|
| 333 |
+
aggregation="bleu",
|
| 334 |
+
)
|
| 335 |
+
def bleu_fn(items): # This is a passthrough function
|
| 336 |
+
return items
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
@register_metric(
|
| 340 |
+
metric="chrf",
|
| 341 |
+
higher_is_better=True,
|
| 342 |
+
output_type="generate_until",
|
| 343 |
+
aggregation="chrf",
|
| 344 |
+
)
|
| 345 |
+
def chrf_fn(items): # This is a passthrough function
|
| 346 |
+
return items
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
@register_metric(
|
| 350 |
+
metric="ter",
|
| 351 |
+
higher_is_better=True,
|
| 352 |
+
output_type="generate_until",
|
| 353 |
+
aggregation="ter",
|
| 354 |
+
)
|
| 355 |
+
def ter_fn(items): # This is a passthrough function
|
| 356 |
+
return items
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
@register_metric(
|
| 360 |
+
metric="acc_all",
|
| 361 |
+
higher_is_better=True,
|
| 362 |
+
output_type="loglikelihood",
|
| 363 |
+
aggregation="mean",
|
| 364 |
+
)
|
| 365 |
+
def acc_all(items):
|
| 366 |
+
# Only count as correct if all answers are labeled correctly for each question
|
| 367 |
+
question_scoring_dict = {}
|
| 368 |
+
preds = list(zip(*items))[0]
|
| 369 |
+
docs = list(zip(*items))[1]
|
| 370 |
+
|
| 371 |
+
for doc, pred in zip(docs, preds):
|
| 372 |
+
paragraph_id = doc["idx"]["paragraph"]
|
| 373 |
+
question_id = doc["idx"]["question"]
|
| 374 |
+
if (paragraph_id, question_id) not in question_scoring_dict:
|
| 375 |
+
question_scoring_dict[(paragraph_id, question_id)] = []
|
| 376 |
+
|
| 377 |
+
gold_label = doc["label"] == 1
|
| 378 |
+
|
| 379 |
+
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
|
| 380 |
+
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
|
| 381 |
+
return acc
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def acc_all_stderr(items):
|
| 385 |
+
# Only count as correct if all answers are labeled correctly for each question
|
| 386 |
+
question_scoring_dict = {}
|
| 387 |
+
preds = list(zip(*items))[0]
|
| 388 |
+
docs = list(zip(*items))[1]
|
| 389 |
+
|
| 390 |
+
for doc, pred in zip(docs, preds):
|
| 391 |
+
question_id = doc["idx"]["question"]
|
| 392 |
+
if question_id not in question_scoring_dict:
|
| 393 |
+
question_scoring_dict[question_id] = []
|
| 394 |
+
|
| 395 |
+
gold_label = doc["label"] == 1
|
| 396 |
+
question_scoring_dict[question_id].append(gold_label == pred)
|
| 397 |
+
|
| 398 |
+
acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
|
| 399 |
+
return acc
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
| 403 |
+
"""Compute max metric between prediction and each ground truth."""
|
| 404 |
+
scores_for_ground_truths = []
|
| 405 |
+
for ground_truth in ground_truths:
|
| 406 |
+
score = metric_fn(prediction, ground_truth)
|
| 407 |
+
scores_for_ground_truths.append(score)
|
| 408 |
+
return max(scores_for_ground_truths)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def weighted_mean(items):
|
| 412 |
+
a, b = zip(*items)
|
| 413 |
+
return sum(a) / sum(b)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def is_non_str_iterable(obj):
|
| 417 |
+
return isinstance(obj, Iterable) and not isinstance(obj, str)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def _sacreformat(refs, preds):
|
| 421 |
+
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
|
| 422 |
+
# Sacrebleu expects (List[str], List[List[str])
|
| 423 |
+
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
|
| 424 |
+
|
| 425 |
+
# Note [ref1_stream] is the first reference for each pred.
|
| 426 |
+
# So lists are size N and (M, N) for N preds and M possible refs for each pred
|
| 427 |
+
# This is a different order of dimensions that I would expect
|
| 428 |
+
|
| 429 |
+
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
|
| 430 |
+
# Must become List[List[str]] with the inner list corresponding to preds
|
| 431 |
+
if not is_non_str_iterable(refs):
|
| 432 |
+
refs = list(refs)
|
| 433 |
+
if not is_non_str_iterable(refs[0]):
|
| 434 |
+
refs = [[ref] for ref in refs]
|
| 435 |
+
refs = list(zip(*refs))
|
| 436 |
+
# Note the number of refs in each ref list much match the number of preds
|
| 437 |
+
|
| 438 |
+
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
|
| 439 |
+
if not is_non_str_iterable(preds):
|
| 440 |
+
preds = list(preds)
|
| 441 |
+
if is_non_str_iterable(preds[0]):
|
| 442 |
+
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
|
| 443 |
+
preds = [pred[0] for pred in preds]
|
| 444 |
+
|
| 445 |
+
return refs, preds
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
# stderr stuff
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class _bootstrap_internal:
|
| 452 |
+
def __init__(self, f, n) -> None:
|
| 453 |
+
self.f = f
|
| 454 |
+
self.n = n
|
| 455 |
+
|
| 456 |
+
def __call__(self, v):
|
| 457 |
+
i, xs = v
|
| 458 |
+
rnd = random.Random()
|
| 459 |
+
rnd.seed(i)
|
| 460 |
+
res = []
|
| 461 |
+
for _ in range(self.n):
|
| 462 |
+
res.append(self.f(rnd.choices(xs, k=len(xs))))
|
| 463 |
+
return res
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def bootstrap_stderr(f, xs, iters):
|
| 467 |
+
import multiprocessing as mp
|
| 468 |
+
|
| 469 |
+
pool = mp.Pool(mp.cpu_count())
|
| 470 |
+
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
|
| 471 |
+
# equivalent to stderr calculated without Bessel's correction in the stddev.
|
| 472 |
+
# Unfortunately, I haven't been able to figure out what the right correction is
|
| 473 |
+
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
|
| 474 |
+
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
|
| 475 |
+
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
|
| 476 |
+
res = []
|
| 477 |
+
chunk_size = min(1000, iters)
|
| 478 |
+
from tqdm import tqdm
|
| 479 |
+
|
| 480 |
+
print("bootstrapping for stddev:", f.__name__)
|
| 481 |
+
for bootstrap in tqdm(
|
| 482 |
+
pool.imap(
|
| 483 |
+
_bootstrap_internal(f, chunk_size),
|
| 484 |
+
[(i, xs) for i in range(iters // chunk_size)],
|
| 485 |
+
),
|
| 486 |
+
total=iters // chunk_size,
|
| 487 |
+
):
|
| 488 |
+
# sample w replacement
|
| 489 |
+
res.extend(bootstrap)
|
| 490 |
+
|
| 491 |
+
pool.close()
|
| 492 |
+
return sample_stddev(res)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def stderr_for_metric(metric, bootstrap_iters: int):
|
| 496 |
+
if bootstrap_iters <= 0:
|
| 497 |
+
# return no function (don't compute stderr) if bootstrap iters = 0
|
| 498 |
+
return None
|
| 499 |
+
|
| 500 |
+
bootstrappable = [
|
| 501 |
+
median,
|
| 502 |
+
matthews_corrcoef,
|
| 503 |
+
f1_score,
|
| 504 |
+
perplexity,
|
| 505 |
+
bleu,
|
| 506 |
+
chrf,
|
| 507 |
+
ter,
|
| 508 |
+
nanmean,
|
| 509 |
+
]
|
| 510 |
+
|
| 511 |
+
if metric in bootstrappable:
|
| 512 |
+
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
|
| 513 |
+
|
| 514 |
+
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
|
| 515 |
+
|
| 516 |
+
return stderr.get(metric, None)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
|
| 520 |
+
# Used to aggregate bootstrapped stderrs across subtasks in a group,
|
| 521 |
+
# when we are weighting by the size of each subtask.
|
| 522 |
+
#
|
| 523 |
+
|
| 524 |
+
assert len(stderrs) == len(sizes)
|
| 525 |
+
|
| 526 |
+
# formula source: https://en.wikipedia.org/wiki/Pooled_variance
|
| 527 |
+
# and: https://stats.stackexchange.com/a/4841331
|
| 528 |
+
# this empirically seems to match running `stderr_for_metric` on all instances
|
| 529 |
+
# from the subtasks concatenated with each other.
|
| 530 |
+
pooled_sample_var = (
|
| 531 |
+
sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)])
|
| 532 |
+
) / (sum(sizes) - len(sizes))
|
| 533 |
+
|
| 534 |
+
return np.sqrt(pooled_sample_var / sum(sizes))
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
|
| 538 |
+
assert metrics is not None, (
|
| 539 |
+
"Need to pass a list of each subtask's metric for this stderr aggregation"
|
| 540 |
+
)
|
| 541 |
+
assert len(stderrs) == len(sizes) and len(sizes) == len(metrics)
|
| 542 |
+
|
| 543 |
+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation.
|
| 544 |
+
# This formula depends on sample means.
|
| 545 |
+
# removed because it seems to give erroneously huge stderrs for groupings of tasks
|
| 546 |
+
# and does not seem to match up with bootstrap-calculated stderrs for groups.
|
| 547 |
+
|
| 548 |
+
### don't use this unless a statistician has told you it's the right thing to do ###
|
| 549 |
+
|
| 550 |
+
# accumulators: we'll aggregate pairwise N - 1 times
|
| 551 |
+
variance = stderrs[0] ** 2
|
| 552 |
+
curr_size = sizes[0]
|
| 553 |
+
curr_score = metrics[0]
|
| 554 |
+
|
| 555 |
+
for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]):
|
| 556 |
+
curr_score = ((curr_score * curr_size) + (score * size)) / (
|
| 557 |
+
curr_size + size
|
| 558 |
+
) # NOTE: this assumes our aggregation fn is "mean"
|
| 559 |
+
|
| 560 |
+
variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / (
|
| 561 |
+
curr_size + size - 1
|
| 562 |
+
) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * (
|
| 563 |
+
curr_score - score
|
| 564 |
+
) ** 2
|
| 565 |
+
|
| 566 |
+
return np.sqrt(variance)
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
|
| 570 |
+
# A helper function that is used to aggregate
|
| 571 |
+
# subtask scores cross-task.
|
| 572 |
+
# TODO: does not hold for non-mean aggregations
|
| 573 |
+
if not weight_by_size:
|
| 574 |
+
sizes = [1] * len(sizes)
|
| 575 |
+
|
| 576 |
+
assert len(metrics) == len(sizes)
|
| 577 |
+
|
| 578 |
+
return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes)
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/model.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import hashlib
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union
|
| 7 |
+
|
| 8 |
+
import transformers
|
| 9 |
+
from sqlitedict import SqliteDict
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from dllm_eval import utils
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
eval_logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
T = TypeVar("T", bound="LM")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LM(abc.ABC):
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
"""Defines the interface that should be implemented by all LM subclasses.
|
| 23 |
+
LMs are assumed to take text (strings) as input and yield strings as output
|
| 24 |
+
(inputs/outputs should be tokenization-agnostic.)
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
# set rank and world size to a single process, by default.
|
| 28 |
+
self._rank = 0
|
| 29 |
+
self._world_size = 1
|
| 30 |
+
self.cache_hook = CacheHook(None)
|
| 31 |
+
|
| 32 |
+
@abc.abstractmethod
|
| 33 |
+
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
|
| 34 |
+
"""Compute log-likelihood of generating a continuation from a context.
|
| 35 |
+
Downstream tasks should attempt to use loglikelihood instead of other
|
| 36 |
+
LM calls whenever possible.
|
| 37 |
+
|
| 38 |
+
:param requests: list[Instance]
|
| 39 |
+
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
|
| 40 |
+
`context: str`
|
| 41 |
+
Context string. Implementations of LM must be able to handle an
|
| 42 |
+
empty context string.
|
| 43 |
+
`continuation: str`
|
| 44 |
+
The continuation over which log likelihood will be calculated. If
|
| 45 |
+
there is a word boundary, the space should be in the continuation.
|
| 46 |
+
For example, context="hello" continuation=" world" is correct.
|
| 47 |
+
|
| 48 |
+
:return: list[tuple[float, bool]]
|
| 49 |
+
A list of pairs (logprob, isgreedy)
|
| 50 |
+
`logprob: float`
|
| 51 |
+
The log probability of `continuation`.
|
| 52 |
+
`isgreedy`:
|
| 53 |
+
Whether `continuation` would be generated by greedy sampling from `context`.
|
| 54 |
+
"""
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
@abc.abstractmethod
|
| 58 |
+
def loglikelihood_rolling(self, requests) -> List[float]:
|
| 59 |
+
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
| 60 |
+
- We will use the full max context length of the model.
|
| 61 |
+
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
|
| 62 |
+
the max context length.
|
| 63 |
+
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
|
| 64 |
+
which may simply concatenate multiple documents together.
|
| 65 |
+
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
|
| 66 |
+
multiple chunks, the last input will still a full-sized context.
|
| 67 |
+
Example:
|
| 68 |
+
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
|
| 69 |
+
Prefix: BOS/EOS
|
| 70 |
+
Max context length: 4
|
| 71 |
+
Resulting input/prediction pairs:
|
| 72 |
+
|
| 73 |
+
INPUT: BOS 0 1 2
|
| 74 |
+
PRED: 0 1 2 3
|
| 75 |
+
|
| 76 |
+
INPUT: 3 4 5 6
|
| 77 |
+
PRED: 4 5 6 7
|
| 78 |
+
|
| 79 |
+
INPUT: 5 6 7 8
|
| 80 |
+
PRED: 8 9
|
| 81 |
+
|
| 82 |
+
Observe that:
|
| 83 |
+
1. Each token is predicted exactly once
|
| 84 |
+
2. For the last pair, we provide the full context, but only score the last two tokens
|
| 85 |
+
|
| 86 |
+
:param requests: list[Instance]
|
| 87 |
+
A list of Instance objects with property `args` which returns a tuple (context,).
|
| 88 |
+
string: str
|
| 89 |
+
String for which we are computing overall loglikelihood
|
| 90 |
+
:return: list[tuple[float]]
|
| 91 |
+
A list of tuples (logprob,)
|
| 92 |
+
logprob: float
|
| 93 |
+
The log probability of `context` conditioned on the BOS/EOS token.
|
| 94 |
+
Can also be overridden for custom cases by `prefix_token_id`.
|
| 95 |
+
"""
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
# TODO: Add an optional max length
|
| 99 |
+
@abc.abstractmethod
|
| 100 |
+
def generate_until(self, requests) -> List[str]:
|
| 101 |
+
"""Generate greedily until a stopping sequence
|
| 102 |
+
|
| 103 |
+
:param requests: list[Instance]
|
| 104 |
+
A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
|
| 105 |
+
context: str
|
| 106 |
+
Context string
|
| 107 |
+
gen_kwargs: dict
|
| 108 |
+
A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
|
| 109 |
+
:return: list[str]
|
| 110 |
+
A list of model generated continuations.
|
| 111 |
+
continuation: str
|
| 112 |
+
The generated continuation.
|
| 113 |
+
"""
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
def apply_chat_template(
|
| 117 |
+
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
|
| 118 |
+
) -> str:
|
| 119 |
+
"""
|
| 120 |
+
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
|
| 121 |
+
|
| 122 |
+
:param chat_history: list[dict[str, str]]
|
| 123 |
+
A list of dictionaries with keys 'role' and 'content'.
|
| 124 |
+
Values are strings representing the role name and the content of the message, respectively.
|
| 125 |
+
:param add_generation_prompt: bool
|
| 126 |
+
Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message.
|
| 127 |
+
:return: str
|
| 128 |
+
A string representing the chat history in a format that can be used as input to the LM.
|
| 129 |
+
"""
|
| 130 |
+
raise NotImplementedError(
|
| 131 |
+
"To use this model with chat templates, please implement the 'apply_chat_template' method for your model type."
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
@classmethod
|
| 135 |
+
def create_from_arg_string(
|
| 136 |
+
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
|
| 137 |
+
) -> T:
|
| 138 |
+
"""
|
| 139 |
+
Creates an instance of the LM class using the given argument string and additional config.
|
| 140 |
+
|
| 141 |
+
Parameters:
|
| 142 |
+
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
|
| 143 |
+
- additional_config: Optional dictionary containing additional configuration parameters.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
- Instance of the LM class.
|
| 147 |
+
"""
|
| 148 |
+
additional_config = {} if additional_config is None else additional_config
|
| 149 |
+
args = utils.simple_parse_args_string(arg_string)
|
| 150 |
+
args2 = {k: v for k, v in additional_config.items() if v is not None}
|
| 151 |
+
return cls(**args, **args2)
|
| 152 |
+
|
| 153 |
+
@classmethod
|
| 154 |
+
def create_from_arg_obj(
|
| 155 |
+
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
|
| 156 |
+
) -> T:
|
| 157 |
+
"""
|
| 158 |
+
Creates an instance of the LM class using the given arg_obj
|
| 159 |
+
|
| 160 |
+
Parameters:
|
| 161 |
+
- arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
|
| 162 |
+
- additional_config: Optional dictionary containing additional configuration parameters.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
- Instance of the LM class.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
additional_config = {} if additional_config is None else additional_config
|
| 169 |
+
additional_config = {
|
| 170 |
+
k: v for k, v in additional_config.items() if v is not None
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
return cls(**arg_dict, **additional_config)
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def rank(self):
|
| 177 |
+
# used in the case of parallelism. Hardcoded to
|
| 178 |
+
# ensure no errors arise using API models which do
|
| 179 |
+
# not support multi-device parallelism nor expect it.
|
| 180 |
+
return self._rank
|
| 181 |
+
|
| 182 |
+
@property
|
| 183 |
+
def world_size(self):
|
| 184 |
+
# used in the case of parallelism. Hardcoded to
|
| 185 |
+
# ensure no errors arise using API models which do
|
| 186 |
+
# not support multi-device parallelism nor expect it.
|
| 187 |
+
return self._world_size
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def tokenizer_name(self) -> str:
|
| 191 |
+
"""Must be defined for LM subclasses which implement Chat Templating.
|
| 192 |
+
Should return the name of the tokenizer or chat template used.
|
| 193 |
+
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
|
| 194 |
+
"""
|
| 195 |
+
raise NotImplementedError(
|
| 196 |
+
"To use this model with chat templates, please implement the 'tokenizer_name' property."
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
|
| 200 |
+
"""Returns the chat template structure for user/assistant messages if a template is provided.
|
| 201 |
+
This method is intended to be overridden in a subclass to define a specific chat template format.
|
| 202 |
+
For models that do not support chat templates, this method returns None by default.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
return ""
|
| 206 |
+
|
| 207 |
+
def set_cache_hook(self, cache_hook) -> None:
|
| 208 |
+
self.cache_hook = cache_hook
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
### SQLite-based caching of LM responses
|
| 212 |
+
def hash_args(attr, args):
|
| 213 |
+
dat = json.dumps([attr] + list(args))
|
| 214 |
+
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class CacheHook:
|
| 218 |
+
def __init__(self, cachinglm) -> None:
|
| 219 |
+
if cachinglm is None:
|
| 220 |
+
self.dbdict = None
|
| 221 |
+
return
|
| 222 |
+
|
| 223 |
+
self.dbdict = cachinglm.dbdict
|
| 224 |
+
|
| 225 |
+
def add_partial(self, attr, req, res) -> None:
|
| 226 |
+
if self.dbdict is None:
|
| 227 |
+
return
|
| 228 |
+
hsh = hash_args(attr, req)
|
| 229 |
+
self.dbdict[hsh] = res
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class CachingLM:
|
| 233 |
+
def __init__(self, lm, cache_db) -> None:
|
| 234 |
+
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
|
| 235 |
+
|
| 236 |
+
:param lm: LM
|
| 237 |
+
Underlying LM
|
| 238 |
+
:param cache_db: str
|
| 239 |
+
Path to cache db
|
| 240 |
+
"""
|
| 241 |
+
self.lm = lm
|
| 242 |
+
self.cache_db = cache_db
|
| 243 |
+
if os.path.dirname(cache_db):
|
| 244 |
+
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
|
| 245 |
+
self.dbdict = SqliteDict(cache_db, autocommit=True)
|
| 246 |
+
|
| 247 |
+
# add hook to lm
|
| 248 |
+
lm.set_cache_hook(self.get_cache_hook())
|
| 249 |
+
|
| 250 |
+
def __getattr__(self, attr: str):
|
| 251 |
+
lm_attr = getattr(self.lm, attr)
|
| 252 |
+
if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
|
| 253 |
+
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
|
| 254 |
+
return lm_attr
|
| 255 |
+
|
| 256 |
+
def fn(requests):
|
| 257 |
+
res = []
|
| 258 |
+
remaining_reqs = []
|
| 259 |
+
warned = False
|
| 260 |
+
# figure out which ones are cached and which ones are new
|
| 261 |
+
eval_logger.info(
|
| 262 |
+
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
|
| 263 |
+
)
|
| 264 |
+
for req in tqdm(requests, desc="Checking cached requests"):
|
| 265 |
+
hsh = hash_args(attr, req.args)
|
| 266 |
+
if attr == "generate_until" and req.args[1].get("do_sample", False):
|
| 267 |
+
# when we are doing non-greedy generation, don't use the cache
|
| 268 |
+
# (else every "randomly sampled" generation would be identical for repeats > 1).
|
| 269 |
+
if not warned:
|
| 270 |
+
eval_logger.warning(
|
| 271 |
+
f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
|
| 272 |
+
)
|
| 273 |
+
warned = True
|
| 274 |
+
res.append(None)
|
| 275 |
+
remaining_reqs.append(req)
|
| 276 |
+
elif hsh in self.dbdict:
|
| 277 |
+
ob = self.dbdict[hsh]
|
| 278 |
+
|
| 279 |
+
assert ob is not None
|
| 280 |
+
|
| 281 |
+
res.append(ob)
|
| 282 |
+
else:
|
| 283 |
+
res.append(None)
|
| 284 |
+
remaining_reqs.append(req)
|
| 285 |
+
eval_logger.info(
|
| 286 |
+
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
|
| 287 |
+
)
|
| 288 |
+
if remaining_reqs:
|
| 289 |
+
# actually run the LM on the requests that do not have cached results
|
| 290 |
+
rem_res = getattr(self.lm, attr)(remaining_reqs)
|
| 291 |
+
else:
|
| 292 |
+
rem_res = []
|
| 293 |
+
|
| 294 |
+
# stick the new ones back into the list and also cache any of the new ones
|
| 295 |
+
resptr = 0
|
| 296 |
+
for req, r in zip(remaining_reqs, rem_res):
|
| 297 |
+
while res[resptr] is not None:
|
| 298 |
+
resptr += 1
|
| 299 |
+
|
| 300 |
+
res[resptr] = r
|
| 301 |
+
|
| 302 |
+
# caching
|
| 303 |
+
hsh = hash_args(attr, req.args)
|
| 304 |
+
self.dbdict[hsh] = r
|
| 305 |
+
self.dbdict.commit()
|
| 306 |
+
|
| 307 |
+
return res
|
| 308 |
+
|
| 309 |
+
return fn
|
| 310 |
+
|
| 311 |
+
def get_cache_hook(self):
|
| 312 |
+
return CacheHook(self)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class TemplateLM(LM):
|
| 316 |
+
"""
|
| 317 |
+
A class acting as intermediary between the LM base class
|
| 318 |
+
and boilerplate often included in other LM subclasses.
|
| 319 |
+
"""
|
| 320 |
+
|
| 321 |
+
tokenizer = None
|
| 322 |
+
|
| 323 |
+
@property
|
| 324 |
+
@abc.abstractmethod
|
| 325 |
+
def eot_token_id(self):
|
| 326 |
+
pass
|
| 327 |
+
|
| 328 |
+
@property
|
| 329 |
+
def prefix_token_id(self):
|
| 330 |
+
# it is used as prefix for loglikelihood
|
| 331 |
+
return self.eot_token_id
|
| 332 |
+
|
| 333 |
+
@abc.abstractmethod
|
| 334 |
+
def tok_encode(self, string: str, **kwargs) -> List[int]:
|
| 335 |
+
"""
|
| 336 |
+
Tokenize a string using the model's tokenizer and return a list of token IDs.
|
| 337 |
+
"""
|
| 338 |
+
pass
|
| 339 |
+
|
| 340 |
+
@abc.abstractmethod
|
| 341 |
+
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
|
| 342 |
+
pass
|
| 343 |
+
|
| 344 |
+
def _encode_pair(
|
| 345 |
+
self, context: str, continuation: str
|
| 346 |
+
) -> Tuple[List[int], List[int]]:
|
| 347 |
+
n_spaces = len(context) - len(context.rstrip())
|
| 348 |
+
if n_spaces > 0:
|
| 349 |
+
continuation = context[-n_spaces:] + continuation
|
| 350 |
+
context = context[:-n_spaces]
|
| 351 |
+
|
| 352 |
+
model_class = getattr(self, "AUTO_MODEL_CLASS", None)
|
| 353 |
+
|
| 354 |
+
if model_class == transformers.AutoModelForSeq2SeqLM:
|
| 355 |
+
context_enc = self.tok_encode(context)
|
| 356 |
+
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
|
| 357 |
+
else:
|
| 358 |
+
whole_enc = self.tok_encode(context + continuation)
|
| 359 |
+
context_enc = self.tok_encode(context)
|
| 360 |
+
|
| 361 |
+
context_enc_len = len(context_enc)
|
| 362 |
+
continuation_enc = whole_enc[context_enc_len:]
|
| 363 |
+
|
| 364 |
+
return context_enc, continuation_enc
|
| 365 |
+
|
| 366 |
+
def loglikelihood(
|
| 367 |
+
self, requests, disable_tqdm: bool = False
|
| 368 |
+
) -> List[Tuple[float, bool]]:
|
| 369 |
+
new_reqs = []
|
| 370 |
+
for context, continuation in [req.args for req in requests]:
|
| 371 |
+
if context == "":
|
| 372 |
+
# BOS or EOS as context
|
| 373 |
+
context_enc, continuation_enc = (
|
| 374 |
+
[self.prefix_token_id],
|
| 375 |
+
self.tok_encode(continuation),
|
| 376 |
+
)
|
| 377 |
+
else:
|
| 378 |
+
context_enc, continuation_enc = self._encode_pair(context, continuation)
|
| 379 |
+
|
| 380 |
+
new_reqs.append(((context, continuation), context_enc, continuation_enc))
|
| 381 |
+
|
| 382 |
+
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
|
| 383 |
+
|
| 384 |
+
@abc.abstractmethod
|
| 385 |
+
def loglikelihood_rolling(
|
| 386 |
+
self, requests, disable_tqdm: bool = False
|
| 387 |
+
) -> List[float]:
|
| 388 |
+
pass
|
| 389 |
+
|
| 390 |
+
@abc.abstractmethod
|
| 391 |
+
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
|
| 392 |
+
pass
|
| 393 |
+
|
| 394 |
+
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
|
| 395 |
+
"""
|
| 396 |
+
Set and get the appropriate chat template for the model.
|
| 397 |
+
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
|
| 398 |
+
|
| 399 |
+
The template selection logic is adapted from the Transformers library's `apply_chat_template`
|
| 400 |
+
method in the Tokenizer class. The original implementation can be found at:
|
| 401 |
+
https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687
|
| 402 |
+
|
| 403 |
+
This method ensures that the right template is chosen based on the following:
|
| 404 |
+
0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string.
|
| 405 |
+
1. If the model's tokenizer has multiple templates:
|
| 406 |
+
a. Use the specified template if it exists in the dictionary.
|
| 407 |
+
b. Use the default template from the list if no specific template is provided.
|
| 408 |
+
c. Raise an error if no default template exists and no specific template is provided.
|
| 409 |
+
2. If the model's tokenizer has a single template or no template:
|
| 410 |
+
a. Use the tokenizer's chat template if available.
|
| 411 |
+
b. Fall back to the default chat template if no tokenizer chat template exists.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
chat_template (Union[bool, str]): Specifies the chat template to use.
|
| 415 |
+
- If False or None, no template is applied.
|
| 416 |
+
- If True, the default or only available template is used.
|
| 417 |
+
- If a string, the template with the matching name is used.
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
Optional[str]: The selected chat template, or None if no template is applied.
|
| 421 |
+
"""
|
| 422 |
+
if self.tokenizer is None:
|
| 423 |
+
return ""
|
| 424 |
+
|
| 425 |
+
if chat_template is False or chat_template is None:
|
| 426 |
+
eval_logger.warning(
|
| 427 |
+
"model.chat_template was called with the chat_template set to False or None. "
|
| 428 |
+
"Therefore no chat template will be applied. Make sure this is an intended behavior."
|
| 429 |
+
)
|
| 430 |
+
return None
|
| 431 |
+
|
| 432 |
+
# Convert boolean chat_template to None to ensure compatibility with the adapted logic
|
| 433 |
+
if isinstance(chat_template, bool):
|
| 434 |
+
chat_template = None
|
| 435 |
+
using_default_template = False
|
| 436 |
+
|
| 437 |
+
# First, handle the cases when the model has a dict of multiple templates
|
| 438 |
+
try:
|
| 439 |
+
template = (
|
| 440 |
+
self.tokenizer.chat_template or self.tokenizer.default_chat_template
|
| 441 |
+
)
|
| 442 |
+
except AttributeError:
|
| 443 |
+
return None
|
| 444 |
+
|
| 445 |
+
if isinstance(template, dict):
|
| 446 |
+
using_default_dict = self.tokenizer.chat_template is None
|
| 447 |
+
|
| 448 |
+
if chat_template is not None:
|
| 449 |
+
if chat_template in template:
|
| 450 |
+
selected_template = template[chat_template]
|
| 451 |
+
if using_default_dict:
|
| 452 |
+
using_default_template = True
|
| 453 |
+
else:
|
| 454 |
+
raise ValueError(
|
| 455 |
+
f"The specified chat template '{chat_template}' is not available. "
|
| 456 |
+
f"Available template names are {sorted(template.keys())}."
|
| 457 |
+
)
|
| 458 |
+
else:
|
| 459 |
+
# If user didn't pass a chat template, use the default template from the dict
|
| 460 |
+
if "default" in template:
|
| 461 |
+
selected_template = template["default"]
|
| 462 |
+
using_default_template = True
|
| 463 |
+
else:
|
| 464 |
+
raise ValueError(
|
| 465 |
+
"This model has multiple chat templates with no default specified! Please either pass a chat "
|
| 466 |
+
"template or the name of the template you wish to use to the `chat_template` argument. Available "
|
| 467 |
+
f"template names are {sorted(template.keys())}."
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
# Cases when the model has a single template or no template
|
| 471 |
+
else:
|
| 472 |
+
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
|
| 473 |
+
if isinstance(chat_template, str):
|
| 474 |
+
eval_logger.warning(
|
| 475 |
+
"Chat template name provided, but the tokenizer's chat template is not a dictionary. "
|
| 476 |
+
"Using the tokenizer's chat template or the default template instead."
|
| 477 |
+
)
|
| 478 |
+
if self.tokenizer.chat_template is not None:
|
| 479 |
+
selected_template = self.tokenizer.chat_template
|
| 480 |
+
else:
|
| 481 |
+
selected_template = self.tokenizer.default_chat_template
|
| 482 |
+
using_default_template = True
|
| 483 |
+
|
| 484 |
+
if using_default_template:
|
| 485 |
+
eval_logger.warning(
|
| 486 |
+
"No chat template is set for this tokenizer, falling back to a default class-level template. This is "
|
| 487 |
+
"very error-prone, because models are often trained with templates different from the class default! "
|
| 488 |
+
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
|
| 489 |
+
"point any code depending on them will stop working. We recommend setting a valid chat template before "
|
| 490 |
+
"then to ensure that this model continues working without issues."
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
return selected_template
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/registry.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Callable, Dict, Union
|
| 3 |
+
|
| 4 |
+
import evaluate as hf_evaluate
|
| 5 |
+
|
| 6 |
+
from dllm_eval.api.model import LM
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
eval_logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
MODEL_REGISTRY = {}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def register_model(*names):
|
| 15 |
+
# either pass a list or a single alias.
|
| 16 |
+
# function receives them as a tuple of strings
|
| 17 |
+
|
| 18 |
+
def decorate(cls):
|
| 19 |
+
for name in names:
|
| 20 |
+
assert issubclass(cls, LM), (
|
| 21 |
+
f"Model '{name}' ({cls.__name__}) must extend LM class"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
assert name not in MODEL_REGISTRY, (
|
| 25 |
+
f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
MODEL_REGISTRY[name] = cls
|
| 29 |
+
return cls
|
| 30 |
+
|
| 31 |
+
return decorate
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_model(model_name):
|
| 35 |
+
try:
|
| 36 |
+
return MODEL_REGISTRY[model_name]
|
| 37 |
+
except KeyError:
|
| 38 |
+
raise ValueError(
|
| 39 |
+
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
TASK_REGISTRY = {}
|
| 44 |
+
GROUP_REGISTRY = {}
|
| 45 |
+
ALL_TASKS = set()
|
| 46 |
+
func2task_index = {}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def register_task(name):
|
| 50 |
+
def decorate(fn):
|
| 51 |
+
assert name not in TASK_REGISTRY, (
|
| 52 |
+
f"task named '{name}' conflicts with existing registered task!"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
TASK_REGISTRY[name] = fn
|
| 56 |
+
ALL_TASKS.add(name)
|
| 57 |
+
func2task_index[fn.__name__] = name
|
| 58 |
+
return fn
|
| 59 |
+
|
| 60 |
+
return decorate
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def register_group(name):
|
| 64 |
+
def decorate(fn):
|
| 65 |
+
func_name = func2task_index[fn.__name__]
|
| 66 |
+
if name in GROUP_REGISTRY:
|
| 67 |
+
GROUP_REGISTRY[name].append(func_name)
|
| 68 |
+
else:
|
| 69 |
+
GROUP_REGISTRY[name] = [func_name]
|
| 70 |
+
ALL_TASKS.add(name)
|
| 71 |
+
return fn
|
| 72 |
+
|
| 73 |
+
return decorate
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
OUTPUT_TYPE_REGISTRY = {}
|
| 77 |
+
METRIC_REGISTRY = {}
|
| 78 |
+
METRIC_AGGREGATION_REGISTRY = {}
|
| 79 |
+
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
|
| 80 |
+
HIGHER_IS_BETTER_REGISTRY = {}
|
| 81 |
+
FILTER_REGISTRY = {}
|
| 82 |
+
|
| 83 |
+
DEFAULT_METRIC_REGISTRY = {
|
| 84 |
+
"loglikelihood": [
|
| 85 |
+
"perplexity",
|
| 86 |
+
"acc",
|
| 87 |
+
],
|
| 88 |
+
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
|
| 89 |
+
"multiple_choice": ["acc", "acc_norm"],
|
| 90 |
+
"generate_until": ["exact_match"],
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def register_metric(**args):
|
| 95 |
+
# TODO: do we want to enforce a certain interface to registered metrics?
|
| 96 |
+
def decorate(fn):
|
| 97 |
+
assert "metric" in args
|
| 98 |
+
name = args["metric"]
|
| 99 |
+
|
| 100 |
+
for key, registry in [
|
| 101 |
+
("metric", METRIC_REGISTRY),
|
| 102 |
+
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
|
| 103 |
+
("aggregation", METRIC_AGGREGATION_REGISTRY),
|
| 104 |
+
]:
|
| 105 |
+
if key in args:
|
| 106 |
+
value = args[key]
|
| 107 |
+
assert value not in registry, (
|
| 108 |
+
f"{key} named '{value}' conflicts with existing registered {key}!"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if key == "metric":
|
| 112 |
+
registry[name] = fn
|
| 113 |
+
elif key == "aggregation":
|
| 114 |
+
registry[name] = AGGREGATION_REGISTRY[value]
|
| 115 |
+
else:
|
| 116 |
+
registry[name] = value
|
| 117 |
+
|
| 118 |
+
return fn
|
| 119 |
+
|
| 120 |
+
return decorate
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
|
| 124 |
+
if not hf_evaluate_metric:
|
| 125 |
+
if name in METRIC_REGISTRY:
|
| 126 |
+
return METRIC_REGISTRY[name]
|
| 127 |
+
else:
|
| 128 |
+
eval_logger.warning(
|
| 129 |
+
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
metric_object = hf_evaluate.load(name)
|
| 134 |
+
return metric_object.compute
|
| 135 |
+
except Exception:
|
| 136 |
+
eval_logger.error(
|
| 137 |
+
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def register_aggregation(name: str):
|
| 142 |
+
def decorate(fn):
|
| 143 |
+
assert name not in AGGREGATION_REGISTRY, (
|
| 144 |
+
f"aggregation named '{name}' conflicts with existing registered aggregation!"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
AGGREGATION_REGISTRY[name] = fn
|
| 148 |
+
return fn
|
| 149 |
+
|
| 150 |
+
return decorate
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
|
| 154 |
+
try:
|
| 155 |
+
return AGGREGATION_REGISTRY[name]
|
| 156 |
+
except KeyError:
|
| 157 |
+
eval_logger.warning(f"{name} not a registered aggregation metric!")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
|
| 161 |
+
try:
|
| 162 |
+
return METRIC_AGGREGATION_REGISTRY[name]
|
| 163 |
+
except KeyError:
|
| 164 |
+
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def is_higher_better(metric_name) -> bool:
|
| 168 |
+
try:
|
| 169 |
+
return HIGHER_IS_BETTER_REGISTRY[metric_name]
|
| 170 |
+
except KeyError:
|
| 171 |
+
eval_logger.warning(
|
| 172 |
+
f"higher_is_better not specified for metric '{metric_name}'!"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def register_filter(name):
|
| 177 |
+
def decorate(cls):
|
| 178 |
+
if name in FILTER_REGISTRY:
|
| 179 |
+
eval_logger.info(
|
| 180 |
+
f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}"
|
| 181 |
+
)
|
| 182 |
+
FILTER_REGISTRY[name] = cls
|
| 183 |
+
return cls
|
| 184 |
+
|
| 185 |
+
return decorate
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_filter(filter_name: Union[str, Callable]) -> Callable:
|
| 189 |
+
try:
|
| 190 |
+
return FILTER_REGISTRY[filter_name]
|
| 191 |
+
except KeyError as e:
|
| 192 |
+
if callable(filter_name):
|
| 193 |
+
return filter_name
|
| 194 |
+
else:
|
| 195 |
+
eval_logger.warning(f"filter `{filter_name}` is not registered!")
|
| 196 |
+
raise e
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/samplers.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import warnings
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import TYPE_CHECKING, Iterable, Optional, Union
|
| 5 |
+
|
| 6 |
+
import datasets
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from random import Random
|
| 11 |
+
|
| 12 |
+
from dllm_eval.api.task import ConfigurableTask, Task
|
| 13 |
+
|
| 14 |
+
eval_logger = logging.getLogger("lm-eval")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ContextSampler:
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
docs: list[dict],
|
| 21 |
+
task: Union["Task", "ConfigurableTask"],
|
| 22 |
+
fewshot_indices: Optional[Iterable] = None,
|
| 23 |
+
rnd: Optional["Random"] = None,
|
| 24 |
+
) -> None:
|
| 25 |
+
self.rnd = rnd
|
| 26 |
+
if not self.rnd:
|
| 27 |
+
raise ValueError(
|
| 28 |
+
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
self.task = task
|
| 32 |
+
self.config = task._config
|
| 33 |
+
|
| 34 |
+
self.target_delimiter = self.config.target_delimiter
|
| 35 |
+
self.fewshot_delimiter = self.config.fewshot_delimiter
|
| 36 |
+
|
| 37 |
+
if (
|
| 38 |
+
self.config.fewshot_config is not None
|
| 39 |
+
and self.config.fewshot_config.get("doc_to_text", None) is not None
|
| 40 |
+
):
|
| 41 |
+
self.doc_to_text = partial(
|
| 42 |
+
self.task.doc_to_text,
|
| 43 |
+
doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
|
| 44 |
+
)
|
| 45 |
+
else:
|
| 46 |
+
self.doc_to_text = self.task.doc_to_text
|
| 47 |
+
|
| 48 |
+
if (
|
| 49 |
+
self.config.fewshot_config is not None
|
| 50 |
+
and self.config.fewshot_config.get("doc_to_target", None) is not None
|
| 51 |
+
):
|
| 52 |
+
self.doc_to_target = partial(
|
| 53 |
+
self.task.doc_to_target,
|
| 54 |
+
doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
self.doc_to_target = self.task.doc_to_target
|
| 58 |
+
|
| 59 |
+
if (
|
| 60 |
+
self.config.fewshot_config is not None
|
| 61 |
+
and self.config.fewshot_config.get("doc_to_choice", None) is not None
|
| 62 |
+
):
|
| 63 |
+
self.doc_to_choice = partial(
|
| 64 |
+
self.task.doc_to_choice,
|
| 65 |
+
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
self.doc_to_choice = self.task.doc_to_choice
|
| 69 |
+
|
| 70 |
+
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
|
| 71 |
+
if fewshot_indices: # subset few-shot docs from
|
| 72 |
+
if not isinstance(self.docs, datasets.Dataset):
|
| 73 |
+
raise ValueError(
|
| 74 |
+
"Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
|
| 75 |
+
)
|
| 76 |
+
self.docs = self.docs.select(fewshot_indices)
|
| 77 |
+
|
| 78 |
+
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
|
| 79 |
+
# draw an extra fewshot sample if using same split as evaluating on
|
| 80 |
+
prefix = gen_prefix + " " if gen_prefix else ""
|
| 81 |
+
n_samples = (
|
| 82 |
+
num_fewshot + 1
|
| 83 |
+
if self.config.fewshot_split == self.config.test_split
|
| 84 |
+
else num_fewshot
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# draw `n_samples` docs from fewshot_docs
|
| 88 |
+
fewshotex = self.sample(n_samples)
|
| 89 |
+
|
| 90 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 91 |
+
# TODO: should we just stop people from using fewshot from same split as evaluating?
|
| 92 |
+
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 93 |
+
|
| 94 |
+
labeled_examples = ""
|
| 95 |
+
for doc in selected_docs:
|
| 96 |
+
doc_content = self.doc_to_text(doc)
|
| 97 |
+
doc_target = self.doc_to_target(doc)
|
| 98 |
+
if self.config.doc_to_choice is None or isinstance(doc_content, str):
|
| 99 |
+
labeled_examples += doc_content
|
| 100 |
+
else:
|
| 101 |
+
labeled_examples += self.doc_to_choice(doc)[doc_content]
|
| 102 |
+
|
| 103 |
+
if doc_target != "":
|
| 104 |
+
if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
|
| 105 |
+
# TODO: add logger warn once here.
|
| 106 |
+
warnings.warn(
|
| 107 |
+
"Both target_delimiter and target start with a space. This may cause issues.",
|
| 108 |
+
Warning,
|
| 109 |
+
stacklevel=2,
|
| 110 |
+
)
|
| 111 |
+
labeled_examples += self.target_delimiter
|
| 112 |
+
labeled_examples += prefix
|
| 113 |
+
labeled_examples += (
|
| 114 |
+
str(doc_target[0])
|
| 115 |
+
if isinstance(doc_target, list)
|
| 116 |
+
else doc_target
|
| 117 |
+
if self.config.doc_to_choice is None or isinstance(doc_target, str)
|
| 118 |
+
else str(self.doc_to_choice(doc)[doc_target])
|
| 119 |
+
)
|
| 120 |
+
labeled_examples += self.fewshot_delimiter
|
| 121 |
+
|
| 122 |
+
return labeled_examples
|
| 123 |
+
|
| 124 |
+
def get_chat_context(
|
| 125 |
+
self,
|
| 126 |
+
doc: dict,
|
| 127 |
+
num_fewshot: int,
|
| 128 |
+
fewshot_as_multiturn: bool = False,
|
| 129 |
+
gen_prefix: Optional[str] = None,
|
| 130 |
+
):
|
| 131 |
+
# TODO: Do we need any other delimiter
|
| 132 |
+
prefix = gen_prefix + " " if gen_prefix else ""
|
| 133 |
+
chat_history = []
|
| 134 |
+
# draw an extra fewshot sample if using same split as evaluating on
|
| 135 |
+
n_samples = (
|
| 136 |
+
num_fewshot + 1
|
| 137 |
+
if self.config.fewshot_split == self.config.test_split
|
| 138 |
+
else num_fewshot
|
| 139 |
+
)
|
| 140 |
+
# draw `n_samples` docs from fewshot_docs
|
| 141 |
+
fewshotex = self.sample(n_samples)
|
| 142 |
+
|
| 143 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 144 |
+
# TODO: should we just stop people from using fewshot from same split as evaluating?
|
| 145 |
+
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 146 |
+
|
| 147 |
+
if fewshot_as_multiturn:
|
| 148 |
+
for doc in selected_docs:
|
| 149 |
+
doc_content = self.doc_to_text(doc)
|
| 150 |
+
doc_target = self.doc_to_target(doc)
|
| 151 |
+
chat_history.append(
|
| 152 |
+
{
|
| 153 |
+
"role": "user",
|
| 154 |
+
"content": doc_content
|
| 155 |
+
if self.config.doc_to_choice is None
|
| 156 |
+
or isinstance(doc_content, str)
|
| 157 |
+
else self.doc_to_choice(doc)[doc_content],
|
| 158 |
+
}
|
| 159 |
+
)
|
| 160 |
+
chat_history.append(
|
| 161 |
+
{
|
| 162 |
+
"role": "assistant",
|
| 163 |
+
"content": prefix + str(doc_target[0])
|
| 164 |
+
if isinstance(doc_target, list)
|
| 165 |
+
else prefix + doc_target
|
| 166 |
+
if self.config.doc_to_choice is None
|
| 167 |
+
or isinstance(doc_target, str)
|
| 168 |
+
else prefix + str(self.doc_to_choice(doc)[doc_target]),
|
| 169 |
+
}
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
# get fewshot context as one user turn
|
| 173 |
+
chat_history.append(
|
| 174 |
+
{
|
| 175 |
+
"role": "user",
|
| 176 |
+
"content": self.get_context(
|
| 177 |
+
doc, num_fewshot, gen_prefix=gen_prefix
|
| 178 |
+
),
|
| 179 |
+
}
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
return chat_history
|
| 183 |
+
|
| 184 |
+
def sample(self, n: int):
|
| 185 |
+
"""
|
| 186 |
+
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
return self.rnd.sample(self.docs, n)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class FirstNSampler(ContextSampler):
|
| 193 |
+
def sample(self, n: int) -> None:
|
| 194 |
+
"""
|
| 195 |
+
Draw the first `n` samples in order from the specified split.
|
| 196 |
+
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
|
| 197 |
+
"""
|
| 198 |
+
assert n <= len(self.docs), (
|
| 199 |
+
f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
|
| 200 |
+
)
|
| 201 |
+
return self.docs[:n]
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class BalancedSampler(ContextSampler):
|
| 205 |
+
def sample(self, n: int) -> None:
|
| 206 |
+
"""
|
| 207 |
+
TODO: this should return approximately class-balanced samples from our fewshot examples.
|
| 208 |
+
TODO: what order should they be in? maybe random?
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
pass
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class ManualSampler(ContextSampler):
|
| 215 |
+
def sample(self, n: int) -> None:
|
| 216 |
+
""" """
|
| 217 |
+
pass
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
SAMPLER_REGISTRY = {
|
| 221 |
+
"default": ContextSampler,
|
| 222 |
+
"first_n": FirstNSampler,
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_sampler(name: str):
|
| 227 |
+
try:
|
| 228 |
+
return SAMPLER_REGISTRY[name]
|
| 229 |
+
except KeyError:
|
| 230 |
+
raise ValueError(
|
| 231 |
+
f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
|
| 232 |
+
)
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/task.py
ADDED
|
@@ -0,0 +1,1881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import ast
|
| 3 |
+
import logging
|
| 4 |
+
import random
|
| 5 |
+
import re
|
| 6 |
+
from collections.abc import Callable
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from dataclasses import asdict, dataclass
|
| 9 |
+
from inspect import getsource
|
| 10 |
+
from typing import (
|
| 11 |
+
Any,
|
| 12 |
+
Dict,
|
| 13 |
+
Iterable,
|
| 14 |
+
Iterator,
|
| 15 |
+
List,
|
| 16 |
+
Literal,
|
| 17 |
+
Mapping,
|
| 18 |
+
Optional,
|
| 19 |
+
Tuple,
|
| 20 |
+
Union,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
import datasets
|
| 24 |
+
import numpy as np
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
from dllm_eval import utils
|
| 28 |
+
from dllm_eval.api import samplers
|
| 29 |
+
from dllm_eval.api.instance import Instance, OutputType
|
| 30 |
+
from dllm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
|
| 31 |
+
from dllm_eval.api.registry import (
|
| 32 |
+
AGGREGATION_REGISTRY,
|
| 33 |
+
DEFAULT_METRIC_REGISTRY,
|
| 34 |
+
get_aggregation,
|
| 35 |
+
get_metric,
|
| 36 |
+
get_metric_aggregation,
|
| 37 |
+
is_higher_better,
|
| 38 |
+
)
|
| 39 |
+
from dllm_eval.caching.cache import load_from_cache, save_to_cache
|
| 40 |
+
from dllm_eval.filters import build_filter_ensemble
|
| 41 |
+
from dllm_eval.prompts import get_prompt
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
ALL_OUTPUT_TYPES = [
|
| 45 |
+
"loglikelihood",
|
| 46 |
+
"multiple_choice",
|
| 47 |
+
"loglikelihood_rolling",
|
| 48 |
+
"generate_until",
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
eval_logger = logging.getLogger(__name__)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class TaskConfig(dict):
|
| 56 |
+
# task naming/registry
|
| 57 |
+
task: Optional[str] = None
|
| 58 |
+
task_alias: Optional[str] = None
|
| 59 |
+
tag: Optional[Union[str, list]] = None
|
| 60 |
+
# HF dataset options.
|
| 61 |
+
# which dataset to use,
|
| 62 |
+
# and what splits for what purpose
|
| 63 |
+
custom_dataset: Optional[Callable] = None
|
| 64 |
+
dataset_path: Optional[str] = None
|
| 65 |
+
dataset_name: Optional[str] = None
|
| 66 |
+
dataset_kwargs: Optional[dict] = None
|
| 67 |
+
training_split: Optional[str] = None
|
| 68 |
+
validation_split: Optional[str] = None
|
| 69 |
+
test_split: Optional[str] = None
|
| 70 |
+
fewshot_split: Optional[str] = (
|
| 71 |
+
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
|
| 72 |
+
)
|
| 73 |
+
# formatting / prompting options.
|
| 74 |
+
# see docs/advanced_task_guide.md for more info
|
| 75 |
+
process_docs: Optional[Callable] = None
|
| 76 |
+
doc_to_text: Optional[Union[Callable, str]] = None
|
| 77 |
+
doc_to_target: Optional[Union[Callable, str]] = None
|
| 78 |
+
doc_to_image: Union[Callable, str] = None
|
| 79 |
+
doc_to_audio: Union[Callable, str] = None
|
| 80 |
+
unsafe_code: bool = False
|
| 81 |
+
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
|
| 82 |
+
process_results: Optional[Union[Callable, str]] = None
|
| 83 |
+
use_prompt: Optional[str] = None
|
| 84 |
+
description: str = ""
|
| 85 |
+
target_delimiter: str = " "
|
| 86 |
+
fewshot_delimiter: str = "\n\n"
|
| 87 |
+
fewshot_config: Optional[dict] = None
|
| 88 |
+
# runtime configuration options
|
| 89 |
+
num_fewshot: Optional[int] = None
|
| 90 |
+
# scoring options
|
| 91 |
+
metric_list: Optional[list] = None
|
| 92 |
+
output_type: OutputType = "generate_until"
|
| 93 |
+
generation_kwargs: Optional[dict] = None
|
| 94 |
+
repeats: int = 1
|
| 95 |
+
filter_list: Optional[Union[str, list]] = None
|
| 96 |
+
should_decontaminate: bool = False
|
| 97 |
+
doc_to_decontamination_query: Optional[str] = None
|
| 98 |
+
gen_prefix: Optional[str] = None
|
| 99 |
+
metadata: Optional[dict] = (
|
| 100 |
+
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def __post_init__(self) -> None:
|
| 104 |
+
if self.generation_kwargs is not None:
|
| 105 |
+
if self.output_type != "generate_until":
|
| 106 |
+
eval_logger.warning(
|
| 107 |
+
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if "temperature" in self.generation_kwargs:
|
| 111 |
+
self.generation_kwargs["temperature"] = float(
|
| 112 |
+
self.generation_kwargs["temperature"]
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if "until" not in self.generation_kwargs:
|
| 116 |
+
eval_logger.warning(
|
| 117 |
+
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
|
| 118 |
+
)
|
| 119 |
+
self.generation_kwargs["until"] = [self.fewshot_delimiter]
|
| 120 |
+
else:
|
| 121 |
+
if self.output_type == "generate_until":
|
| 122 |
+
# ensure that we greedily generate in absence of explicit arguments otherwise
|
| 123 |
+
self.generation_kwargs = {
|
| 124 |
+
"until": (
|
| 125 |
+
None
|
| 126 |
+
if self.fewshot_delimiter is None
|
| 127 |
+
else [self.fewshot_delimiter]
|
| 128 |
+
),
|
| 129 |
+
"do_sample": False,
|
| 130 |
+
"temperature": 0,
|
| 131 |
+
}
|
| 132 |
+
eval_logger.warning(
|
| 133 |
+
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, item):
|
| 137 |
+
return getattr(self, item)
|
| 138 |
+
|
| 139 |
+
def __setitem__(self, item, value):
|
| 140 |
+
return setattr(self, item, value)
|
| 141 |
+
|
| 142 |
+
def to_dict(self, keep_callable: bool = False) -> dict:
|
| 143 |
+
"""dumps the current config as a dictionary object, as a printable format.
|
| 144 |
+
null fields will not be printed.
|
| 145 |
+
Used for dumping results alongside full task configuration
|
| 146 |
+
|
| 147 |
+
:return: dict
|
| 148 |
+
A printable dictionary version of the TaskConfig object.
|
| 149 |
+
|
| 150 |
+
# TODO: should any default value in the TaskConfig not be printed?
|
| 151 |
+
"""
|
| 152 |
+
cfg_dict = asdict(self)
|
| 153 |
+
# remove values that are `None`
|
| 154 |
+
for k, v in list(cfg_dict.items()):
|
| 155 |
+
if v is None:
|
| 156 |
+
cfg_dict.pop(k)
|
| 157 |
+
elif k == "metric_list":
|
| 158 |
+
for metric_dict in v:
|
| 159 |
+
for metric_key, metric_value in metric_dict.items():
|
| 160 |
+
if callable(metric_value):
|
| 161 |
+
metric_dict[metric_key] = self.serialize_function(
|
| 162 |
+
metric_value, keep_callable=keep_callable
|
| 163 |
+
)
|
| 164 |
+
cfg_dict[k] = v
|
| 165 |
+
elif callable(v):
|
| 166 |
+
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
|
| 167 |
+
return cfg_dict
|
| 168 |
+
|
| 169 |
+
def serialize_function(
|
| 170 |
+
self, value: Union[Callable, str], keep_callable=False
|
| 171 |
+
) -> Union[Callable, str]:
|
| 172 |
+
"""Serializes a given function or string.
|
| 173 |
+
|
| 174 |
+
If 'keep_callable' is True, the original callable is returned.
|
| 175 |
+
Otherwise, attempts to return the source code of the callable using 'getsource'.
|
| 176 |
+
"""
|
| 177 |
+
if keep_callable:
|
| 178 |
+
return value
|
| 179 |
+
else:
|
| 180 |
+
try:
|
| 181 |
+
return getsource(value)
|
| 182 |
+
except (TypeError, OSError):
|
| 183 |
+
return str(value)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class Task(abc.ABC):
|
| 187 |
+
"""A task represents an entire benchmark including its dataset, problems,
|
| 188 |
+
answers, and evaluation methods. See BoolQ for a simple example implementation
|
| 189 |
+
|
| 190 |
+
A `doc` can be any python object which represents one instance of evaluation.
|
| 191 |
+
This is usually a dictionary e.g.
|
| 192 |
+
{"question": ..., "answer": ...} or
|
| 193 |
+
{"question": ..., question, answer)
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
VERSION: Optional[Union[int, str]] = None
|
| 197 |
+
|
| 198 |
+
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
|
| 199 |
+
# or a path to a custom `datasets` loading script.
|
| 200 |
+
DATASET_PATH: Optional[str] = None
|
| 201 |
+
|
| 202 |
+
# The name of a subset within `DATASET_PATH`.
|
| 203 |
+
DATASET_NAME: Optional[str] = None
|
| 204 |
+
|
| 205 |
+
OUTPUT_TYPE: Optional[OutputType] = None
|
| 206 |
+
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
data_dir: Optional[str] = None,
|
| 210 |
+
cache_dir: Optional[str] = None,
|
| 211 |
+
download_mode: Optional[datasets.DownloadMode] = None,
|
| 212 |
+
config: Optional[Mapping] = None, # Union[dict, TaskConfig]
|
| 213 |
+
) -> None:
|
| 214 |
+
"""
|
| 215 |
+
:param data_dir: str
|
| 216 |
+
Stores the path to a local folder containing the `Task`'s data files.
|
| 217 |
+
Use this to specify the path to manually downloaded data (usually when
|
| 218 |
+
the dataset is not publicly accessible).
|
| 219 |
+
:param cache_dir: str
|
| 220 |
+
The directory to read/write the `Task` dataset. This follows the
|
| 221 |
+
HuggingFace `datasets` API with the default cache directory located at:
|
| 222 |
+
`~/.cache/huggingface/datasets`
|
| 223 |
+
NOTE: You can change the cache location globally for a given process
|
| 224 |
+
to another directory:
|
| 225 |
+
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
| 226 |
+
:param download_mode: datasets.DownloadMode
|
| 227 |
+
How to treat pre-existing `Task` downloads and data.
|
| 228 |
+
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
| 229 |
+
Reuse download and reuse dataset.
|
| 230 |
+
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
| 231 |
+
Reuse download with fresh dataset.
|
| 232 |
+
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
| 233 |
+
Fresh download and fresh dataset.
|
| 234 |
+
"""
|
| 235 |
+
self.download(data_dir, cache_dir, download_mode)
|
| 236 |
+
self._training_docs: Optional[list] = None
|
| 237 |
+
self._fewshot_docs: Optional[list] = None
|
| 238 |
+
self._instances: Optional[List[Instance]] = None
|
| 239 |
+
|
| 240 |
+
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
|
| 241 |
+
|
| 242 |
+
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
|
| 243 |
+
self.fewshot_rnd: Optional[random.Random] = (
|
| 244 |
+
None # purposely induce errors in case of improper usage
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
def download(
|
| 248 |
+
self,
|
| 249 |
+
data_dir: Optional[str] = None,
|
| 250 |
+
cache_dir: Optional[str] = None,
|
| 251 |
+
download_mode=None,
|
| 252 |
+
) -> None:
|
| 253 |
+
"""Downloads and returns the task dataset.
|
| 254 |
+
Override this method to download the dataset from a custom API.
|
| 255 |
+
|
| 256 |
+
:param data_dir: str
|
| 257 |
+
Stores the path to a local folder containing the `Task`'s data files.
|
| 258 |
+
Use this to specify the path to manually downloaded data (usually when
|
| 259 |
+
the dataset is not publicly accessible).
|
| 260 |
+
:param cache_dir: str
|
| 261 |
+
The directory to read/write the `Task` dataset. This follows the
|
| 262 |
+
HuggingFace `datasets` API with the default cache directory located at:
|
| 263 |
+
`~/.cache/huggingface/datasets`
|
| 264 |
+
NOTE: You can change the cache location globally for a given process
|
| 265 |
+
by setting the shell environment variable, `HF_DATASETS_CACHE`,
|
| 266 |
+
to another directory:
|
| 267 |
+
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
| 268 |
+
:param download_mode: datasets.DownloadMode
|
| 269 |
+
How to treat pre-existing `Task` downloads and data.
|
| 270 |
+
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
| 271 |
+
Reuse download and reuse dataset.
|
| 272 |
+
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
| 273 |
+
Reuse download with fresh dataset.
|
| 274 |
+
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
| 275 |
+
Fresh download and fresh dataset.
|
| 276 |
+
"""
|
| 277 |
+
self.dataset = datasets.load_dataset(
|
| 278 |
+
path=self.DATASET_PATH,
|
| 279 |
+
name=self.DATASET_NAME,
|
| 280 |
+
data_dir=data_dir,
|
| 281 |
+
cache_dir=cache_dir,
|
| 282 |
+
download_mode=download_mode,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
@property
|
| 286 |
+
def config(self) -> TaskConfig:
|
| 287 |
+
"""Returns the TaskConfig associated with this class."""
|
| 288 |
+
return self._config
|
| 289 |
+
|
| 290 |
+
@abc.abstractmethod
|
| 291 |
+
def has_training_docs(self):
|
| 292 |
+
"""Whether the task has a training set"""
|
| 293 |
+
pass
|
| 294 |
+
|
| 295 |
+
@abc.abstractmethod
|
| 296 |
+
def has_validation_docs(self):
|
| 297 |
+
"""Whether the task has a validation set"""
|
| 298 |
+
pass
|
| 299 |
+
|
| 300 |
+
@abc.abstractmethod
|
| 301 |
+
def has_test_docs(self):
|
| 302 |
+
"""Whether the task has a test set"""
|
| 303 |
+
pass
|
| 304 |
+
|
| 305 |
+
def training_docs(self) -> Iterable:
|
| 306 |
+
"""
|
| 307 |
+
:return: Iterable[obj]
|
| 308 |
+
A iterable of any object, that doc_to_text can handle
|
| 309 |
+
"""
|
| 310 |
+
return []
|
| 311 |
+
|
| 312 |
+
def validation_docs(self) -> Iterable:
|
| 313 |
+
"""
|
| 314 |
+
:return: Iterable[obj]
|
| 315 |
+
A iterable of any object, that doc_to_text can handle
|
| 316 |
+
"""
|
| 317 |
+
return []
|
| 318 |
+
|
| 319 |
+
def test_docs(self) -> Iterable:
|
| 320 |
+
"""
|
| 321 |
+
:return: Iterable[obj]
|
| 322 |
+
A iterable of any object, that doc_to_text can handle
|
| 323 |
+
"""
|
| 324 |
+
return []
|
| 325 |
+
|
| 326 |
+
def fewshot_docs(self) -> Iterable:
|
| 327 |
+
"""
|
| 328 |
+
:return: Iterable[obj]
|
| 329 |
+
A iterable of any object, that doc_to_text can handle
|
| 330 |
+
"""
|
| 331 |
+
if self.has_training_docs():
|
| 332 |
+
return self.training_docs()
|
| 333 |
+
elif self.has_validation_docs():
|
| 334 |
+
return self.validation_docs()
|
| 335 |
+
else:
|
| 336 |
+
if self.config.get("num_fewshot", 0) > 0:
|
| 337 |
+
eval_logger.warning(
|
| 338 |
+
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
|
| 339 |
+
", using test_docs as fewshot_docs but this is not recommended."
|
| 340 |
+
)
|
| 341 |
+
return self.test_docs()
|
| 342 |
+
|
| 343 |
+
def _process_doc(self, doc: dict) -> dict:
|
| 344 |
+
"""
|
| 345 |
+
Override this to process (detokenize, strip, replace, etc.) individual
|
| 346 |
+
documents. This can be used in a map over documents of a data split.
|
| 347 |
+
E.g. `map(self._process_doc, self.dataset["validation"])`
|
| 348 |
+
|
| 349 |
+
:return: dict
|
| 350 |
+
The processed version of the specified `doc`.
|
| 351 |
+
"""
|
| 352 |
+
return doc
|
| 353 |
+
|
| 354 |
+
@property
|
| 355 |
+
def instances(self) -> List[Instance]:
|
| 356 |
+
"""After calling `task.build_all_requests()`, tasks
|
| 357 |
+
maintain a list of the dataset instances which will be evaluated.
|
| 358 |
+
"""
|
| 359 |
+
return self._instances
|
| 360 |
+
|
| 361 |
+
def fewshot_examples(self, k, rnd):
|
| 362 |
+
if self._training_docs is None:
|
| 363 |
+
self._training_docs = list(self.training_docs())
|
| 364 |
+
|
| 365 |
+
return rnd.sample(self._training_docs, k)
|
| 366 |
+
|
| 367 |
+
def doc_to_decontamination_query(self, doc):
|
| 368 |
+
raise NotImplementedError(
|
| 369 |
+
"Override doc_to_decontamination_query with document specific decontamination query."
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
@abc.abstractmethod
|
| 373 |
+
def doc_to_text(self, doc):
|
| 374 |
+
pass
|
| 375 |
+
|
| 376 |
+
@abc.abstractmethod
|
| 377 |
+
def doc_to_target(self, doc):
|
| 378 |
+
pass
|
| 379 |
+
|
| 380 |
+
# not an abstractmethod because not every language-only task has to implement this
|
| 381 |
+
def doc_to_image(self, doc):
|
| 382 |
+
raise NotImplementedError
|
| 383 |
+
|
| 384 |
+
def doc_to_audio(self, doc):
|
| 385 |
+
raise NotImplementedError
|
| 386 |
+
|
| 387 |
+
def doc_to_prefix(self, doc):
|
| 388 |
+
return ""
|
| 389 |
+
|
| 390 |
+
def build_all_requests(
|
| 391 |
+
self,
|
| 392 |
+
*,
|
| 393 |
+
limit: Union[int, None] = None,
|
| 394 |
+
samples: Optional[List[int]] = None,
|
| 395 |
+
rank: int = 0,
|
| 396 |
+
world_size: int = 1,
|
| 397 |
+
cache_requests: bool = False,
|
| 398 |
+
rewrite_requests_cache: bool = False,
|
| 399 |
+
system_instruction: Optional[str] = None,
|
| 400 |
+
apply_chat_template: bool = False,
|
| 401 |
+
fewshot_as_multiturn: bool = False,
|
| 402 |
+
chat_template: Optional[Callable] = None,
|
| 403 |
+
tokenizer_name: str = "",
|
| 404 |
+
) -> None:
|
| 405 |
+
"""Build a set of Instances for a task, and store them in task.instances"""
|
| 406 |
+
|
| 407 |
+
# used with caching
|
| 408 |
+
og_limit = limit
|
| 409 |
+
|
| 410 |
+
cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}"
|
| 411 |
+
cache_key += "-chat_template" if apply_chat_template else ""
|
| 412 |
+
cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else ""
|
| 413 |
+
cache_key += (
|
| 414 |
+
f"-system_prompt_hash{utils.hash_string(system_instruction)}"
|
| 415 |
+
if system_instruction is not None
|
| 416 |
+
else ""
|
| 417 |
+
)
|
| 418 |
+
cache_key += f"-tokenizer{tokenizer_name}"
|
| 419 |
+
|
| 420 |
+
cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
|
| 421 |
+
|
| 422 |
+
if cache_requests and cached_instances and not rewrite_requests_cache:
|
| 423 |
+
cached_instances = cached_instances[:limit]
|
| 424 |
+
|
| 425 |
+
flattened_instances = [
|
| 426 |
+
instance
|
| 427 |
+
for instance_group in cached_instances
|
| 428 |
+
for instance in instance_group
|
| 429 |
+
]
|
| 430 |
+
|
| 431 |
+
self._instances = flattened_instances
|
| 432 |
+
return
|
| 433 |
+
|
| 434 |
+
eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
|
| 435 |
+
|
| 436 |
+
instances = []
|
| 437 |
+
|
| 438 |
+
# process all documents when caching is specified for simplicity
|
| 439 |
+
if (
|
| 440 |
+
cache_requests
|
| 441 |
+
and (not cached_instances or rewrite_requests_cache)
|
| 442 |
+
and limit is not None
|
| 443 |
+
):
|
| 444 |
+
limit = None
|
| 445 |
+
|
| 446 |
+
doc_id_docs = list(
|
| 447 |
+
self.doc_iterator(
|
| 448 |
+
rank=rank, limit=limit, samples=samples, world_size=world_size
|
| 449 |
+
)
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
num_docs = len(doc_id_docs)
|
| 453 |
+
|
| 454 |
+
for doc_id, doc in tqdm(
|
| 455 |
+
doc_id_docs,
|
| 456 |
+
total=num_docs,
|
| 457 |
+
):
|
| 458 |
+
# sample fewshot context #TODO: need to offset doc_id by rank now!
|
| 459 |
+
fewshot_ctx = self.fewshot_context(
|
| 460 |
+
doc,
|
| 461 |
+
num_fewshot=0
|
| 462 |
+
if self.config.num_fewshot is None
|
| 463 |
+
else self.config.num_fewshot,
|
| 464 |
+
system_instruction=system_instruction,
|
| 465 |
+
apply_chat_template=apply_chat_template,
|
| 466 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 467 |
+
chat_template=chat_template,
|
| 468 |
+
gen_prefix=self.doc_to_prefix(doc),
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
|
| 472 |
+
inst = self.construct_requests(
|
| 473 |
+
doc=doc,
|
| 474 |
+
ctx=fewshot_ctx,
|
| 475 |
+
metadata=(self.config["task"], doc_id, self.config.repeats),
|
| 476 |
+
apply_chat_template=apply_chat_template,
|
| 477 |
+
chat_template=chat_template,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
if not isinstance(inst, list):
|
| 481 |
+
inst = [inst]
|
| 482 |
+
|
| 483 |
+
instances.append(inst)
|
| 484 |
+
|
| 485 |
+
# now flatten, this is to allow slicing to work with pickles
|
| 486 |
+
|
| 487 |
+
sliced_instances = instances[:og_limit]
|
| 488 |
+
|
| 489 |
+
flattened_instances = [
|
| 490 |
+
instance
|
| 491 |
+
for instance_group in sliced_instances
|
| 492 |
+
for instance in instance_group
|
| 493 |
+
]
|
| 494 |
+
|
| 495 |
+
self._instances = flattened_instances
|
| 496 |
+
|
| 497 |
+
if len(self._instances) == 0:
|
| 498 |
+
raise ValueError("task.build_requests() did not find any docs!")
|
| 499 |
+
|
| 500 |
+
if cache_requests and (not cached_instances or rewrite_requests_cache):
|
| 501 |
+
save_to_cache(file_name=cache_key, obj=instances)
|
| 502 |
+
|
| 503 |
+
@abc.abstractmethod
|
| 504 |
+
def construct_requests(self, doc, ctx, **kwargs):
|
| 505 |
+
"""Uses RequestFactory to construct Requests and returns an iterable of
|
| 506 |
+
Requests which will be sent to the LM.
|
| 507 |
+
|
| 508 |
+
:param doc:
|
| 509 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 510 |
+
:param ctx: str
|
| 511 |
+
The context string, generated by fewshot_context. This includes the natural
|
| 512 |
+
language description, as well as the few shot examples, and the question
|
| 513 |
+
part of the document for `doc`.
|
| 514 |
+
:param doc_idx: int
|
| 515 |
+
The index of a document within `self.test_docs()` or `self.validation_docs()`,
|
| 516 |
+
whichever is the main split used.
|
| 517 |
+
:param repeats: int
|
| 518 |
+
TODO: update this docstring
|
| 519 |
+
The number of times each instance in a dataset is inferred on. Defaults to 1,
|
| 520 |
+
can be increased for techniques like majority voting.
|
| 521 |
+
"""
|
| 522 |
+
pass
|
| 523 |
+
|
| 524 |
+
@abc.abstractmethod
|
| 525 |
+
def process_results(self, doc, results):
|
| 526 |
+
"""Take a single document and the LM results and evaluates, returning a
|
| 527 |
+
dict where keys are the names of submetrics and values are the values of
|
| 528 |
+
the metric for that one document
|
| 529 |
+
|
| 530 |
+
:param doc:
|
| 531 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 532 |
+
:param results:
|
| 533 |
+
The results of the requests created in construct_requests.
|
| 534 |
+
"""
|
| 535 |
+
pass
|
| 536 |
+
|
| 537 |
+
@abc.abstractmethod
|
| 538 |
+
def aggregation(self):
|
| 539 |
+
"""
|
| 540 |
+
:returns: {str: [metric_score] -> float}
|
| 541 |
+
A dictionary where keys are the names of submetrics and values are
|
| 542 |
+
functions that aggregate a list of metric scores
|
| 543 |
+
"""
|
| 544 |
+
pass
|
| 545 |
+
|
| 546 |
+
@abc.abstractmethod
|
| 547 |
+
def higher_is_better(self):
|
| 548 |
+
"""
|
| 549 |
+
:returns: {str: bool}
|
| 550 |
+
A dictionary where keys are the names of submetrics and values are
|
| 551 |
+
whether a higher value of the submetric is better
|
| 552 |
+
"""
|
| 553 |
+
pass
|
| 554 |
+
|
| 555 |
+
def get_config(self, key: str) -> Any:
|
| 556 |
+
return getattr(self._config, key, None)
|
| 557 |
+
|
| 558 |
+
@classmethod
|
| 559 |
+
def count_bytes(cls, doc):
|
| 560 |
+
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
|
| 561 |
+
return len(doc.encode("utf-8"))
|
| 562 |
+
|
| 563 |
+
@classmethod
|
| 564 |
+
def count_words(cls, doc):
|
| 565 |
+
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
|
| 566 |
+
return len(re.split(r"\s+", doc))
|
| 567 |
+
|
| 568 |
+
@utils.positional_deprecated
|
| 569 |
+
def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs):
|
| 570 |
+
"""Returns a fewshot context string that is made up of a prepended description
|
| 571 |
+
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
|
| 572 |
+
|
| 573 |
+
:param doc: str
|
| 574 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 575 |
+
:param num_fewshot: int
|
| 576 |
+
The number of fewshot examples to provide in the returned context string.
|
| 577 |
+
:param rnd: random.Random
|
| 578 |
+
The pseudo-random number generator used to randomly sample examples.
|
| 579 |
+
WARNING: This is currently a required arg although it's optionalized with a default `None`.
|
| 580 |
+
:param description: str
|
| 581 |
+
The task's description that will be prepended to the fewshot examples.
|
| 582 |
+
:returns: str
|
| 583 |
+
The fewshot context.
|
| 584 |
+
"""
|
| 585 |
+
if rnd is None:
|
| 586 |
+
if self.fewshot_rnd is not None:
|
| 587 |
+
rnd = self.fewshot_rnd
|
| 588 |
+
else:
|
| 589 |
+
raise ValueError(
|
| 590 |
+
"A `random.Random` generator argument must be provided to `rnd`"
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
description = description if description else ""
|
| 594 |
+
|
| 595 |
+
if num_fewshot == 0:
|
| 596 |
+
labeled_examples = ""
|
| 597 |
+
else:
|
| 598 |
+
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
|
| 599 |
+
if self.has_training_docs():
|
| 600 |
+
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
|
| 601 |
+
else:
|
| 602 |
+
if self._fewshot_docs is None:
|
| 603 |
+
self._fewshot_docs = list(
|
| 604 |
+
self.validation_docs()
|
| 605 |
+
if self.has_validation_docs()
|
| 606 |
+
else self.test_docs()
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
|
| 610 |
+
|
| 611 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 612 |
+
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 613 |
+
|
| 614 |
+
labeled_examples = (
|
| 615 |
+
"\n\n".join(
|
| 616 |
+
[
|
| 617 |
+
self.doc_to_text(doc) + self.doc_to_target(doc)
|
| 618 |
+
for doc in fewshotex
|
| 619 |
+
]
|
| 620 |
+
)
|
| 621 |
+
+ "\n\n"
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
example = self.doc_to_text(doc)
|
| 625 |
+
return description + labeled_examples + example
|
| 626 |
+
|
| 627 |
+
def apply_filters(self) -> Optional[List[Instance]]:
|
| 628 |
+
"""Iterates over FilterEnsembles and applies them to instances"""
|
| 629 |
+
if hasattr(self, "_filters"):
|
| 630 |
+
for f in self._filters:
|
| 631 |
+
f.apply(self._instances)
|
| 632 |
+
else:
|
| 633 |
+
eval_logger.warning("No filter defined, passing through instances")
|
| 634 |
+
return self._instances
|
| 635 |
+
|
| 636 |
+
def dump_config(self) -> dict:
|
| 637 |
+
"""Returns the config as a dictionary."""
|
| 638 |
+
# TODO: this should only return the overrides applied to a non-YAML task's configuration.
|
| 639 |
+
# (num_fewshot)
|
| 640 |
+
return self.config.to_dict()
|
| 641 |
+
|
| 642 |
+
def set_config(self, key: str, value: Any, update: bool = False) -> None:
|
| 643 |
+
"""Set or update the configuration for a given key."""
|
| 644 |
+
if key is None:
|
| 645 |
+
raise ValueError("Key must be provided.")
|
| 646 |
+
|
| 647 |
+
if update:
|
| 648 |
+
current_value = getattr(self._config, key, {})
|
| 649 |
+
if not isinstance(current_value, dict):
|
| 650 |
+
raise TypeError(
|
| 651 |
+
f"Expected a dict for key '{key}', got {type(current_value).__name__} instead."
|
| 652 |
+
)
|
| 653 |
+
current_value.update(value)
|
| 654 |
+
else:
|
| 655 |
+
setattr(self._config, key, value)
|
| 656 |
+
|
| 657 |
+
def override_metric(self, metric_name: str) -> None:
|
| 658 |
+
"""
|
| 659 |
+
Override the default metrics used for evaluation with custom metrics.
|
| 660 |
+
|
| 661 |
+
Parameters:
|
| 662 |
+
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
|
| 663 |
+
"""
|
| 664 |
+
(
|
| 665 |
+
self._metric_fn_list,
|
| 666 |
+
self._aggregation_list,
|
| 667 |
+
self._metric_fn_kwargs,
|
| 668 |
+
self._higher_is_better,
|
| 669 |
+
) = ({}, {}, {}, {})
|
| 670 |
+
self._metric_fn_list[metric_name] = get_metric(metric_name)
|
| 671 |
+
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
|
| 672 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 673 |
+
self._metric_fn_kwargs[metric_name] = {}
|
| 674 |
+
if not isinstance(self, ConfigurableTask):
|
| 675 |
+
self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
|
| 676 |
+
self.aggregation = lambda: {
|
| 677 |
+
metric_name: get_metric_aggregation(metric_name)
|
| 678 |
+
}
|
| 679 |
+
setattr(self._config, "metric_list", [{"metric": metric_name}])
|
| 680 |
+
setattr(self._config, "process_results", None)
|
| 681 |
+
|
| 682 |
+
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
|
| 683 |
+
self.fewshot_rnd = random.Random(seed)
|
| 684 |
+
if hasattr(self, "sampler"):
|
| 685 |
+
self.sampler.rnd = self.fewshot_rnd
|
| 686 |
+
|
| 687 |
+
@property
|
| 688 |
+
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
|
| 689 |
+
if self.has_test_docs():
|
| 690 |
+
return self.test_docs()
|
| 691 |
+
elif self.has_validation_docs():
|
| 692 |
+
return self.validation_docs()
|
| 693 |
+
else:
|
| 694 |
+
raise ValueError(
|
| 695 |
+
f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
def doc_iterator(
|
| 699 |
+
self,
|
| 700 |
+
*,
|
| 701 |
+
rank: int = 0,
|
| 702 |
+
limit: Union[int, None] = None,
|
| 703 |
+
world_size: int = 1,
|
| 704 |
+
samples: Optional[List[int]] = None,
|
| 705 |
+
) -> Iterator[Tuple[int, Any]]:
|
| 706 |
+
if samples:
|
| 707 |
+
n = len(self.eval_docs)
|
| 708 |
+
assert all([e < n for e in samples]), (
|
| 709 |
+
f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
|
| 710 |
+
)
|
| 711 |
+
eval_logger.info(
|
| 712 |
+
f"{self.config.task}: Evaluating on {len(samples)} examples"
|
| 713 |
+
)
|
| 714 |
+
doc_iterator = utils.create_iterator(
|
| 715 |
+
enumerate(x for i, x in enumerate(self.eval_docs) if i in samples),
|
| 716 |
+
rank=int(rank),
|
| 717 |
+
limit=None, # limit does not matter here since we are selecting samples directly
|
| 718 |
+
world_size=int(world_size),
|
| 719 |
+
)
|
| 720 |
+
else:
|
| 721 |
+
limit = int(limit) if limit else None
|
| 722 |
+
doc_iterator = utils.create_iterator(
|
| 723 |
+
enumerate(self.eval_docs),
|
| 724 |
+
rank=int(rank),
|
| 725 |
+
limit=limit,
|
| 726 |
+
world_size=int(world_size),
|
| 727 |
+
)
|
| 728 |
+
return doc_iterator
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
class ConfigurableTask(Task):
|
| 732 |
+
VERSION = "Yaml"
|
| 733 |
+
OUTPUT_TYPE = None
|
| 734 |
+
CONFIG = None
|
| 735 |
+
|
| 736 |
+
def __init__(
|
| 737 |
+
self,
|
| 738 |
+
data_dir=None,
|
| 739 |
+
cache_dir=None,
|
| 740 |
+
download_mode=None,
|
| 741 |
+
config: Optional[dict] = None,
|
| 742 |
+
) -> None: # TODO no super() call here
|
| 743 |
+
# Get pre-configured attributes
|
| 744 |
+
self._config = self.CONFIG
|
| 745 |
+
|
| 746 |
+
# Use new configurations if there was no preconfiguration
|
| 747 |
+
if self.config is None:
|
| 748 |
+
self._config = TaskConfig(**config)
|
| 749 |
+
# Overwrite configs
|
| 750 |
+
else:
|
| 751 |
+
if config is not None:
|
| 752 |
+
self._config.__dict__.update(config)
|
| 753 |
+
|
| 754 |
+
if self.config is None:
|
| 755 |
+
raise ValueError(
|
| 756 |
+
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
if isinstance(self.config.metadata, dict):
|
| 760 |
+
if "version" in self.config.metadata:
|
| 761 |
+
self.VERSION = self.config.metadata["version"]
|
| 762 |
+
|
| 763 |
+
if self.config.output_type is not None:
|
| 764 |
+
if self.config.output_type not in ALL_OUTPUT_TYPES:
|
| 765 |
+
raise ValueError(
|
| 766 |
+
f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
|
| 767 |
+
)
|
| 768 |
+
self.OUTPUT_TYPE = self.config.output_type
|
| 769 |
+
|
| 770 |
+
if self.config.doc_to_image is not None:
|
| 771 |
+
# mark the task as requiring multimodality.
|
| 772 |
+
self.MULTIMODAL = True
|
| 773 |
+
|
| 774 |
+
if self.config.doc_to_audio:
|
| 775 |
+
# mark the task as requiring multimodality.
|
| 776 |
+
self.MULTIMODAL = True
|
| 777 |
+
|
| 778 |
+
if self.config.unsafe_code is not False:
|
| 779 |
+
self.UNSAFE_CODE = True
|
| 780 |
+
|
| 781 |
+
if self.config.dataset_path is not None:
|
| 782 |
+
self.DATASET_PATH = self.config.dataset_path
|
| 783 |
+
|
| 784 |
+
if self.config.dataset_name is not None:
|
| 785 |
+
self.DATASET_NAME = self.config.dataset_name
|
| 786 |
+
|
| 787 |
+
self._metric_fn_list = {}
|
| 788 |
+
self._metric_fn_kwargs = {}
|
| 789 |
+
self._aggregation_list = {}
|
| 790 |
+
self._higher_is_better = {}
|
| 791 |
+
|
| 792 |
+
if self.config.metric_list is None:
|
| 793 |
+
# TODO: handle this in TaskConfig.__post_init__ ?
|
| 794 |
+
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
|
| 795 |
+
|
| 796 |
+
for metric_name in _metric_list:
|
| 797 |
+
self._metric_fn_list[metric_name] = get_metric(metric_name)
|
| 798 |
+
self._metric_fn_kwargs[metric_name] = {}
|
| 799 |
+
self._aggregation_list[metric_name] = get_metric_aggregation(
|
| 800 |
+
metric_name
|
| 801 |
+
)
|
| 802 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 803 |
+
else:
|
| 804 |
+
for metric_config in self.config.metric_list:
|
| 805 |
+
if "metric" not in metric_config:
|
| 806 |
+
raise ValueError(
|
| 807 |
+
"'metric' key not provided for an entry in 'metric_list', must be specified!"
|
| 808 |
+
)
|
| 809 |
+
metric_name = metric_config["metric"]
|
| 810 |
+
kwargs = {
|
| 811 |
+
key: metric_config[key]
|
| 812 |
+
for key in metric_config
|
| 813 |
+
if key
|
| 814 |
+
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
|
| 815 |
+
}
|
| 816 |
+
hf_evaluate_metric = (
|
| 817 |
+
"hf_evaluate" in metric_config
|
| 818 |
+
and metric_config["hf_evaluate"] is True
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
if self.config.process_results is not None:
|
| 822 |
+
self._metric_fn_list[metric_name] = None
|
| 823 |
+
self._metric_fn_kwargs[metric_name] = {}
|
| 824 |
+
elif callable(metric_name):
|
| 825 |
+
metric_fn = metric_name.__call__
|
| 826 |
+
metric_name = metric_name.__name__
|
| 827 |
+
self._metric_fn_list[metric_name] = metric_fn
|
| 828 |
+
self._metric_fn_kwargs[metric_name] = kwargs
|
| 829 |
+
else:
|
| 830 |
+
self._metric_fn_list[metric_name] = get_metric(
|
| 831 |
+
metric_name, hf_evaluate_metric
|
| 832 |
+
)
|
| 833 |
+
self._metric_fn_kwargs[metric_name] = kwargs
|
| 834 |
+
|
| 835 |
+
if "aggregation" in metric_config:
|
| 836 |
+
agg_name = metric_config["aggregation"]
|
| 837 |
+
if isinstance(agg_name, str):
|
| 838 |
+
self._aggregation_list[metric_name] = get_aggregation(agg_name)
|
| 839 |
+
elif callable(agg_name): # noqa: E721
|
| 840 |
+
self._aggregation_list[metric_name] = metric_config[
|
| 841 |
+
"aggregation"
|
| 842 |
+
]
|
| 843 |
+
else:
|
| 844 |
+
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
|
| 845 |
+
metric_agg = get_metric_aggregation(metric_name)
|
| 846 |
+
eval_logger.warning(
|
| 847 |
+
f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
|
| 848 |
+
f"using default "
|
| 849 |
+
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
|
| 850 |
+
)
|
| 851 |
+
self._aggregation_list[metric_name] = metric_agg
|
| 852 |
+
|
| 853 |
+
if "higher_is_better" in metric_config:
|
| 854 |
+
self._higher_is_better[metric_name] = metric_config[
|
| 855 |
+
"higher_is_better"
|
| 856 |
+
]
|
| 857 |
+
else:
|
| 858 |
+
eval_logger.warning(
|
| 859 |
+
f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
|
| 860 |
+
f"using default "
|
| 861 |
+
f"higher_is_better={is_higher_better(metric_name)}"
|
| 862 |
+
)
|
| 863 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 864 |
+
|
| 865 |
+
self.download(self.config.dataset_kwargs)
|
| 866 |
+
self._training_docs = None
|
| 867 |
+
self._fewshot_docs = None
|
| 868 |
+
|
| 869 |
+
if self.config.filter_list is not None:
|
| 870 |
+
self._filters = []
|
| 871 |
+
for filter_config in self.config.filter_list:
|
| 872 |
+
filter_name = filter_config["name"]
|
| 873 |
+
filter_functions = filter_config["filter"]
|
| 874 |
+
components = []
|
| 875 |
+
for function in filter_functions:
|
| 876 |
+
kwargs = {
|
| 877 |
+
key: function[key] for key in function if key != "function"
|
| 878 |
+
}
|
| 879 |
+
components.append([function["function"], kwargs])
|
| 880 |
+
filter_pipeline = build_filter_ensemble(filter_name, components)
|
| 881 |
+
self._filters.append(filter_pipeline)
|
| 882 |
+
else:
|
| 883 |
+
# TODO: handle repeats in a more general way rather than just discarding
|
| 884 |
+
eval_logger.debug(
|
| 885 |
+
"No custom filters defined. Using default 'take_first' filter for handling repeats."
|
| 886 |
+
)
|
| 887 |
+
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
|
| 888 |
+
|
| 889 |
+
if self.config.use_prompt is not None:
|
| 890 |
+
eval_logger.info(f"loading prompt {self.config.use_prompt}")
|
| 891 |
+
self.prompt = get_prompt(
|
| 892 |
+
self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
|
| 893 |
+
)
|
| 894 |
+
else:
|
| 895 |
+
self.prompt = None
|
| 896 |
+
|
| 897 |
+
if self.fewshot_docs() is not None:
|
| 898 |
+
self.fewshot_rnd = (
|
| 899 |
+
random.Random()
|
| 900 |
+
) # setting with no seed, to be overridden at a later time
|
| 901 |
+
config_sampler: Union[str, Callable] = (
|
| 902 |
+
self.config.fewshot_config.get("sampler", "default")
|
| 903 |
+
if self.config.fewshot_config
|
| 904 |
+
else "default"
|
| 905 |
+
)
|
| 906 |
+
if isinstance(config_sampler, str):
|
| 907 |
+
self.sampler = samplers.get_sampler(config_sampler)(
|
| 908 |
+
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
|
| 909 |
+
)
|
| 910 |
+
elif callable(config_sampler) and issubclass(
|
| 911 |
+
config_sampler, samplers.ContextSampler
|
| 912 |
+
):
|
| 913 |
+
self.sampler = config_sampler(
|
| 914 |
+
docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
|
| 915 |
+
)
|
| 916 |
+
else:
|
| 917 |
+
raise TypeError(
|
| 918 |
+
f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
|
| 919 |
+
f"not {type(config_sampler)}"
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
self.task_docs = self.eval_docs
|
| 923 |
+
|
| 924 |
+
# Test One Doc
|
| 925 |
+
self.features = list(self.task_docs.features.keys())
|
| 926 |
+
self.multiple_input = 0
|
| 927 |
+
self.multiple_target = 0
|
| 928 |
+
test_doc = self.task_docs[0]
|
| 929 |
+
test_text = self.doc_to_text(test_doc)
|
| 930 |
+
test_target = self.doc_to_target(test_doc)
|
| 931 |
+
|
| 932 |
+
if self.config.doc_to_choice is not None:
|
| 933 |
+
test_choice = self.doc_to_choice(test_doc)
|
| 934 |
+
if not isinstance(test_choice, list):
|
| 935 |
+
eval_logger.error("doc_to_choice must return list")
|
| 936 |
+
else:
|
| 937 |
+
num_choice = len(test_choice)
|
| 938 |
+
|
| 939 |
+
if isinstance(test_text, int):
|
| 940 |
+
eval_logger.debug(
|
| 941 |
+
"doc_to_text returned an int. Assuming multiple inputs."
|
| 942 |
+
)
|
| 943 |
+
self.multiple_input = num_choice
|
| 944 |
+
else:
|
| 945 |
+
test_choice = None
|
| 946 |
+
|
| 947 |
+
if isinstance(test_target, list):
|
| 948 |
+
eval_logger.debug(
|
| 949 |
+
"doc_to_target returned a list. Assuming multiple targets."
|
| 950 |
+
)
|
| 951 |
+
self.multiple_target = len(test_target)
|
| 952 |
+
else:
|
| 953 |
+
if (isinstance(test_target, int)) and (test_choice is not None):
|
| 954 |
+
test_target = test_choice[test_target]
|
| 955 |
+
else:
|
| 956 |
+
test_target = str(test_target)
|
| 957 |
+
|
| 958 |
+
if test_choice is not None:
|
| 959 |
+
check_choices = test_choice
|
| 960 |
+
else:
|
| 961 |
+
check_choices = [test_target]
|
| 962 |
+
if self.config.doc_to_choice is not None:
|
| 963 |
+
for choice in check_choices:
|
| 964 |
+
choice_has_whitespace = True if choice[0].isspace() else False
|
| 965 |
+
delimiter_has_whitespace = (
|
| 966 |
+
True
|
| 967 |
+
if self.config.target_delimiter.rstrip()
|
| 968 |
+
!= self.config.target_delimiter
|
| 969 |
+
else False
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
if delimiter_has_whitespace and choice_has_whitespace:
|
| 973 |
+
eval_logger.debug(
|
| 974 |
+
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
|
| 975 |
+
)
|
| 976 |
+
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
|
| 977 |
+
eval_logger.debug(
|
| 978 |
+
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
def download(
|
| 982 |
+
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
|
| 983 |
+
) -> None:
|
| 984 |
+
if isinstance(self.config.custom_dataset, Callable):
|
| 985 |
+
eval_logger.warning(
|
| 986 |
+
f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
|
| 987 |
+
+ "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
|
| 988 |
+
)
|
| 989 |
+
self.dataset = self.config.custom_dataset(
|
| 990 |
+
**(self.config.metadata or {}), **(self.config.dataset_kwargs or {})
|
| 991 |
+
)
|
| 992 |
+
else:
|
| 993 |
+
self.dataset = datasets.load_dataset(
|
| 994 |
+
path=self.DATASET_PATH,
|
| 995 |
+
name=self.DATASET_NAME,
|
| 996 |
+
**dataset_kwargs if dataset_kwargs is not None else {},
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
def has_training_docs(self) -> bool:
|
| 1000 |
+
if self.config.training_split is not None:
|
| 1001 |
+
return True
|
| 1002 |
+
else:
|
| 1003 |
+
return False
|
| 1004 |
+
|
| 1005 |
+
def has_validation_docs(self) -> bool:
|
| 1006 |
+
if self.config.validation_split is not None:
|
| 1007 |
+
return True
|
| 1008 |
+
else:
|
| 1009 |
+
return False
|
| 1010 |
+
|
| 1011 |
+
def has_test_docs(self) -> bool:
|
| 1012 |
+
if self.config.test_split is not None:
|
| 1013 |
+
return True
|
| 1014 |
+
else:
|
| 1015 |
+
return False
|
| 1016 |
+
|
| 1017 |
+
def training_docs(self) -> datasets.Dataset:
|
| 1018 |
+
if self.has_training_docs():
|
| 1019 |
+
if self.config.process_docs is not None:
|
| 1020 |
+
return self.config.process_docs(
|
| 1021 |
+
self.dataset[self.config.training_split]
|
| 1022 |
+
)
|
| 1023 |
+
return self.dataset[self.config.training_split]
|
| 1024 |
+
|
| 1025 |
+
def validation_docs(self) -> datasets.Dataset:
|
| 1026 |
+
if self.has_validation_docs():
|
| 1027 |
+
if self.config.process_docs is not None:
|
| 1028 |
+
return self.config.process_docs(
|
| 1029 |
+
self.dataset[self.config.validation_split]
|
| 1030 |
+
)
|
| 1031 |
+
return self.dataset[self.config.validation_split]
|
| 1032 |
+
|
| 1033 |
+
def test_docs(self) -> datasets.Dataset:
|
| 1034 |
+
if self.has_test_docs():
|
| 1035 |
+
if self.config.process_docs is not None:
|
| 1036 |
+
return self.config.process_docs(self.dataset[self.config.test_split])
|
| 1037 |
+
return self.dataset[self.config.test_split]
|
| 1038 |
+
|
| 1039 |
+
def fewshot_docs(self):
|
| 1040 |
+
if self.config.fewshot_split is not None:
|
| 1041 |
+
if self.config.process_docs is not None:
|
| 1042 |
+
return self.config.process_docs(self.dataset[self.config.fewshot_split])
|
| 1043 |
+
return self.dataset[self.config.fewshot_split]
|
| 1044 |
+
elif (
|
| 1045 |
+
self.config.fewshot_config is not None
|
| 1046 |
+
and self.config.fewshot_config.get("samples", None) is not None
|
| 1047 |
+
):
|
| 1048 |
+
if isinstance(self.config.fewshot_config["samples"], list):
|
| 1049 |
+
return self.config.fewshot_config["samples"]
|
| 1050 |
+
elif callable(self.config.fewshot_config["samples"]):
|
| 1051 |
+
return self.config.fewshot_config["samples"]()
|
| 1052 |
+
else:
|
| 1053 |
+
raise Exception(
|
| 1054 |
+
"`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list."
|
| 1055 |
+
)
|
| 1056 |
+
else:
|
| 1057 |
+
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
|
| 1058 |
+
eval_logger.warning(
|
| 1059 |
+
f"[Task: {self.config.task}] "
|
| 1060 |
+
"num_fewshot > 0 but fewshot_split is None. "
|
| 1061 |
+
"using preconfigured rule."
|
| 1062 |
+
)
|
| 1063 |
+
return super().fewshot_docs()
|
| 1064 |
+
|
| 1065 |
+
@staticmethod
|
| 1066 |
+
def append_target_question(
|
| 1067 |
+
labeled_examples: List[Dict[str, str]],
|
| 1068 |
+
question: str,
|
| 1069 |
+
fewshot_as_multiturn: bool = False,
|
| 1070 |
+
gen_prefix: Optional[str] = None,
|
| 1071 |
+
) -> None:
|
| 1072 |
+
"""Adds a target question to the labeled examples list.
|
| 1073 |
+
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
|
| 1074 |
+
Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant.
|
| 1075 |
+
"""
|
| 1076 |
+
if not fewshot_as_multiturn:
|
| 1077 |
+
# if no messages or last message is system, append as new user entry
|
| 1078 |
+
if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
|
| 1079 |
+
labeled_examples.append({"role": "user", "content": question})
|
| 1080 |
+
# if last message is user, append to it to avoid two user messages in a row
|
| 1081 |
+
else:
|
| 1082 |
+
labeled_examples[-1]["content"] += question
|
| 1083 |
+
else:
|
| 1084 |
+
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
|
| 1085 |
+
labeled_examples.append({"role": "user", "content": question})
|
| 1086 |
+
if gen_prefix:
|
| 1087 |
+
labeled_examples.append({"role": "assistant", "content": gen_prefix})
|
| 1088 |
+
|
| 1089 |
+
@utils.positional_deprecated
|
| 1090 |
+
def fewshot_context(
|
| 1091 |
+
self,
|
| 1092 |
+
doc: dict,
|
| 1093 |
+
num_fewshot: int,
|
| 1094 |
+
system_instruction: Optional[str] = None,
|
| 1095 |
+
apply_chat_template: bool = False,
|
| 1096 |
+
fewshot_as_multiturn: bool = False,
|
| 1097 |
+
chat_template: Optional[Callable] = None,
|
| 1098 |
+
gen_prefix: Optional[str] = None,
|
| 1099 |
+
) -> Union[str, List[str]]:
|
| 1100 |
+
"""Returns a fewshot context string that is made up of a prepended description
|
| 1101 |
+
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
|
| 1102 |
+
|
| 1103 |
+
:param doc: str
|
| 1104 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 1105 |
+
:param num_fewshot: int
|
| 1106 |
+
The number of fewshot examples to provide in the returned context string.
|
| 1107 |
+
:param system_instruction: str
|
| 1108 |
+
System instruction to be applied to the prompt.
|
| 1109 |
+
:param apply_chat_template: bool
|
| 1110 |
+
Whether to apply the chat template to the fewshot context.
|
| 1111 |
+
:param fewshot_as_multiturn: bool
|
| 1112 |
+
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
| 1113 |
+
:param chat_template:
|
| 1114 |
+
callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
|
| 1115 |
+
:param gen_prefix:
|
| 1116 |
+
String to append after the <|assistant|> token.
|
| 1117 |
+
:returns: str
|
| 1118 |
+
The fewshot context.
|
| 1119 |
+
"""
|
| 1120 |
+
if apply_chat_template:
|
| 1121 |
+
labeled_examples = []
|
| 1122 |
+
else:
|
| 1123 |
+
labeled_examples = ""
|
| 1124 |
+
|
| 1125 |
+
# get task description
|
| 1126 |
+
if description := self.config.description:
|
| 1127 |
+
description = utils.apply_template(self.config.description, doc)
|
| 1128 |
+
|
| 1129 |
+
# create system prompt based on the provided system instruction and description
|
| 1130 |
+
if system_instruction is not None and description:
|
| 1131 |
+
system_prompt = (
|
| 1132 |
+
f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
|
| 1133 |
+
)
|
| 1134 |
+
elif system_instruction is not None:
|
| 1135 |
+
system_prompt = system_instruction
|
| 1136 |
+
elif description:
|
| 1137 |
+
system_prompt = description
|
| 1138 |
+
else:
|
| 1139 |
+
system_prompt = ""
|
| 1140 |
+
|
| 1141 |
+
# add system prompt if specified
|
| 1142 |
+
if system_prompt:
|
| 1143 |
+
if apply_chat_template:
|
| 1144 |
+
labeled_examples.append({"role": "system", "content": system_prompt})
|
| 1145 |
+
else:
|
| 1146 |
+
labeled_examples = system_prompt
|
| 1147 |
+
# if few-shot - append examples after the system prompt
|
| 1148 |
+
if num_fewshot > 0:
|
| 1149 |
+
if apply_chat_template:
|
| 1150 |
+
labeled_examples.extend(
|
| 1151 |
+
self.sampler.get_chat_context(
|
| 1152 |
+
doc,
|
| 1153 |
+
num_fewshot,
|
| 1154 |
+
fewshot_as_multiturn,
|
| 1155 |
+
gen_prefix=gen_prefix,
|
| 1156 |
+
)
|
| 1157 |
+
)
|
| 1158 |
+
else:
|
| 1159 |
+
labeled_examples += self.sampler.get_context(
|
| 1160 |
+
doc, num_fewshot, gen_prefix=gen_prefix
|
| 1161 |
+
)
|
| 1162 |
+
|
| 1163 |
+
example = self.doc_to_text(doc)
|
| 1164 |
+
if apply_chat_template:
|
| 1165 |
+
if self.multiple_input:
|
| 1166 |
+
# TODO: append prefill?
|
| 1167 |
+
if not labeled_examples:
|
| 1168 |
+
return ""
|
| 1169 |
+
return chat_template(labeled_examples)
|
| 1170 |
+
if isinstance(example, str):
|
| 1171 |
+
self.append_target_question(
|
| 1172 |
+
labeled_examples,
|
| 1173 |
+
example,
|
| 1174 |
+
fewshot_as_multiturn,
|
| 1175 |
+
gen_prefix=gen_prefix,
|
| 1176 |
+
)
|
| 1177 |
+
# for loglikelihood create a list of questions with appended choices
|
| 1178 |
+
elif isinstance(example, list):
|
| 1179 |
+
labeled_examples_list = []
|
| 1180 |
+
# copy chat history for each example and append the answer
|
| 1181 |
+
for ex in example:
|
| 1182 |
+
chat = deepcopy(labeled_examples)
|
| 1183 |
+
self.append_target_question(
|
| 1184 |
+
chat,
|
| 1185 |
+
ex,
|
| 1186 |
+
fewshot_as_multiturn,
|
| 1187 |
+
gen_prefix=gen_prefix,
|
| 1188 |
+
)
|
| 1189 |
+
# TODO: append prefill?
|
| 1190 |
+
labeled_examples_list.append(
|
| 1191 |
+
chat_template(
|
| 1192 |
+
chat,
|
| 1193 |
+
add_generation_prompt=False if gen_prefix else True,
|
| 1194 |
+
)
|
| 1195 |
+
)
|
| 1196 |
+
return labeled_examples_list
|
| 1197 |
+
# if example is an integer, append the choice or convert to string
|
| 1198 |
+
elif isinstance(example, int):
|
| 1199 |
+
if self.config.doc_to_choice is not None:
|
| 1200 |
+
choices = self.doc_to_choice(doc)
|
| 1201 |
+
self.append_target_question(
|
| 1202 |
+
labeled_examples,
|
| 1203 |
+
choices[example],
|
| 1204 |
+
fewshot_as_multiturn,
|
| 1205 |
+
gen_prefix=gen_prefix,
|
| 1206 |
+
)
|
| 1207 |
+
else:
|
| 1208 |
+
self.append_target_question(
|
| 1209 |
+
labeled_examples,
|
| 1210 |
+
str(example),
|
| 1211 |
+
fewshot_as_multiturn,
|
| 1212 |
+
gen_prefix=gen_prefix,
|
| 1213 |
+
)
|
| 1214 |
+
# return lm.apply_chat_template(labeled_examples)
|
| 1215 |
+
return chat_template(
|
| 1216 |
+
labeled_examples,
|
| 1217 |
+
add_generation_prompt=False if gen_prefix else True,
|
| 1218 |
+
)
|
| 1219 |
+
else:
|
| 1220 |
+
prefix = (
|
| 1221 |
+
self.config.target_delimiter + gen_prefix
|
| 1222 |
+
if gen_prefix is not None
|
| 1223 |
+
else ""
|
| 1224 |
+
)
|
| 1225 |
+
if self.multiple_input:
|
| 1226 |
+
return labeled_examples
|
| 1227 |
+
if isinstance(example, str):
|
| 1228 |
+
return labeled_examples + example + prefix
|
| 1229 |
+
elif isinstance(example, list):
|
| 1230 |
+
return [labeled_examples + ex + prefix for ex in example]
|
| 1231 |
+
elif isinstance(example, int):
|
| 1232 |
+
if self.config.doc_to_choice is not None:
|
| 1233 |
+
choices = self.doc_to_choice(doc)
|
| 1234 |
+
return labeled_examples + choices[example] + prefix
|
| 1235 |
+
else:
|
| 1236 |
+
return labeled_examples + str(example) + prefix
|
| 1237 |
+
|
| 1238 |
+
def apply_filters(self) -> Optional[List[Instance]]:
|
| 1239 |
+
"""Iterates over FilterEnsembles and applies them to instances"""
|
| 1240 |
+
if hasattr(self, "_filters"):
|
| 1241 |
+
for f in self._filters:
|
| 1242 |
+
f.apply(self._instances)
|
| 1243 |
+
else:
|
| 1244 |
+
eval_logger.warning("No filter defined, passing through instances")
|
| 1245 |
+
return self._instances
|
| 1246 |
+
|
| 1247 |
+
def should_decontaminate(self):
|
| 1248 |
+
return self.config.should_decontaminate
|
| 1249 |
+
|
| 1250 |
+
def doc_to_decontamination_query(self, doc: dict):
|
| 1251 |
+
if self.config.should_decontaminate:
|
| 1252 |
+
if self.config.doc_to_decontamination_query is None:
|
| 1253 |
+
return self.doc_to_text(doc)
|
| 1254 |
+
else:
|
| 1255 |
+
doc_to_decontamination_query = self.config.doc_to_decontamination_query
|
| 1256 |
+
if doc_to_decontamination_query in self.features:
|
| 1257 |
+
return doc[doc_to_decontamination_query]
|
| 1258 |
+
elif callable(doc_to_decontamination_query):
|
| 1259 |
+
return doc_to_decontamination_query(doc)
|
| 1260 |
+
else:
|
| 1261 |
+
return ast.literal_eval(
|
| 1262 |
+
utils.apply_template(
|
| 1263 |
+
self.config.doc_to_decontamination_query, doc
|
| 1264 |
+
)
|
| 1265 |
+
)
|
| 1266 |
+
|
| 1267 |
+
def _process_doc(self, doc: dict) -> dict:
|
| 1268 |
+
"""
|
| 1269 |
+
Override this to process (detokenize, strip, replace, etc.) individual
|
| 1270 |
+
documents. This can be used in a map over documents of a data split.
|
| 1271 |
+
E.g. `map(self._process_doc, self.dataset["validation"])`
|
| 1272 |
+
|
| 1273 |
+
:return: dict
|
| 1274 |
+
The processed version of the specified `doc`.
|
| 1275 |
+
"""
|
| 1276 |
+
return doc
|
| 1277 |
+
|
| 1278 |
+
def doc_to_text(self, doc, doc_to_text=None):
|
| 1279 |
+
if self.prompt is not None:
|
| 1280 |
+
doc_to_text = self.prompt
|
| 1281 |
+
elif doc_to_text is not None:
|
| 1282 |
+
doc_to_text = doc_to_text
|
| 1283 |
+
else:
|
| 1284 |
+
doc_to_text = self.config.doc_to_text
|
| 1285 |
+
|
| 1286 |
+
if isinstance(doc_to_text, int):
|
| 1287 |
+
return doc_to_text
|
| 1288 |
+
elif isinstance(doc_to_text, str):
|
| 1289 |
+
if doc_to_text in self.features:
|
| 1290 |
+
# if self.config.doc_to_choice is not None:
|
| 1291 |
+
# return self.doc_to_choice(doc)[doc[doc_to_text]]
|
| 1292 |
+
# else:
|
| 1293 |
+
return doc[doc_to_text]
|
| 1294 |
+
else:
|
| 1295 |
+
text_string = utils.apply_template(doc_to_text, doc)
|
| 1296 |
+
if text_string.isdigit() and self._config.doc_to_choice is not None:
|
| 1297 |
+
return ast.literal_eval(text_string)
|
| 1298 |
+
else:
|
| 1299 |
+
return text_string
|
| 1300 |
+
elif callable(doc_to_text):
|
| 1301 |
+
return doc_to_text(doc)
|
| 1302 |
+
# Used when applying a Promptsource template
|
| 1303 |
+
elif hasattr(doc_to_text, "apply"):
|
| 1304 |
+
applied_prompt = doc_to_text.apply(doc)
|
| 1305 |
+
if len(applied_prompt) == 2:
|
| 1306 |
+
return applied_prompt[0]
|
| 1307 |
+
else:
|
| 1308 |
+
eval_logger.warning("Applied prompt returns empty string")
|
| 1309 |
+
return self.config.fewshot_delimiter
|
| 1310 |
+
else:
|
| 1311 |
+
print(type(doc_to_text))
|
| 1312 |
+
raise TypeError
|
| 1313 |
+
|
| 1314 |
+
def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
|
| 1315 |
+
if self.prompt is not None:
|
| 1316 |
+
doc_to_target = self.prompt
|
| 1317 |
+
elif doc_to_target is not None:
|
| 1318 |
+
doc_to_target = doc_to_target
|
| 1319 |
+
else:
|
| 1320 |
+
doc_to_target = self.config.doc_to_target
|
| 1321 |
+
|
| 1322 |
+
if isinstance(doc_to_target, int):
|
| 1323 |
+
return doc_to_target
|
| 1324 |
+
elif isinstance(doc_to_target, str):
|
| 1325 |
+
if doc_to_target in self.features:
|
| 1326 |
+
# if self.config.doc_to_choice is not None:
|
| 1327 |
+
# return self.doc_to_choice(doc)[doc[doc_to_target]]
|
| 1328 |
+
# else:
|
| 1329 |
+
return doc[doc_to_target]
|
| 1330 |
+
else:
|
| 1331 |
+
target_string = utils.apply_template(doc_to_target, doc)
|
| 1332 |
+
if target_string.isdigit() and self._config.doc_to_choice is not None:
|
| 1333 |
+
return ast.literal_eval(target_string)
|
| 1334 |
+
elif (
|
| 1335 |
+
len(target_string) >= 2
|
| 1336 |
+
and (target_string[0] == "[")
|
| 1337 |
+
and (target_string[-1] == "]")
|
| 1338 |
+
):
|
| 1339 |
+
try:
|
| 1340 |
+
return ast.literal_eval(target_string)
|
| 1341 |
+
except (SyntaxError, ValueError):
|
| 1342 |
+
return target_string
|
| 1343 |
+
else:
|
| 1344 |
+
return target_string
|
| 1345 |
+
elif isinstance(doc_to_target, list):
|
| 1346 |
+
return doc_to_target
|
| 1347 |
+
elif callable(doc_to_target):
|
| 1348 |
+
return doc_to_target(doc)
|
| 1349 |
+
# Used when applying a Promptsource template
|
| 1350 |
+
elif hasattr(doc_to_target, "apply"):
|
| 1351 |
+
applied_prompt = doc_to_target.apply(doc)
|
| 1352 |
+
if len(applied_prompt) == 2:
|
| 1353 |
+
return applied_prompt[1]
|
| 1354 |
+
else:
|
| 1355 |
+
eval_logger.warning("Applied prompt returns empty string")
|
| 1356 |
+
return self.config.fewshot_delimiter
|
| 1357 |
+
else:
|
| 1358 |
+
raise TypeError
|
| 1359 |
+
|
| 1360 |
+
def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
|
| 1361 |
+
if self.prompt is not None:
|
| 1362 |
+
doc_to_choice = self.prompt
|
| 1363 |
+
elif doc_to_choice is not None:
|
| 1364 |
+
doc_to_choice = doc_to_choice
|
| 1365 |
+
elif self.config.doc_to_choice is None:
|
| 1366 |
+
eval_logger.error("doc_to_choice was called but not set in config")
|
| 1367 |
+
else:
|
| 1368 |
+
doc_to_choice = self.config.doc_to_choice
|
| 1369 |
+
|
| 1370 |
+
if isinstance(doc_to_choice, str):
|
| 1371 |
+
if doc_to_choice in self.features:
|
| 1372 |
+
return doc[doc_to_choice]
|
| 1373 |
+
else:
|
| 1374 |
+
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
|
| 1375 |
+
elif isinstance(doc_to_choice, list):
|
| 1376 |
+
return doc_to_choice
|
| 1377 |
+
elif isinstance(doc_to_choice, dict):
|
| 1378 |
+
return list(doc_to_choice.values())
|
| 1379 |
+
elif callable(doc_to_choice):
|
| 1380 |
+
return doc_to_choice(doc)
|
| 1381 |
+
elif hasattr(doc_to_choice, "get_answer_choices_list"):
|
| 1382 |
+
return doc_to_choice.get_answer_choices_list(doc)
|
| 1383 |
+
else:
|
| 1384 |
+
raise TypeError
|
| 1385 |
+
|
| 1386 |
+
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
|
| 1387 |
+
if doc_to_image is not None:
|
| 1388 |
+
doc_to_image = doc_to_image
|
| 1389 |
+
elif self.config.doc_to_image is not None:
|
| 1390 |
+
doc_to_image = self.config.doc_to_image
|
| 1391 |
+
else:
|
| 1392 |
+
return None
|
| 1393 |
+
|
| 1394 |
+
if isinstance(doc_to_image, list):
|
| 1395 |
+
image_feature = [
|
| 1396 |
+
self.doc_to_image(doc, feature) for feature in doc_to_image
|
| 1397 |
+
]
|
| 1398 |
+
return [feature for feature in image_feature if feature is not None]
|
| 1399 |
+
elif isinstance(doc_to_image, str):
|
| 1400 |
+
if doc_to_image in self.features:
|
| 1401 |
+
return doc[doc_to_image]
|
| 1402 |
+
else:
|
| 1403 |
+
return ast.literal_eval(utils.apply_template(doc_to_image, doc))
|
| 1404 |
+
elif callable(doc_to_image):
|
| 1405 |
+
return doc_to_image(doc)
|
| 1406 |
+
else:
|
| 1407 |
+
return None
|
| 1408 |
+
|
| 1409 |
+
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]:
|
| 1410 |
+
if doc_to_audio is not None:
|
| 1411 |
+
doc_to_audio = doc_to_audio
|
| 1412 |
+
elif self.config.doc_to_audio is not None:
|
| 1413 |
+
doc_to_audio = self.config.doc_to_audio
|
| 1414 |
+
else:
|
| 1415 |
+
return None
|
| 1416 |
+
|
| 1417 |
+
if isinstance(doc_to_audio, list):
|
| 1418 |
+
audio_feature = [
|
| 1419 |
+
self.doc_to_audio(doc, feature) for feature in doc_to_audio
|
| 1420 |
+
]
|
| 1421 |
+
return [feature for feature in audio_feature if feature is not None]
|
| 1422 |
+
elif isinstance(doc_to_audio, str):
|
| 1423 |
+
if doc_to_audio in self.features:
|
| 1424 |
+
return doc[doc_to_audio]
|
| 1425 |
+
else:
|
| 1426 |
+
return ast.literal_eval(utils.apply_template(doc_to_audio, doc))
|
| 1427 |
+
elif callable(doc_to_audio):
|
| 1428 |
+
return doc_to_audio(doc)
|
| 1429 |
+
else:
|
| 1430 |
+
return None
|
| 1431 |
+
|
| 1432 |
+
def doc_to_prefix(self, doc):
|
| 1433 |
+
if (gen_prefix := self.config.gen_prefix) is not None:
|
| 1434 |
+
if gen_prefix in self.features:
|
| 1435 |
+
return doc[gen_prefix]
|
| 1436 |
+
else:
|
| 1437 |
+
return utils.apply_template(gen_prefix, doc)
|
| 1438 |
+
return None
|
| 1439 |
+
|
| 1440 |
+
def construct_requests(
|
| 1441 |
+
self, doc: dict, ctx: str, **kwargs
|
| 1442 |
+
) -> Union[List[Instance], Instance]:
|
| 1443 |
+
apply_chat_template = kwargs.pop("apply_chat_template", False)
|
| 1444 |
+
chat_template: Callable | None = kwargs.pop("chat_template", None)
|
| 1445 |
+
|
| 1446 |
+
aux_arguments = None
|
| 1447 |
+
|
| 1448 |
+
if self.OUTPUT_TYPE == "loglikelihood":
|
| 1449 |
+
arguments = (ctx, self.doc_to_target(doc))
|
| 1450 |
+
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
|
| 1451 |
+
arguments = (self.doc_to_target(doc),)
|
| 1452 |
+
elif self.OUTPUT_TYPE == "multiple_choice":
|
| 1453 |
+
choices = self.doc_to_choice(doc)
|
| 1454 |
+
target_delimiter = self.config.target_delimiter
|
| 1455 |
+
if apply_chat_template:
|
| 1456 |
+
target_delimiter = ""
|
| 1457 |
+
if self.multiple_input:
|
| 1458 |
+
# If there are multiple inputs, choices are placed in the ctx
|
| 1459 |
+
# apply chat_template to choices if apply_chat_template
|
| 1460 |
+
cont = self.doc_to_target(doc)
|
| 1461 |
+
|
| 1462 |
+
arguments = [
|
| 1463 |
+
(
|
| 1464 |
+
ctx
|
| 1465 |
+
+ (
|
| 1466 |
+
chat_template([{"role": "user", "content": choice}])
|
| 1467 |
+
if apply_chat_template
|
| 1468 |
+
else choice
|
| 1469 |
+
),
|
| 1470 |
+
f"{target_delimiter}{cont}",
|
| 1471 |
+
)
|
| 1472 |
+
for choice in choices
|
| 1473 |
+
]
|
| 1474 |
+
else:
|
| 1475 |
+
# Otherwise they are placed in the continuation
|
| 1476 |
+
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
|
| 1477 |
+
|
| 1478 |
+
# TODO: we should raise a warning telling users this will at most ~2x runtime.
|
| 1479 |
+
if "acc_mutual_info" in self._metric_fn_list.keys():
|
| 1480 |
+
# if we are calculating multiple choice accuracy
|
| 1481 |
+
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
|
| 1482 |
+
|
| 1483 |
+
# here mutual info refers to calculating
|
| 1484 |
+
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
|
| 1485 |
+
# in other words normalizing by subtracting the unconditional logprob of each choice.
|
| 1486 |
+
# TODO: should these be strided? will have to modify the processing in process_results if so
|
| 1487 |
+
aux_arguments = [
|
| 1488 |
+
("", f"{target_delimiter}{choice}") for choice in choices
|
| 1489 |
+
]
|
| 1490 |
+
|
| 1491 |
+
arguments.extend(aux_arguments)
|
| 1492 |
+
|
| 1493 |
+
elif self.OUTPUT_TYPE == "generate_until":
|
| 1494 |
+
arguments = (ctx, deepcopy(self.config.generation_kwargs))
|
| 1495 |
+
|
| 1496 |
+
multimodal_arg = {}
|
| 1497 |
+
if (
|
| 1498 |
+
self.config.doc_to_image
|
| 1499 |
+
): # TODO: ensure that non-multimodal tasks aren't getting visual args
|
| 1500 |
+
multimodal_arg = {
|
| 1501 |
+
**multimodal_arg,
|
| 1502 |
+
**{"visual": self.doc_to_image(doc)},
|
| 1503 |
+
}
|
| 1504 |
+
|
| 1505 |
+
if (
|
| 1506 |
+
self.config.doc_to_audio
|
| 1507 |
+
): # TODO: ensure that non-multimodal tasks aren't getting audio args
|
| 1508 |
+
multimodal_arg = {
|
| 1509 |
+
**multimodal_arg,
|
| 1510 |
+
**{"audio": self.doc_to_audio(doc)},
|
| 1511 |
+
}
|
| 1512 |
+
|
| 1513 |
+
if bool(multimodal_arg):
|
| 1514 |
+
if isinstance(arguments, list):
|
| 1515 |
+
arguments = [arg + (multimodal_arg,) for arg in arguments]
|
| 1516 |
+
else:
|
| 1517 |
+
arguments = arguments + (multimodal_arg,)
|
| 1518 |
+
|
| 1519 |
+
if self.OUTPUT_TYPE == "multiple_choice":
|
| 1520 |
+
request_list = [
|
| 1521 |
+
Instance(
|
| 1522 |
+
request_type="loglikelihood",
|
| 1523 |
+
doc=doc,
|
| 1524 |
+
arguments=arg,
|
| 1525 |
+
idx=i,
|
| 1526 |
+
**kwargs,
|
| 1527 |
+
)
|
| 1528 |
+
for i, arg in enumerate(arguments)
|
| 1529 |
+
]
|
| 1530 |
+
|
| 1531 |
+
return request_list
|
| 1532 |
+
|
| 1533 |
+
return Instance(
|
| 1534 |
+
request_type=self.OUTPUT_TYPE,
|
| 1535 |
+
doc=doc,
|
| 1536 |
+
arguments=arguments,
|
| 1537 |
+
idx=0,
|
| 1538 |
+
**kwargs,
|
| 1539 |
+
)
|
| 1540 |
+
|
| 1541 |
+
def process_results(self, doc, results):
|
| 1542 |
+
if callable(self.config.process_results):
|
| 1543 |
+
return self.config.process_results(doc, results)
|
| 1544 |
+
|
| 1545 |
+
result_dict = {}
|
| 1546 |
+
use_metric = list(self._metric_fn_list.keys())
|
| 1547 |
+
if self.OUTPUT_TYPE == "loglikelihood":
|
| 1548 |
+
results = results[0]
|
| 1549 |
+
ll, is_greedy = results
|
| 1550 |
+
return {
|
| 1551 |
+
**({"perplexity": ll} if "perplexity" in use_metric else {}),
|
| 1552 |
+
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
|
| 1553 |
+
}
|
| 1554 |
+
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
|
| 1555 |
+
(loglikelihood,) = results
|
| 1556 |
+
_words = self.count_words(self.doc_to_target(doc))
|
| 1557 |
+
_bytes = self.count_bytes(self.doc_to_target(doc))
|
| 1558 |
+
return {
|
| 1559 |
+
**(
|
| 1560 |
+
{"word_perplexity": (loglikelihood, _words)}
|
| 1561 |
+
if "word_perplexity" in use_metric
|
| 1562 |
+
else {}
|
| 1563 |
+
),
|
| 1564 |
+
**(
|
| 1565 |
+
{"byte_perplexity": (loglikelihood, _bytes)}
|
| 1566 |
+
if "byte_perplexity" in use_metric
|
| 1567 |
+
else {}
|
| 1568 |
+
),
|
| 1569 |
+
**(
|
| 1570 |
+
{"bits_per_byte": (loglikelihood, _bytes)}
|
| 1571 |
+
if "bits_per_byte" in use_metric
|
| 1572 |
+
else {}
|
| 1573 |
+
),
|
| 1574 |
+
}
|
| 1575 |
+
elif self.OUTPUT_TYPE == "multiple_choice":
|
| 1576 |
+
lls, is_greedy = zip(*results)
|
| 1577 |
+
|
| 1578 |
+
# retrieve choices in List[str] form, to compute choice lengths, etc.
|
| 1579 |
+
choices = self.doc_to_choice(doc)
|
| 1580 |
+
completion_len = np.array([float(len(i)) for i in choices])
|
| 1581 |
+
|
| 1582 |
+
if (
|
| 1583 |
+
2 * len(choices) == len(lls)
|
| 1584 |
+
and "acc_mutual_info" in self._metric_fn_list.keys()
|
| 1585 |
+
):
|
| 1586 |
+
# then we are doing mutual info.
|
| 1587 |
+
# this stores the "dryrun" / unconditional answer loglikelihoods
|
| 1588 |
+
# as we extend the args list with unconditional ("", continuation) pairs
|
| 1589 |
+
lls_unconditional = lls[len(choices) :]
|
| 1590 |
+
if len(lls_unconditional) != len(choices):
|
| 1591 |
+
raise ValueError
|
| 1592 |
+
# and this stores our "regular" conditional loglikelihoods
|
| 1593 |
+
lls = lls[: len(choices)]
|
| 1594 |
+
|
| 1595 |
+
pred = np.argmax(lls)
|
| 1596 |
+
pred_norm = np.argmax(lls / completion_len)
|
| 1597 |
+
|
| 1598 |
+
if self.multiple_input:
|
| 1599 |
+
gold = self.doc_to_text(doc)
|
| 1600 |
+
else:
|
| 1601 |
+
gold = self.doc_to_target(doc)
|
| 1602 |
+
|
| 1603 |
+
gold_index_error = False
|
| 1604 |
+
if isinstance(gold, list):
|
| 1605 |
+
gold = [i if i < len(choices) else -100 for i in gold]
|
| 1606 |
+
if -100 in gold:
|
| 1607 |
+
gold_index_error = True
|
| 1608 |
+
else:
|
| 1609 |
+
if isinstance(gold, int):
|
| 1610 |
+
gold = gold if gold < len(choices) else -100
|
| 1611 |
+
elif isinstance(gold, str):
|
| 1612 |
+
gold = choices.index(gold) if gold in choices else -100
|
| 1613 |
+
|
| 1614 |
+
if gold == -100:
|
| 1615 |
+
gold_index_error = True
|
| 1616 |
+
|
| 1617 |
+
if gold_index_error:
|
| 1618 |
+
eval_logger.warning(
|
| 1619 |
+
f"Label index was not in within range of available choices,"
|
| 1620 |
+
f"Sample:\n\n{doc}\n\n"
|
| 1621 |
+
)
|
| 1622 |
+
|
| 1623 |
+
if self.multiple_target:
|
| 1624 |
+
acc = 1.0 if pred in gold else 0.0
|
| 1625 |
+
acc_norm = 1.0 if pred_norm in gold else 0.0
|
| 1626 |
+
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
|
| 1627 |
+
else:
|
| 1628 |
+
acc = 1.0 if pred == gold else 0.0
|
| 1629 |
+
acc_norm = 1.0 if pred_norm == gold else 0.0
|
| 1630 |
+
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
|
| 1631 |
+
exact_match = int(is_greedy[gold]) if gold != -100 else 0
|
| 1632 |
+
|
| 1633 |
+
prob_norm = utils.softmax(lls)
|
| 1634 |
+
|
| 1635 |
+
# TODO use keyword arguments to the metric?
|
| 1636 |
+
# gold, pred, norm stuff, the original lls,
|
| 1637 |
+
result_dict = {
|
| 1638 |
+
**({"acc": acc} if "acc" in use_metric else {}),
|
| 1639 |
+
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
|
| 1640 |
+
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
|
| 1641 |
+
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
|
| 1642 |
+
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
|
| 1643 |
+
**(
|
| 1644 |
+
{"brier_score": (gold, prob_norm)}
|
| 1645 |
+
if "brier_score" in use_metric
|
| 1646 |
+
else {}
|
| 1647 |
+
),
|
| 1648 |
+
}
|
| 1649 |
+
|
| 1650 |
+
if "acc_mutual_info" in use_metric:
|
| 1651 |
+
lls_mutual_info = [
|
| 1652 |
+
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
|
| 1653 |
+
]
|
| 1654 |
+
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
|
| 1655 |
+
result_dict["acc_mutual_info"] = acc_mutual_info
|
| 1656 |
+
|
| 1657 |
+
elif self.OUTPUT_TYPE == "generate_until":
|
| 1658 |
+
gold = self.doc_to_target(doc)
|
| 1659 |
+
result = results[0]
|
| 1660 |
+
if self.config.doc_to_choice is not None:
|
| 1661 |
+
# If you set doc_to_choice,
|
| 1662 |
+
# it assumes that doc_to_target returns a number.
|
| 1663 |
+
choices = self.doc_to_choice(doc)
|
| 1664 |
+
gold = choices[gold]
|
| 1665 |
+
# we expect multiple_targets to be a list.
|
| 1666 |
+
elif self.multiple_target:
|
| 1667 |
+
gold = list(gold)
|
| 1668 |
+
# TODO: handle this better
|
| 1669 |
+
elif type(gold) is not type(result) and not (
|
| 1670 |
+
"bypass" in self._metric_fn_list.keys() or isinstance(result, list)
|
| 1671 |
+
):
|
| 1672 |
+
# cast gold to the same type as result
|
| 1673 |
+
gold = type(result)(gold)
|
| 1674 |
+
|
| 1675 |
+
for metric in self._metric_fn_list.keys():
|
| 1676 |
+
if self.multiple_target:
|
| 1677 |
+
# in the case where we have multiple targets,
|
| 1678 |
+
# return true if any are true
|
| 1679 |
+
# TODO: this may break for multipLe_target, non zero-or-1 metrics
|
| 1680 |
+
scores = []
|
| 1681 |
+
if not isinstance(gold, list):
|
| 1682 |
+
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
|
| 1683 |
+
# print(gold)
|
| 1684 |
+
gold = [gold]
|
| 1685 |
+
if metric == "exact_match":
|
| 1686 |
+
result = [result for _ in range(len(gold))]
|
| 1687 |
+
scores = self._metric_fn_list[metric](
|
| 1688 |
+
references=gold,
|
| 1689 |
+
predictions=result,
|
| 1690 |
+
**self._metric_fn_kwargs[metric],
|
| 1691 |
+
)[metric]
|
| 1692 |
+
result_score = 1.0 if scores > 0.0 else 0.0
|
| 1693 |
+
else:
|
| 1694 |
+
for gold_option in gold:
|
| 1695 |
+
try:
|
| 1696 |
+
result_score = self._metric_fn_list[metric](
|
| 1697 |
+
references=[gold_option],
|
| 1698 |
+
predictions=[result],
|
| 1699 |
+
**self._metric_fn_kwargs[metric],
|
| 1700 |
+
)
|
| 1701 |
+
except (
|
| 1702 |
+
TypeError
|
| 1703 |
+
): # TODO: this is hacky and I don't want to do it
|
| 1704 |
+
result_score = self._metric_fn_list[metric](
|
| 1705 |
+
[gold_option, result]
|
| 1706 |
+
)
|
| 1707 |
+
if isinstance(result_score, dict):
|
| 1708 |
+
# TODO: this handles the case where HF evaluate returns a dict.
|
| 1709 |
+
result_score = result_score[metric]
|
| 1710 |
+
scores.append(result_score)
|
| 1711 |
+
if any(scores):
|
| 1712 |
+
result_score = 1.0
|
| 1713 |
+
else:
|
| 1714 |
+
result_score = 0.0
|
| 1715 |
+
else:
|
| 1716 |
+
try:
|
| 1717 |
+
result_score = self._metric_fn_list[metric](
|
| 1718 |
+
references=[gold],
|
| 1719 |
+
predictions=[result],
|
| 1720 |
+
**self._metric_fn_kwargs[metric],
|
| 1721 |
+
)
|
| 1722 |
+
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
|
| 1723 |
+
result_score = self._metric_fn_list[metric]([gold, result])
|
| 1724 |
+
if isinstance(result_score, dict):
|
| 1725 |
+
# TODO: this handles the case where HF evaluate returns a dict.
|
| 1726 |
+
# This allows for multiple metrics to be returned from the same function
|
| 1727 |
+
for k, v in result_score.items():
|
| 1728 |
+
result_dict[k] = v
|
| 1729 |
+
else:
|
| 1730 |
+
result_dict[metric] = result_score
|
| 1731 |
+
else:
|
| 1732 |
+
raise ValueError(
|
| 1733 |
+
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
|
| 1734 |
+
"'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
|
| 1735 |
+
)
|
| 1736 |
+
|
| 1737 |
+
return result_dict
|
| 1738 |
+
|
| 1739 |
+
def aggregation(self) -> dict:
|
| 1740 |
+
return self._aggregation_list
|
| 1741 |
+
|
| 1742 |
+
def higher_is_better(self) -> dict:
|
| 1743 |
+
return self._higher_is_better
|
| 1744 |
+
|
| 1745 |
+
def get_config(self, key: str) -> Any:
|
| 1746 |
+
return getattr(self._config, key, None)
|
| 1747 |
+
|
| 1748 |
+
@property
|
| 1749 |
+
def task_name(self) -> Any:
|
| 1750 |
+
return getattr(self.config, "task", None)
|
| 1751 |
+
|
| 1752 |
+
def __repr__(self):
|
| 1753 |
+
return (
|
| 1754 |
+
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
|
| 1755 |
+
f"output_type={self.OUTPUT_TYPE},"
|
| 1756 |
+
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
|
| 1757 |
+
f"num_samples={len(self.eval_docs)})"
|
| 1758 |
+
)
|
| 1759 |
+
|
| 1760 |
+
|
| 1761 |
+
class MultipleChoiceTask(Task):
|
| 1762 |
+
OUTPUT_TYPE = "loglikelihood"
|
| 1763 |
+
|
| 1764 |
+
def doc_to_target(self, doc: dict) -> str:
|
| 1765 |
+
return " " + doc["choices"][doc["gold"]]
|
| 1766 |
+
|
| 1767 |
+
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
|
| 1768 |
+
# TODO: add mutual info here?
|
| 1769 |
+
return [
|
| 1770 |
+
Instance(
|
| 1771 |
+
request_type="loglikelihood",
|
| 1772 |
+
doc=doc,
|
| 1773 |
+
arguments=(ctx, " {}".format(choice)),
|
| 1774 |
+
idx=i,
|
| 1775 |
+
**kwargs,
|
| 1776 |
+
)
|
| 1777 |
+
for i, choice in enumerate(doc["choices"])
|
| 1778 |
+
]
|
| 1779 |
+
|
| 1780 |
+
def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
|
| 1781 |
+
results = [
|
| 1782 |
+
res[0] for res in results
|
| 1783 |
+
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
|
| 1784 |
+
gold = doc["gold"]
|
| 1785 |
+
|
| 1786 |
+
acc = 1.0 if np.argmax(results) == gold else 0.0
|
| 1787 |
+
completion_len = np.array([float(len(i)) for i in doc["choices"]])
|
| 1788 |
+
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
|
| 1789 |
+
|
| 1790 |
+
return {
|
| 1791 |
+
"acc": acc,
|
| 1792 |
+
"acc_norm": acc_norm,
|
| 1793 |
+
}
|
| 1794 |
+
|
| 1795 |
+
def higher_is_better(self) -> dict:
|
| 1796 |
+
return {
|
| 1797 |
+
"acc": True,
|
| 1798 |
+
"acc_norm": True,
|
| 1799 |
+
}
|
| 1800 |
+
|
| 1801 |
+
def aggregation(self) -> dict:
|
| 1802 |
+
return {
|
| 1803 |
+
"acc": mean,
|
| 1804 |
+
"acc_norm": mean,
|
| 1805 |
+
}
|
| 1806 |
+
|
| 1807 |
+
|
| 1808 |
+
class PerplexityTask(Task):
|
| 1809 |
+
OUTPUT_TYPE = "loglikelihood_rolling"
|
| 1810 |
+
|
| 1811 |
+
def has_training_docs(self) -> bool:
|
| 1812 |
+
return False
|
| 1813 |
+
|
| 1814 |
+
def fewshot_examples(self, k: int, rnd) -> List:
|
| 1815 |
+
if k != 0:
|
| 1816 |
+
raise ValueError(
|
| 1817 |
+
"The number of fewshot examples must be 0 for perplexity tasks."
|
| 1818 |
+
)
|
| 1819 |
+
return []
|
| 1820 |
+
|
| 1821 |
+
def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
|
| 1822 |
+
if num_fewshot != 0:
|
| 1823 |
+
raise ValueError(
|
| 1824 |
+
"The number of fewshot examples must be 0 for perplexity tasks."
|
| 1825 |
+
)
|
| 1826 |
+
|
| 1827 |
+
return ""
|
| 1828 |
+
|
| 1829 |
+
def higher_is_better(self) -> dict:
|
| 1830 |
+
return {
|
| 1831 |
+
"word_perplexity": False,
|
| 1832 |
+
"byte_perplexity": False,
|
| 1833 |
+
"bits_per_byte": False,
|
| 1834 |
+
}
|
| 1835 |
+
|
| 1836 |
+
def doc_to_decontamination_query(self, doc):
|
| 1837 |
+
return doc
|
| 1838 |
+
|
| 1839 |
+
def doc_to_text(self, doc) -> str:
|
| 1840 |
+
return ""
|
| 1841 |
+
|
| 1842 |
+
def doc_to_target(self, doc):
|
| 1843 |
+
return doc
|
| 1844 |
+
|
| 1845 |
+
def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
|
| 1846 |
+
if bool(ctx):
|
| 1847 |
+
raise ValueError
|
| 1848 |
+
|
| 1849 |
+
return Instance(
|
| 1850 |
+
request_type=self.OUTPUT_TYPE,
|
| 1851 |
+
doc=doc,
|
| 1852 |
+
arguments=(self.doc_to_target(doc),),
|
| 1853 |
+
idx=0,
|
| 1854 |
+
**kwargs,
|
| 1855 |
+
)
|
| 1856 |
+
|
| 1857 |
+
def process_results(self, doc: dict, results: Tuple[float]) -> dict:
|
| 1858 |
+
(loglikelihood,) = results
|
| 1859 |
+
words = self.count_words(self.doc_to_target(doc))
|
| 1860 |
+
bytes_ = self.count_bytes(self.doc_to_target(doc))
|
| 1861 |
+
return {
|
| 1862 |
+
"word_perplexity": (loglikelihood, words),
|
| 1863 |
+
"byte_perplexity": (loglikelihood, bytes_),
|
| 1864 |
+
"bits_per_byte": (loglikelihood, bytes_),
|
| 1865 |
+
}
|
| 1866 |
+
|
| 1867 |
+
def aggregation(self) -> dict:
|
| 1868 |
+
return {
|
| 1869 |
+
"word_perplexity": weighted_perplexity,
|
| 1870 |
+
"byte_perplexity": weighted_perplexity,
|
| 1871 |
+
"bits_per_byte": bits_per_byte,
|
| 1872 |
+
}
|
| 1873 |
+
|
| 1874 |
+
@classmethod
|
| 1875 |
+
def count_bytes(cls, doc) -> int:
|
| 1876 |
+
return len(doc.encode("utf-8"))
|
| 1877 |
+
|
| 1878 |
+
@classmethod
|
| 1879 |
+
def count_words(cls, doc) -> int:
|
| 1880 |
+
"""Downstream tasks with custom word boundaries should override this!"""
|
| 1881 |
+
return len(re.split(r"\s+", doc))
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/caching/__init__.py
ADDED
|
File without changes
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/caching/cache.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import dill
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
eval_logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
|
| 12 |
+
|
| 13 |
+
OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache"
|
| 17 |
+
|
| 18 |
+
# This should be sufficient for uniqueness
|
| 19 |
+
HASH_INPUT = "EleutherAI-lm-evaluation-harness"
|
| 20 |
+
|
| 21 |
+
HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
|
| 22 |
+
|
| 23 |
+
FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_from_cache(file_name: str, cache: bool = False):
|
| 27 |
+
if not cache:
|
| 28 |
+
return
|
| 29 |
+
try:
|
| 30 |
+
path = f"{PATH}/{file_name}{FILE_SUFFIX}"
|
| 31 |
+
|
| 32 |
+
with open(path, "rb") as file:
|
| 33 |
+
cached_task_dict = dill.loads(file.read())
|
| 34 |
+
return cached_task_dict
|
| 35 |
+
|
| 36 |
+
except Exception:
|
| 37 |
+
eval_logger.debug(f"{file_name} is not cached, generating...")
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def save_to_cache(file_name, obj):
|
| 42 |
+
if not os.path.exists(PATH):
|
| 43 |
+
os.mkdir(PATH)
|
| 44 |
+
|
| 45 |
+
file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"
|
| 46 |
+
|
| 47 |
+
eval_logger.debug(f"Saving {file_path} to cache...")
|
| 48 |
+
with open(file_path, "wb") as file:
|
| 49 |
+
file.write(dill.dumps(obj))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# NOTE the "key" param is to allow for flexibility
|
| 53 |
+
def delete_cache(key: str = ""):
|
| 54 |
+
files = os.listdir(PATH)
|
| 55 |
+
|
| 56 |
+
for file in files:
|
| 57 |
+
if file.startswith(key) and file.endswith(FILE_SUFFIX):
|
| 58 |
+
file_path = f"{PATH}/{file}"
|
| 59 |
+
os.unlink(file_path)
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/decontamination/__init__.py
ADDED
|
File without changes
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/decontamination/janitor.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import re
|
| 3 |
+
import string
|
| 4 |
+
import traceback
|
| 5 |
+
from typing import Iterator, List, Sequence, Tuple, TypeVar
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# This is a cpp module. Compile janitor_util.cpp with:
|
| 9 |
+
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
|
| 10 |
+
try:
|
| 11 |
+
import janitor_util
|
| 12 |
+
|
| 13 |
+
JANITOR_CPP = True
|
| 14 |
+
except Exception:
|
| 15 |
+
print("WARNING: C++ module could not be loaded. Janitor running in python mode")
|
| 16 |
+
traceback.print_exc()
|
| 17 |
+
JANITOR_CPP = False
|
| 18 |
+
|
| 19 |
+
T = TypeVar("T")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Implementation from nltk source
|
| 23 |
+
# https://www.nltk.org/_modules/nltk/util.html
|
| 24 |
+
def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
|
| 25 |
+
history = []
|
| 26 |
+
while n > 1:
|
| 27 |
+
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
|
| 28 |
+
try:
|
| 29 |
+
next_item = next(sequence)
|
| 30 |
+
except StopIteration:
|
| 31 |
+
# no more data, terminate the generator
|
| 32 |
+
return
|
| 33 |
+
history.append(next_item)
|
| 34 |
+
n -= 1
|
| 35 |
+
for item in sequence:
|
| 36 |
+
history.append(item)
|
| 37 |
+
yield tuple(history)
|
| 38 |
+
del history[0]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def word_ngrams(s: str, n: int) -> Iterator[str]:
|
| 42 |
+
"""Splits a string into ngram words"""
|
| 43 |
+
tokens = s.split() # not a generator :(
|
| 44 |
+
ngram_seqs = form_ngrams(iter(tokens), n)
|
| 45 |
+
return (" ".join(ngram) for ngram in ngram_seqs)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Does character sequences only - combined faster function to play around with later
|
| 49 |
+
# def word_ngrams_indices_combined(sequence, n):
|
| 50 |
+
# current_word = ""
|
| 51 |
+
# history = []
|
| 52 |
+
# gap = False;
|
| 53 |
+
# start = 0
|
| 54 |
+
# end = 0
|
| 55 |
+
# for character in sequence:
|
| 56 |
+
# if character == " ":
|
| 57 |
+
# if not gap:
|
| 58 |
+
# gap = True
|
| 59 |
+
# history.append(current_word)
|
| 60 |
+
# end += len(current_word) - 1
|
| 61 |
+
# current_word = ""
|
| 62 |
+
# if len(history) == n:
|
| 63 |
+
# yield (tuple(history), start, end)
|
| 64 |
+
# del history[0]
|
| 65 |
+
# start = end + 1
|
| 66 |
+
# end = start
|
| 67 |
+
# else:
|
| 68 |
+
# gap = False
|
| 69 |
+
# current_word += character
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
|
| 73 |
+
def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
|
| 74 |
+
"""Splits a string on whitespaces and records the indices of each in the original string.
|
| 75 |
+
@:return generator((word, (start_idx, end_idx)), ...)
|
| 76 |
+
"""
|
| 77 |
+
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
|
| 81 |
+
"""Splits a string into pairs of (ngram words, their start/end indices)"""
|
| 82 |
+
tokens_with_indices = split_indices(s)
|
| 83 |
+
|
| 84 |
+
# Generator of ngrams of (word, idx_pairs)
|
| 85 |
+
# (
|
| 86 |
+
# [(word, (start,end)), (word, (start, end))...],
|
| 87 |
+
# [(word, (start, end)), ...],
|
| 88 |
+
# ...
|
| 89 |
+
# )
|
| 90 |
+
ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
|
| 91 |
+
|
| 92 |
+
# Generator of pairs of word and index ngrams
|
| 93 |
+
# (
|
| 94 |
+
# ([word, word, ...], [(start,end), (start,end), ...]),
|
| 95 |
+
# ...
|
| 96 |
+
# )
|
| 97 |
+
ngram_indices_pairs = (
|
| 98 |
+
zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
|
| 102 |
+
return (
|
| 103 |
+
(" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
|
| 104 |
+
for ngram_seq, indices in ngram_indices_pairs
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class Janitor:
|
| 109 |
+
# FIXME delete_chars: Should anything else go here? Special chars?
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
ngram_n: int = 13,
|
| 113 |
+
window_to_remove: int = 200,
|
| 114 |
+
too_dirty_cutoff: int = 10,
|
| 115 |
+
minimum_slice_length: int = 200,
|
| 116 |
+
delete_chars: str = string.punctuation,
|
| 117 |
+
) -> None:
|
| 118 |
+
self.ngram_n = ngram_n
|
| 119 |
+
self.window_to_remove = window_to_remove
|
| 120 |
+
self.too_dirty_cutoff = too_dirty_cutoff
|
| 121 |
+
self.minimum_slice_length = minimum_slice_length
|
| 122 |
+
self.delete_chars = delete_chars
|
| 123 |
+
|
| 124 |
+
self.dirt_ngrams = set()
|
| 125 |
+
|
| 126 |
+
# If in python, we'll translate uppercase to lowercase and delete naughty characters.
|
| 127 |
+
# This is fast by python standards
|
| 128 |
+
# https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
|
| 129 |
+
self.translation_table = str.maketrans(
|
| 130 |
+
string.ascii_lowercase + string.ascii_uppercase, # These characters
|
| 131 |
+
string.ascii_lowercase * 2, # Become these characters
|
| 132 |
+
self.delete_chars, # These are deleted
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
##############
|
| 136 |
+
# I/O for saving contamination ngrams
|
| 137 |
+
##############
|
| 138 |
+
|
| 139 |
+
def save_contamination_ngrams(self, filename: str) -> None:
|
| 140 |
+
with open(filename, "wb") as fp:
|
| 141 |
+
pickle.dump(filename, fp)
|
| 142 |
+
|
| 143 |
+
def load_contamination_ngrams(self, filename: str) -> None:
|
| 144 |
+
with open(filename, "rb") as fp:
|
| 145 |
+
self.dirt_ngrams = pickle.load(fp)
|
| 146 |
+
|
| 147 |
+
##############
|
| 148 |
+
# Call these :)
|
| 149 |
+
##############
|
| 150 |
+
|
| 151 |
+
def register_contaminant(self, dirt_string: str) -> None:
|
| 152 |
+
"""Register a string as contamination to be removed, e.g. a test set
|
| 153 |
+
This breaks the dirt_string into ngrams to store for future cleaning"""
|
| 154 |
+
if JANITOR_CPP:
|
| 155 |
+
return self.register_contaminant_cpp(dirt_string)
|
| 156 |
+
else:
|
| 157 |
+
print("WARNING: Janitor running in python mode")
|
| 158 |
+
return self.register_contaminant_python(dirt_string)
|
| 159 |
+
|
| 160 |
+
def clean(self, dirty_string: str) -> List[str]:
|
| 161 |
+
"""Clean a string (e.g. a training set) by removing all ngrams previously
|
| 162 |
+
registered as contaminants. Returns a list of clean chunks, or empty if
|
| 163 |
+
the string was too dirty"""
|
| 164 |
+
if JANITOR_CPP:
|
| 165 |
+
return self.clean_cpp(dirty_string)
|
| 166 |
+
else:
|
| 167 |
+
print("WARNING: Janitor running in python mode")
|
| 168 |
+
return self.clean_python(dirty_string)
|
| 169 |
+
|
| 170 |
+
def _split_chunks(
|
| 171 |
+
self, dirty_string: str, dirty_parts: Sequence[Tuple]
|
| 172 |
+
) -> List[str]:
|
| 173 |
+
clean_chunks = []
|
| 174 |
+
splice_idx = 0
|
| 175 |
+
end = -1
|
| 176 |
+
for i, (ngram, start, end) in enumerate(dirty_parts):
|
| 177 |
+
if i >= self.too_dirty_cutoff:
|
| 178 |
+
return []
|
| 179 |
+
start = max(0, start - self.window_to_remove)
|
| 180 |
+
end = min(len(dirty_string), end + self.window_to_remove)
|
| 181 |
+
|
| 182 |
+
if start - splice_idx > self.minimum_slice_length:
|
| 183 |
+
clean_chunks.append(dirty_string[splice_idx:start])
|
| 184 |
+
splice_idx = end
|
| 185 |
+
|
| 186 |
+
if end < len(dirty_string) - self.minimum_slice_length:
|
| 187 |
+
clean_chunks.append(dirty_string[end + 1 :])
|
| 188 |
+
|
| 189 |
+
return clean_chunks
|
| 190 |
+
|
| 191 |
+
##############
|
| 192 |
+
# Fast C++
|
| 193 |
+
##############
|
| 194 |
+
|
| 195 |
+
def register_contaminant_cpp(self, dirt_string) -> None:
|
| 196 |
+
self.dirt_ngrams.update(
|
| 197 |
+
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def clean_cpp(self, dirty_string: str) -> List[str]:
|
| 201 |
+
contamination_indices = janitor_util.clean_ngram_with_indices(
|
| 202 |
+
dirty_string, self.delete_chars, self.ngram_n
|
| 203 |
+
)
|
| 204 |
+
return self._split_chunks(dirty_string, contamination_indices)
|
| 205 |
+
|
| 206 |
+
##############
|
| 207 |
+
# Slow python
|
| 208 |
+
##############
|
| 209 |
+
|
| 210 |
+
def normalize_string(self, s: str) -> str:
|
| 211 |
+
return s.translate(self.translation_table)
|
| 212 |
+
|
| 213 |
+
def register_contaminant_python(self, dirt_string: str) -> None:
|
| 214 |
+
self.dirt_ngrams.update(
|
| 215 |
+
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def clean_python(self, dirty_string: str) -> List[str]:
|
| 219 |
+
contamination_indices = (
|
| 220 |
+
(None, *idx_pair)
|
| 221 |
+
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
|
| 222 |
+
if self.normalize_string(dirty_ngram) in self.dirt_ngrams
|
| 223 |
+
)
|
| 224 |
+
return self._split_chunks(dirty_string, contamination_indices)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
##################################################################
|
| 228 |
+
# Tests
|
| 229 |
+
#################################################################
|
| 230 |
+
|
| 231 |
+
# def print_cpp():
|
| 232 |
+
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
| 233 |
+
|
| 234 |
+
# for i in range(1, 10, 2):
|
| 235 |
+
# pprint(janitor_util.clean_ngram(source, string.punctuation, i))
|
| 236 |
+
# for ngram, start, end in \
|
| 237 |
+
# janitor_util.clean_ngram_with_indices(source, string.punctuation, i):
|
| 238 |
+
# print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n"))
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# def test_cpp():
|
| 242 |
+
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
| 243 |
+
# contaminant = "dirty boy. Clean he he"
|
| 244 |
+
|
| 245 |
+
# jan_python = Janitor()
|
| 246 |
+
# jan_cpp = Janitor()
|
| 247 |
+
|
| 248 |
+
# jan_python.register_contaminant_python(contaminant)
|
| 249 |
+
# jan_cpp.register_contaminant(contaminant)
|
| 250 |
+
|
| 251 |
+
# assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams)
|
| 252 |
+
|
| 253 |
+
# assert jan_python.clean_python(source) == jan_cpp.clean(source), \
|
| 254 |
+
# (jan_python.clean_python(source), jan_cpp.clean(source))
|
| 255 |
+
|
| 256 |
+
# print("Passed test, python==cpp")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# def benchmark():
|
| 260 |
+
# # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html
|
| 261 |
+
# setup = \
|
| 262 |
+
# """
|
| 263 |
+
# with open("data/enwik8", "r") as f:
|
| 264 |
+
# data = f.read()
|
| 265 |
+
# jan = Janitor(too_dirty_cutoff=1000)
|
| 266 |
+
# jan.register_contaminant('''
|
| 267 |
+
# theories is that there is a connection between "geekdom" and autism.
|
| 268 |
+
# This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled "
|
| 269 |
+
# The [[Geek]] Syndrome", which is a point argued by many in the autism rights
|
| 270 |
+
# movement{{ref|Wired}}. This article, many professionals assert, is just one example of
|
| 271 |
+
# the media's application of mental disease labels to what is actually variant normal behavior
|
| 272 |
+
# &mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
|
| 273 |
+
# interests, even when they seem unusual to others, are not in themselves signs of autism or
|
| 274 |
+
# Asperger's syndrome. Others assert that it is actually the medical profession which is applying
|
| 275 |
+
# mental disease labels to children who in the past would have simply been accepted as a little
|
| 276 |
+
# different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
|
| 277 |
+
# Due to the recent publicity surrounding autism and autis
|
| 278 |
+
# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
|
| 279 |
+
# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
|
| 280 |
+
# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
|
| 281 |
+
# would last, took a cautious approach, preferring to save the revenue rather than investing it in
|
| 282 |
+
# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
|
| 283 |
+
# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
|
| 284 |
+
# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
|
| 285 |
+
# with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
|
| 286 |
+
# ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
|
| 287 |
+
# ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
|
| 288 |
+
# Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
|
| 289 |
+
# [[United Arab Emirates]]. After the Emirates gained independence in 1971,
|
| 290 |
+
# ''')
|
| 291 |
+
# """
|
| 292 |
+
|
| 293 |
+
# n = 1
|
| 294 |
+
# print(f"Timing {n} run on 100 MB")
|
| 295 |
+
# print("Register contaminant")
|
| 296 |
+
# # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n))
|
| 297 |
+
# print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n))
|
| 298 |
+
|
| 299 |
+
# print("Clean")
|
| 300 |
+
# # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n))
|
| 301 |
+
# print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n))
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# def test_janitor_general():
|
| 305 |
+
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
| 306 |
+
# contaminant = "dirty boy. Clean he he"
|
| 307 |
+
|
| 308 |
+
# jan = Janitor(ngram_n=3)
|
| 309 |
+
# jan.register_contaminant(contaminant)
|
| 310 |
+
# cleaned = " ".join(jan.clean(source))
|
| 311 |
+
# for contam in jan.dirt_ngrams:
|
| 312 |
+
# assert contam not in cleaned, contam
|
| 313 |
+
|
| 314 |
+
# filename = "data/saved_contam"
|
| 315 |
+
# jan.save_contamination_ngrams(filename)
|
| 316 |
+
|
| 317 |
+
# jan = Janitor(ngram_n=3)
|
| 318 |
+
# jan.load_contamination_ngrams(filename)
|
| 319 |
+
# cleaned = " ".join(jan.clean(source))
|
| 320 |
+
# for contam in jan.dirt_ngrams:
|
| 321 |
+
# assert contam not in cleaned, contam
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# if __name__ == "__main__":
|
| 325 |
+
# test()
|
| 326 |
+
# # print_cpp()
|
| 327 |
+
# # test_cpp()
|
| 328 |
+
# # benchmark()
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .evaluation_tracker import EvaluationTracker
|
| 2 |
+
from .wandb_logger import WandbLogger
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/evaluation_tracker.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import time
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from dataclasses import asdict, dataclass
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
from datasets.utils.metadata import MetadataConfigs
|
| 13 |
+
from huggingface_hub import (
|
| 14 |
+
DatasetCard,
|
| 15 |
+
DatasetCardData,
|
| 16 |
+
HfApi,
|
| 17 |
+
hf_hub_url,
|
| 18 |
+
)
|
| 19 |
+
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
|
| 20 |
+
|
| 21 |
+
from dllm_eval.utils import (
|
| 22 |
+
get_file_datetime,
|
| 23 |
+
get_file_task_name,
|
| 24 |
+
get_results_filenames,
|
| 25 |
+
get_sample_results_filenames,
|
| 26 |
+
handle_non_serializable,
|
| 27 |
+
hash_string,
|
| 28 |
+
sanitize_list,
|
| 29 |
+
sanitize_model_name,
|
| 30 |
+
sanitize_task_name,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
eval_logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass(init=False)
|
| 38 |
+
class GeneralConfigTracker:
|
| 39 |
+
"""
|
| 40 |
+
Tracker for the evaluation parameters.
|
| 41 |
+
|
| 42 |
+
Attributes:
|
| 43 |
+
model_source (str): Source of the model (e.g. Hugging Face, GGUF, etc.)
|
| 44 |
+
model_name (str): Name of the model.
|
| 45 |
+
model_name_sanitized (str): Sanitized model name for directory creation.
|
| 46 |
+
start_time (float): Start time of the experiment. Logged at class init.
|
| 47 |
+
end_time (float): Start time of the experiment. Logged when calling [`GeneralConfigTracker.log_end_time`]
|
| 48 |
+
total_evaluation_time_seconds (str): Inferred total evaluation time in seconds (from the start and end times).
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
model_source: str = None
|
| 52 |
+
model_name: str = None
|
| 53 |
+
model_name_sanitized: str = None
|
| 54 |
+
system_instruction: str = None
|
| 55 |
+
system_instruction_sha: str = None
|
| 56 |
+
fewshot_as_multiturn: bool = None
|
| 57 |
+
chat_template: str = None
|
| 58 |
+
chat_template_sha: str = None
|
| 59 |
+
start_time: float = None
|
| 60 |
+
end_time: float = None
|
| 61 |
+
total_evaluation_time_seconds: str = None
|
| 62 |
+
|
| 63 |
+
def __init__(self) -> None:
|
| 64 |
+
"""Starts the evaluation timer."""
|
| 65 |
+
self.start_time = time.perf_counter()
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def _get_model_name(model_args: str) -> str:
|
| 69 |
+
"""Extracts the model name from the model arguments."""
|
| 70 |
+
|
| 71 |
+
def extract_model_name(model_args: str, key: str) -> str:
|
| 72 |
+
"""Extracts the model name from the model arguments using a key."""
|
| 73 |
+
args_after_key = model_args.split(key)[1]
|
| 74 |
+
return args_after_key.split(",")[0]
|
| 75 |
+
|
| 76 |
+
# order does matter, e.g. peft and delta are provided together with pretrained
|
| 77 |
+
prefixes = ["peft=", "delta=", "pretrained=", "model=", "path=", "engine="]
|
| 78 |
+
for prefix in prefixes:
|
| 79 |
+
if prefix in model_args:
|
| 80 |
+
return extract_model_name(model_args, prefix)
|
| 81 |
+
return ""
|
| 82 |
+
|
| 83 |
+
def log_experiment_args(
|
| 84 |
+
self,
|
| 85 |
+
model_source: str,
|
| 86 |
+
model_args: str,
|
| 87 |
+
system_instruction: str,
|
| 88 |
+
chat_template: str,
|
| 89 |
+
fewshot_as_multiturn: bool,
|
| 90 |
+
) -> None:
|
| 91 |
+
"""Logs model parameters and job ID."""
|
| 92 |
+
self.model_source = model_source
|
| 93 |
+
self.model_name = GeneralConfigTracker._get_model_name(model_args)
|
| 94 |
+
self.model_name_sanitized = sanitize_model_name(self.model_name)
|
| 95 |
+
self.system_instruction = system_instruction
|
| 96 |
+
self.system_instruction_sha = (
|
| 97 |
+
hash_string(system_instruction) if system_instruction else None
|
| 98 |
+
)
|
| 99 |
+
self.chat_template = chat_template
|
| 100 |
+
self.chat_template_sha = hash_string(chat_template) if chat_template else None
|
| 101 |
+
self.fewshot_as_multiturn = fewshot_as_multiturn
|
| 102 |
+
|
| 103 |
+
def log_end_time(self) -> None:
|
| 104 |
+
"""Logs the end time of the evaluation and calculates the total evaluation time."""
|
| 105 |
+
self.end_time = time.perf_counter()
|
| 106 |
+
self.total_evaluation_time_seconds = str(self.end_time - self.start_time)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class EvaluationTracker:
|
| 110 |
+
"""
|
| 111 |
+
Keeps track and saves relevant information of the evaluation process.
|
| 112 |
+
Compiles the data from trackers and writes it to files, which can be published to the Hugging Face hub if requested.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
output_path: str = None,
|
| 118 |
+
hub_results_org: str = "",
|
| 119 |
+
hub_repo_name: str = "",
|
| 120 |
+
details_repo_name: str = "",
|
| 121 |
+
results_repo_name: str = "",
|
| 122 |
+
push_results_to_hub: bool = False,
|
| 123 |
+
push_samples_to_hub: bool = False,
|
| 124 |
+
public_repo: bool = False,
|
| 125 |
+
token: str = "",
|
| 126 |
+
leaderboard_url: str = "",
|
| 127 |
+
point_of_contact: str = "",
|
| 128 |
+
gated: bool = False,
|
| 129 |
+
) -> None:
|
| 130 |
+
"""
|
| 131 |
+
Creates all the necessary loggers for evaluation tracking.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
output_path (str): Path to save the results. If not provided, the results won't be saved.
|
| 135 |
+
hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token.
|
| 136 |
+
hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`.
|
| 137 |
+
details_repo_name (str): The name of the Hugging Face repository to push the details to. If not provided, the results will be pushed to `lm-eval-results`.
|
| 138 |
+
result_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will not be pushed and will be found in the details_hub_repo.
|
| 139 |
+
push_results_to_hub (bool): Whether to push the results to the Hugging Face hub.
|
| 140 |
+
push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub.
|
| 141 |
+
public_repo (bool): Whether to push the results to a public or private repository.
|
| 142 |
+
token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`.
|
| 143 |
+
leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card.
|
| 144 |
+
point_of_contact (str): Contact information on the Hugging Face hub dataset card.
|
| 145 |
+
gated (bool): Whether to gate the repository.
|
| 146 |
+
"""
|
| 147 |
+
self.general_config_tracker = GeneralConfigTracker()
|
| 148 |
+
|
| 149 |
+
self.output_path = output_path
|
| 150 |
+
self.push_results_to_hub = push_results_to_hub
|
| 151 |
+
self.push_samples_to_hub = push_samples_to_hub
|
| 152 |
+
self.public_repo = public_repo
|
| 153 |
+
self.leaderboard_url = leaderboard_url
|
| 154 |
+
self.point_of_contact = point_of_contact
|
| 155 |
+
self.api = HfApi(token=token) if token else None
|
| 156 |
+
self.gated_repo = gated
|
| 157 |
+
|
| 158 |
+
if not self.api and (push_results_to_hub or push_samples_to_hub):
|
| 159 |
+
raise ValueError(
|
| 160 |
+
"Hugging Face token is not defined, but 'push_results_to_hub' or 'push_samples_to_hub' is set to True. "
|
| 161 |
+
"Please provide a valid Hugging Face token by setting the HF_TOKEN environment variable."
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if (
|
| 165 |
+
self.api
|
| 166 |
+
and hub_results_org == ""
|
| 167 |
+
and (push_results_to_hub or push_samples_to_hub)
|
| 168 |
+
):
|
| 169 |
+
hub_results_org = self.api.whoami()["name"]
|
| 170 |
+
eval_logger.warning(
|
| 171 |
+
f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'."
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if hub_repo_name == "":
|
| 175 |
+
details_repo_name = (
|
| 176 |
+
details_repo_name if details_repo_name != "" else "lm-eval-results"
|
| 177 |
+
)
|
| 178 |
+
results_repo_name = (
|
| 179 |
+
results_repo_name if results_repo_name != "" else details_repo_name
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
details_repo_name = hub_repo_name
|
| 183 |
+
results_repo_name = hub_repo_name
|
| 184 |
+
eval_logger.warning(
|
| 185 |
+
"hub_repo_name was specified. Both details and results will be pushed to the same repository. Using hub_repo_name is no longer recommended, details_repo_name and results_repo_name should be used instead."
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.details_repo = f"{hub_results_org}/{details_repo_name}"
|
| 189 |
+
self.details_repo_private = f"{hub_results_org}/{details_repo_name}-private"
|
| 190 |
+
self.results_repo = f"{hub_results_org}/{results_repo_name}"
|
| 191 |
+
self.results_repo_private = f"{hub_results_org}/{results_repo_name}-private"
|
| 192 |
+
|
| 193 |
+
def save_results_aggregated(
|
| 194 |
+
self,
|
| 195 |
+
results: dict,
|
| 196 |
+
samples: dict,
|
| 197 |
+
) -> None:
|
| 198 |
+
"""
|
| 199 |
+
Saves the aggregated results and samples to the output path and pushes them to the Hugging Face hub if requested.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
results (dict): The aggregated results to save.
|
| 203 |
+
samples (dict): The samples results to save.
|
| 204 |
+
"""
|
| 205 |
+
self.general_config_tracker.log_end_time()
|
| 206 |
+
|
| 207 |
+
if self.output_path:
|
| 208 |
+
try:
|
| 209 |
+
eval_logger.info("Saving results aggregated")
|
| 210 |
+
|
| 211 |
+
# calculate cumulative hash for each task - only if samples are provided
|
| 212 |
+
task_hashes = {}
|
| 213 |
+
if samples:
|
| 214 |
+
for task_name, task_samples in samples.items():
|
| 215 |
+
sample_hashes = [
|
| 216 |
+
s["doc_hash"] + s["prompt_hash"] + s["target_hash"]
|
| 217 |
+
for s in task_samples
|
| 218 |
+
]
|
| 219 |
+
task_hashes[task_name] = hash_string("".join(sample_hashes))
|
| 220 |
+
|
| 221 |
+
# update initial results dict
|
| 222 |
+
results.update({"task_hashes": task_hashes})
|
| 223 |
+
results.update(asdict(self.general_config_tracker))
|
| 224 |
+
dumped = json.dumps(
|
| 225 |
+
results,
|
| 226 |
+
indent=2,
|
| 227 |
+
default=handle_non_serializable,
|
| 228 |
+
ensure_ascii=False,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
path = Path(self.output_path if self.output_path else Path.cwd())
|
| 232 |
+
self.date_id = datetime.now().isoformat().replace(":", "-")
|
| 233 |
+
if path.suffix == ".json":
|
| 234 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 235 |
+
file_results_aggregated = path.with_name(
|
| 236 |
+
f"{path.stem}_{self.date_id}.json"
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 240 |
+
file_results_aggregated = path.joinpath(
|
| 241 |
+
f"results_{self.date_id}.json"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
file_results_aggregated.open("w", encoding="utf-8").write(dumped)
|
| 245 |
+
|
| 246 |
+
if self.api and self.push_results_to_hub:
|
| 247 |
+
repo_id = (
|
| 248 |
+
self.results_repo
|
| 249 |
+
if self.public_repo
|
| 250 |
+
else self.results_repo_private
|
| 251 |
+
)
|
| 252 |
+
self.api.create_repo(
|
| 253 |
+
repo_id=repo_id,
|
| 254 |
+
repo_type="dataset",
|
| 255 |
+
private=not self.public_repo,
|
| 256 |
+
exist_ok=True,
|
| 257 |
+
)
|
| 258 |
+
self.api.upload_file(
|
| 259 |
+
repo_id=repo_id,
|
| 260 |
+
path_or_fileobj=str(file_results_aggregated),
|
| 261 |
+
path_in_repo=os.path.join(
|
| 262 |
+
self.general_config_tracker.model_name,
|
| 263 |
+
file_results_aggregated.name,
|
| 264 |
+
),
|
| 265 |
+
repo_type="dataset",
|
| 266 |
+
commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
|
| 267 |
+
)
|
| 268 |
+
eval_logger.info(
|
| 269 |
+
"Successfully pushed aggregated results to the Hugging Face Hub. "
|
| 270 |
+
f"You can find them at: {repo_id}"
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
except Exception as e:
|
| 274 |
+
eval_logger.warning("Could not save results aggregated")
|
| 275 |
+
eval_logger.info(repr(e))
|
| 276 |
+
else:
|
| 277 |
+
eval_logger.info(
|
| 278 |
+
"Output path not provided, skipping saving results aggregated"
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def save_results_samples(
|
| 282 |
+
self,
|
| 283 |
+
task_name: str,
|
| 284 |
+
samples: dict,
|
| 285 |
+
) -> None:
|
| 286 |
+
"""
|
| 287 |
+
Saves the samples results to the output path and pushes them to the Hugging Face hub if requested.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
task_name (str): The task name to save the samples for.
|
| 291 |
+
samples (dict): The samples results to save.
|
| 292 |
+
"""
|
| 293 |
+
if self.output_path:
|
| 294 |
+
try:
|
| 295 |
+
eval_logger.info(f"Saving per-sample results for: {task_name}")
|
| 296 |
+
|
| 297 |
+
path = Path(self.output_path if self.output_path else Path.cwd())
|
| 298 |
+
if path.suffix == ".json":
|
| 299 |
+
path = path.parent
|
| 300 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 301 |
+
|
| 302 |
+
file_results_samples = path.joinpath(
|
| 303 |
+
f"samples_{task_name}_{self.date_id}.jsonl"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
for sample in samples:
|
| 307 |
+
# we first need to sanitize arguments and resps
|
| 308 |
+
# otherwise we won't be able to load the dataset
|
| 309 |
+
# using the datasets library
|
| 310 |
+
arguments = {}
|
| 311 |
+
for i, arg in enumerate(sample["arguments"]):
|
| 312 |
+
arguments[f"gen_args_{i}"] = {}
|
| 313 |
+
for j, tmp in enumerate(arg):
|
| 314 |
+
arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp
|
| 315 |
+
|
| 316 |
+
sample["resps"] = sanitize_list(sample["resps"])
|
| 317 |
+
sample["filtered_resps"] = sanitize_list(sample["filtered_resps"])
|
| 318 |
+
sample["arguments"] = arguments
|
| 319 |
+
sample["target"] = str(sample["target"])
|
| 320 |
+
|
| 321 |
+
sample_dump = (
|
| 322 |
+
json.dumps(
|
| 323 |
+
sample,
|
| 324 |
+
default=handle_non_serializable,
|
| 325 |
+
ensure_ascii=False,
|
| 326 |
+
)
|
| 327 |
+
+ "\n"
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
with open(file_results_samples, "a", encoding="utf-8") as f:
|
| 331 |
+
f.write(sample_dump)
|
| 332 |
+
|
| 333 |
+
if self.api and self.push_samples_to_hub:
|
| 334 |
+
repo_id = (
|
| 335 |
+
self.details_repo
|
| 336 |
+
if self.public_repo
|
| 337 |
+
else self.details_repo_private
|
| 338 |
+
)
|
| 339 |
+
self.api.create_repo(
|
| 340 |
+
repo_id=repo_id,
|
| 341 |
+
repo_type="dataset",
|
| 342 |
+
private=not self.public_repo,
|
| 343 |
+
exist_ok=True,
|
| 344 |
+
)
|
| 345 |
+
try:
|
| 346 |
+
if self.gated_repo:
|
| 347 |
+
headers = build_hf_headers()
|
| 348 |
+
r = get_session().put(
|
| 349 |
+
url=f"https://huggingface.co/api/datasets/{repo_id}/settings",
|
| 350 |
+
headers=headers,
|
| 351 |
+
json={"gated": "auto"},
|
| 352 |
+
)
|
| 353 |
+
hf_raise_for_status(r)
|
| 354 |
+
except Exception as e:
|
| 355 |
+
eval_logger.warning("Could not gate the repository")
|
| 356 |
+
eval_logger.info(repr(e))
|
| 357 |
+
self.api.upload_folder(
|
| 358 |
+
repo_id=repo_id,
|
| 359 |
+
folder_path=str(path),
|
| 360 |
+
path_in_repo=self.general_config_tracker.model_name_sanitized,
|
| 361 |
+
repo_type="dataset",
|
| 362 |
+
commit_message=f"Adding samples results for {task_name} to {self.general_config_tracker.model_name}",
|
| 363 |
+
)
|
| 364 |
+
eval_logger.info(
|
| 365 |
+
f"Successfully pushed sample results for task: {task_name} to the Hugging Face Hub. "
|
| 366 |
+
f"You can find them at: {repo_id}"
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
except Exception as e:
|
| 370 |
+
eval_logger.warning("Could not save sample results")
|
| 371 |
+
eval_logger.info(repr(e))
|
| 372 |
+
else:
|
| 373 |
+
eval_logger.info("Output path not provided, skipping saving sample results")
|
| 374 |
+
|
| 375 |
+
def recreate_metadata_card(self) -> None:
|
| 376 |
+
"""
|
| 377 |
+
Creates a metadata card for the evaluation results dataset and pushes it to the Hugging Face hub.
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
eval_logger.info("Recreating metadata card")
|
| 381 |
+
repo_id = self.details_repo if self.public_repo else self.details_repo_private
|
| 382 |
+
|
| 383 |
+
files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset")
|
| 384 |
+
results_files = get_results_filenames(files_in_repo)
|
| 385 |
+
sample_files = get_sample_results_filenames(files_in_repo)
|
| 386 |
+
|
| 387 |
+
# Build a dictionary to store the latest evaluation datetime for:
|
| 388 |
+
# - Each tested model and its aggregated results
|
| 389 |
+
# - Each task and sample results, if existing
|
| 390 |
+
# i.e. {
|
| 391 |
+
# "org__model_name__gsm8k": "2021-09-01T12:00:00",
|
| 392 |
+
# "org__model_name__ifeval": "2021-09-01T12:00:00",
|
| 393 |
+
# "org__model_name__results": "2021-09-01T12:00:00"
|
| 394 |
+
# }
|
| 395 |
+
latest_task_results_datetime = defaultdict(lambda: datetime.min.isoformat())
|
| 396 |
+
|
| 397 |
+
for file_path in sample_files:
|
| 398 |
+
file_path = Path(file_path)
|
| 399 |
+
filename = file_path.name
|
| 400 |
+
model_name = file_path.parent
|
| 401 |
+
task_name = get_file_task_name(filename)
|
| 402 |
+
results_datetime = get_file_datetime(filename)
|
| 403 |
+
task_name_sanitized = sanitize_task_name(task_name)
|
| 404 |
+
# Results and sample results for the same model and task will have the same datetime
|
| 405 |
+
samples_key = f"{model_name}__{task_name_sanitized}"
|
| 406 |
+
results_key = f"{model_name}__results"
|
| 407 |
+
latest_datetime = max(
|
| 408 |
+
latest_task_results_datetime[samples_key],
|
| 409 |
+
results_datetime,
|
| 410 |
+
)
|
| 411 |
+
latest_task_results_datetime[samples_key] = latest_datetime
|
| 412 |
+
latest_task_results_datetime[results_key] = max(
|
| 413 |
+
latest_task_results_datetime[results_key],
|
| 414 |
+
latest_datetime,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# Create metadata card
|
| 418 |
+
card_metadata = MetadataConfigs()
|
| 419 |
+
|
| 420 |
+
# Add the latest aggregated results to the metadata card for easy access
|
| 421 |
+
for file_path in results_files:
|
| 422 |
+
file_path = Path(file_path)
|
| 423 |
+
results_filename = file_path.name
|
| 424 |
+
model_name = file_path.parent
|
| 425 |
+
eval_date = get_file_datetime(results_filename)
|
| 426 |
+
eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
|
| 427 |
+
results_filename = Path("**") / Path(results_filename).name
|
| 428 |
+
config_name = f"{model_name}__results"
|
| 429 |
+
sanitized_last_eval_date_results = re.sub(
|
| 430 |
+
r"[^\w\.]", "_", latest_task_results_datetime[config_name]
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
if eval_date_sanitized == sanitized_last_eval_date_results:
|
| 434 |
+
# Ensure that all results files are listed in the metadata card
|
| 435 |
+
current_results = card_metadata.get(config_name, {"data_files": []})
|
| 436 |
+
current_results["data_files"].append(
|
| 437 |
+
{"split": eval_date_sanitized, "path": [str(results_filename)]}
|
| 438 |
+
)
|
| 439 |
+
card_metadata[config_name] = current_results
|
| 440 |
+
# If the results file is the newest, update the "latest" field in the metadata card
|
| 441 |
+
card_metadata[config_name]["data_files"].append(
|
| 442 |
+
{"split": "latest", "path": [str(results_filename)]}
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# Add the tasks details configs
|
| 446 |
+
for file_path in sample_files:
|
| 447 |
+
file_path = Path(file_path)
|
| 448 |
+
filename = file_path.name
|
| 449 |
+
model_name = file_path.parent
|
| 450 |
+
task_name = get_file_task_name(filename)
|
| 451 |
+
eval_date = get_file_datetime(filename)
|
| 452 |
+
task_name_sanitized = sanitize_task_name(task_name)
|
| 453 |
+
eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
|
| 454 |
+
results_filename = Path("**") / Path(filename).name
|
| 455 |
+
config_name = f"{model_name}__{task_name_sanitized}"
|
| 456 |
+
sanitized_last_eval_date_results = re.sub(
|
| 457 |
+
r"[^\w\.]", "_", latest_task_results_datetime[config_name]
|
| 458 |
+
)
|
| 459 |
+
if eval_date_sanitized == sanitized_last_eval_date_results:
|
| 460 |
+
# Ensure that all sample results files are listed in the metadata card
|
| 461 |
+
current_details_for_task = card_metadata.get(
|
| 462 |
+
config_name, {"data_files": []}
|
| 463 |
+
)
|
| 464 |
+
current_details_for_task["data_files"].append(
|
| 465 |
+
{"split": eval_date_sanitized, "path": [str(results_filename)]}
|
| 466 |
+
)
|
| 467 |
+
card_metadata[config_name] = current_details_for_task
|
| 468 |
+
# If the samples results file is the newest, update the "latest" field in the metadata card
|
| 469 |
+
card_metadata[config_name]["data_files"].append(
|
| 470 |
+
{"split": "latest", "path": [str(results_filename)]}
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
# Get latest results and extract info to update metadata card examples
|
| 474 |
+
latest_datetime = max(latest_task_results_datetime.values())
|
| 475 |
+
latest_model_name = max(
|
| 476 |
+
latest_task_results_datetime, key=lambda k: latest_task_results_datetime[k]
|
| 477 |
+
)
|
| 478 |
+
last_results_file = [
|
| 479 |
+
f for f in results_files if latest_datetime.replace(":", "-") in f
|
| 480 |
+
][0]
|
| 481 |
+
last_results_file_path = hf_hub_url(
|
| 482 |
+
repo_id=repo_id, filename=last_results_file, repo_type="dataset"
|
| 483 |
+
)
|
| 484 |
+
latest_results_file = load_dataset(
|
| 485 |
+
"json", data_files=last_results_file_path, split="train"
|
| 486 |
+
)
|
| 487 |
+
results_dict = latest_results_file["results"][0]
|
| 488 |
+
new_dictionary = {"all": results_dict}
|
| 489 |
+
new_dictionary.update(results_dict)
|
| 490 |
+
results_string = json.dumps(new_dictionary, indent=4)
|
| 491 |
+
|
| 492 |
+
dataset_summary = (
|
| 493 |
+
"Dataset automatically created during the evaluation run of model "
|
| 494 |
+
)
|
| 495 |
+
if self.general_config_tracker.model_source == "hf":
|
| 496 |
+
dataset_summary += f"[{self.general_config_tracker.model_name}](https://huggingface.co/{self.general_config_tracker.model_name})\n"
|
| 497 |
+
else:
|
| 498 |
+
dataset_summary += f"{self.general_config_tracker.model_name}\n"
|
| 499 |
+
dataset_summary += (
|
| 500 |
+
f"The dataset is composed of {len(card_metadata) - 1} configuration(s), each one corresponding to one of the evaluated task.\n\n"
|
| 501 |
+
f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each "
|
| 502 |
+
'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n'
|
| 503 |
+
'An additional configuration "results" store all the aggregated results of the run.\n\n'
|
| 504 |
+
"To load the details from a run, you can for instance do the following:\n"
|
| 505 |
+
)
|
| 506 |
+
if self.general_config_tracker.model_source == "hf":
|
| 507 |
+
dataset_summary += (
|
| 508 |
+
"```python\nfrom datasets import load_dataset\n"
|
| 509 |
+
f'data = load_dataset(\n\t"{repo_id}",\n\tname="{latest_model_name}",\n\tsplit="latest"\n)\n```\n\n'
|
| 510 |
+
)
|
| 511 |
+
dataset_summary += (
|
| 512 |
+
"## Latest results\n\n"
|
| 513 |
+
f"These are the [latest results from run {latest_datetime}]({last_results_file_path.replace('/resolve/', '/blob/')}) "
|
| 514 |
+
"(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. "
|
| 515 |
+
'You find each in the results and the "latest" split for each eval):\n\n'
|
| 516 |
+
f"```python\n{results_string}\n```"
|
| 517 |
+
)
|
| 518 |
+
card_data = DatasetCardData(
|
| 519 |
+
dataset_summary=dataset_summary,
|
| 520 |
+
repo_url=f"https://huggingface.co/{self.general_config_tracker.model_name}",
|
| 521 |
+
pretty_name=f"Evaluation run of {self.general_config_tracker.model_name}",
|
| 522 |
+
leaderboard_url=self.leaderboard_url,
|
| 523 |
+
point_of_contact=self.point_of_contact,
|
| 524 |
+
)
|
| 525 |
+
card_metadata.to_dataset_card_data(card_data)
|
| 526 |
+
card = DatasetCard.from_template(
|
| 527 |
+
card_data,
|
| 528 |
+
pretty_name=card_data.pretty_name,
|
| 529 |
+
)
|
| 530 |
+
card.push_to_hub(repo_id, repo_type="dataset")
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/utils.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import subprocess
|
| 5 |
+
from importlib.metadata import version
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from torch.utils.collect_env import get_pretty_env_info
|
| 11 |
+
from transformers import __version__ as trans_version
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
|
| 18 |
+
"""Remove the ',none' substring from the input_string if it exists at the end.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
input_string (str): The input string from which to remove the ',none' substring.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed
|
| 25 |
+
and a boolean indicating whether the modification was made (True) or not (False).
|
| 26 |
+
"""
|
| 27 |
+
# Define the pattern to match ',none' at the end of the string
|
| 28 |
+
pattern = re.compile(r",none$")
|
| 29 |
+
|
| 30 |
+
# Use sub() to replace ',none' with an empty string
|
| 31 |
+
result = re.sub(pattern, "", input_string)
|
| 32 |
+
|
| 33 |
+
# check if the input_string changed
|
| 34 |
+
removed = result != input_string
|
| 35 |
+
|
| 36 |
+
return result, removed
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _handle_non_serializable(o: Any) -> Union[int, str, list]:
|
| 40 |
+
"""Handle non-serializable objects by converting them to serializable types.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
o (Any): The object to be handled.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32,
|
| 47 |
+
it will be converted to int. If the object is of type set, it will be converted
|
| 48 |
+
to a list. Otherwise, it will be converted to str.
|
| 49 |
+
"""
|
| 50 |
+
if isinstance(o, np.int64) or isinstance(o, np.int32):
|
| 51 |
+
return int(o)
|
| 52 |
+
elif isinstance(o, set):
|
| 53 |
+
return list(o)
|
| 54 |
+
else:
|
| 55 |
+
return str(o)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]:
|
| 59 |
+
try:
|
| 60 |
+
git_folder = Path(repo_path, ".git")
|
| 61 |
+
if git_folder.is_file():
|
| 62 |
+
git_folder = Path(
|
| 63 |
+
git_folder.parent,
|
| 64 |
+
git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1],
|
| 65 |
+
)
|
| 66 |
+
if Path(git_folder, "HEAD").exists():
|
| 67 |
+
head_name = (
|
| 68 |
+
Path(git_folder, "HEAD")
|
| 69 |
+
.read_text(encoding="utf-8")
|
| 70 |
+
.split("\n")[0]
|
| 71 |
+
.split(" ")[-1]
|
| 72 |
+
)
|
| 73 |
+
head_ref = Path(git_folder, head_name)
|
| 74 |
+
git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "")
|
| 75 |
+
else:
|
| 76 |
+
git_hash = None
|
| 77 |
+
except Exception as err:
|
| 78 |
+
logger.debug(
|
| 79 |
+
f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}"
|
| 80 |
+
)
|
| 81 |
+
return None
|
| 82 |
+
return git_hash
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def get_git_commit_hash():
|
| 86 |
+
"""
|
| 87 |
+
Gets the git commit hash of your current repo (if it exists).
|
| 88 |
+
Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
|
| 89 |
+
"""
|
| 90 |
+
try:
|
| 91 |
+
git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
|
| 92 |
+
git_hash = git_hash.decode()
|
| 93 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
| 94 |
+
# FileNotFoundError occurs when git not installed on system
|
| 95 |
+
git_hash = get_commit_from_path(os.getcwd()) # git hash of repo if exists
|
| 96 |
+
return git_hash
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def add_env_info(storage: Dict[str, Any]):
|
| 100 |
+
try:
|
| 101 |
+
pretty_env_info = get_pretty_env_info()
|
| 102 |
+
except Exception as err:
|
| 103 |
+
pretty_env_info = str(err)
|
| 104 |
+
try:
|
| 105 |
+
dllm_eval_version = version("dllm_eval")
|
| 106 |
+
except Exception as err:
|
| 107 |
+
dllm_eval_version = str(err)
|
| 108 |
+
transformers_version = trans_version
|
| 109 |
+
upper_dir_commit = get_commit_from_path(
|
| 110 |
+
Path(os.getcwd(), "..")
|
| 111 |
+
) # git hash of upper repo if exists
|
| 112 |
+
added_info = {
|
| 113 |
+
"pretty_env_info": pretty_env_info,
|
| 114 |
+
"transformers_version": transformers_version,
|
| 115 |
+
"dllm_eval_version": dllm_eval_version,
|
| 116 |
+
"upper_git_hash": upper_dir_commit, # in case this repo is submodule
|
| 117 |
+
}
|
| 118 |
+
storage.update(added_info)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def add_tokenizer_info(storage: Dict[str, Any], lm):
|
| 122 |
+
if getattr(lm, "tokenizer", False):
|
| 123 |
+
try:
|
| 124 |
+
tokenizer_info = {
|
| 125 |
+
"tokenizer_pad_token": [
|
| 126 |
+
lm.tokenizer.pad_token,
|
| 127 |
+
str(lm.tokenizer.pad_token_id),
|
| 128 |
+
],
|
| 129 |
+
"tokenizer_eos_token": [
|
| 130 |
+
lm.tokenizer.eos_token,
|
| 131 |
+
str(lm.tokenizer.eos_token_id),
|
| 132 |
+
],
|
| 133 |
+
"tokenizer_bos_token": [
|
| 134 |
+
lm.tokenizer.bos_token,
|
| 135 |
+
str(lm.tokenizer.bos_token_id),
|
| 136 |
+
],
|
| 137 |
+
"eot_token_id": getattr(lm, "eot_token_id", None),
|
| 138 |
+
"max_length": getattr(lm, "max_length", None),
|
| 139 |
+
}
|
| 140 |
+
storage.update(tokenizer_info)
|
| 141 |
+
except Exception as err:
|
| 142 |
+
logger.debug(
|
| 143 |
+
f"Logging detailed tokenizer info failed with {err}, skipping..."
|
| 144 |
+
)
|
| 145 |
+
# seems gguf and textsynth do not have tokenizer
|
| 146 |
+
else:
|
| 147 |
+
logger.debug(
|
| 148 |
+
"LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results."
|
| 149 |
+
)
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/wandb_logger.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any, Dict, List, Literal, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from packaging.version import Version
|
| 9 |
+
|
| 10 |
+
from dllm_eval.loggers.utils import _handle_non_serializable, remove_none_pattern
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_wandb_printer() -> Literal["Printer"]:
|
| 17 |
+
"""Returns a wandb printer instance for pretty stdout."""
|
| 18 |
+
from wandb.sdk.lib.printer import new_printer
|
| 19 |
+
|
| 20 |
+
printer = new_printer()
|
| 21 |
+
return printer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class WandbLogger:
|
| 25 |
+
def __init__(self, init_args=None, config_args=None) -> None:
|
| 26 |
+
"""Attaches to wandb logger if already initialized. Otherwise, passes init_args to wandb.init() and config_args to wandb.config.update()
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
init_args Optional[Dict]: Arguments for init configuration.
|
| 30 |
+
config_args Optional[Dict]: Arguments for config
|
| 31 |
+
|
| 32 |
+
Parse and log the results returned from evaluator.simple_evaluate() with:
|
| 33 |
+
wandb_logger.post_init(results)
|
| 34 |
+
wandb_logger.log_eval_result()
|
| 35 |
+
wandb_logger.log_eval_samples(results["samples"])
|
| 36 |
+
"""
|
| 37 |
+
try:
|
| 38 |
+
import wandb
|
| 39 |
+
|
| 40 |
+
assert Version(wandb.__version__) >= Version("0.13.6")
|
| 41 |
+
if Version(wandb.__version__) < Version("0.13.6"):
|
| 42 |
+
wandb.require("report-editing:v0")
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.warning(
|
| 45 |
+
"To use the wandb reporting functionality please install wandb>=0.13.6.\n"
|
| 46 |
+
"To install the latest version of wandb run `pip install wandb --upgrade`\n"
|
| 47 |
+
f"{e}"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.wandb_args: Dict[str, Any] = init_args or {}
|
| 51 |
+
self.wandb_config_args: Dict[str, Any] = config_args or {}
|
| 52 |
+
|
| 53 |
+
# pop the step key from the args to save for all logging calls
|
| 54 |
+
self.step = self.wandb_args.pop("step", None)
|
| 55 |
+
|
| 56 |
+
# initialize a W&B run
|
| 57 |
+
if wandb.run is None:
|
| 58 |
+
self.run = wandb.init(**self.wandb_args)
|
| 59 |
+
if self.wandb_config_args:
|
| 60 |
+
self.run.config.update(self.wandb_config_args)
|
| 61 |
+
else:
|
| 62 |
+
self.run = wandb.run
|
| 63 |
+
|
| 64 |
+
self.printer = get_wandb_printer()
|
| 65 |
+
|
| 66 |
+
def post_init(self, results: Dict[str, Any]) -> None:
|
| 67 |
+
self.results: Dict[str, Any] = copy.deepcopy(results)
|
| 68 |
+
self.task_names: List[str] = list(results.get("results", {}).keys())
|
| 69 |
+
self.group_names: List[str] = list(results.get("groups", {}).keys())
|
| 70 |
+
|
| 71 |
+
def _get_config(self) -> Dict[str, Any]:
|
| 72 |
+
"""Get configuration parameters."""
|
| 73 |
+
self.task_configs = self.results.get("configs", {})
|
| 74 |
+
cli_configs = self.results.get("config", {})
|
| 75 |
+
configs = {
|
| 76 |
+
"task_configs": self.task_configs,
|
| 77 |
+
"cli_configs": cli_configs,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
return configs
|
| 81 |
+
|
| 82 |
+
def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]:
|
| 83 |
+
"""Sanitize the results dictionary."""
|
| 84 |
+
_results = copy.deepcopy(self.results.get("results", dict()))
|
| 85 |
+
|
| 86 |
+
# Remove None from the metric string name
|
| 87 |
+
tmp_results = copy.deepcopy(_results)
|
| 88 |
+
for task_name in self.task_names:
|
| 89 |
+
task_result = tmp_results.get(task_name, dict())
|
| 90 |
+
for metric_name, metric_value in task_result.items():
|
| 91 |
+
_metric_name, removed = remove_none_pattern(metric_name)
|
| 92 |
+
if removed:
|
| 93 |
+
_results[task_name][_metric_name] = metric_value
|
| 94 |
+
_results[task_name].pop(metric_name)
|
| 95 |
+
|
| 96 |
+
# remove string valued keys from the results dict
|
| 97 |
+
wandb_summary = {}
|
| 98 |
+
for task in self.task_names:
|
| 99 |
+
task_result = _results.get(task, dict())
|
| 100 |
+
for metric_name, metric_value in task_result.items():
|
| 101 |
+
if isinstance(metric_value, str):
|
| 102 |
+
wandb_summary[f"{task}/{metric_name}"] = metric_value
|
| 103 |
+
|
| 104 |
+
for summary_metric, summary_value in wandb_summary.items():
|
| 105 |
+
_task, _summary_metric = summary_metric.split("/")
|
| 106 |
+
_results[_task].pop(_summary_metric)
|
| 107 |
+
|
| 108 |
+
tmp_results = copy.deepcopy(_results)
|
| 109 |
+
for task_name, task_results in tmp_results.items():
|
| 110 |
+
for metric_name, metric_value in task_results.items():
|
| 111 |
+
_results[f"{task_name}/{metric_name}"] = metric_value
|
| 112 |
+
_results[task_name].pop(metric_name)
|
| 113 |
+
for task in self.task_names:
|
| 114 |
+
_results.pop(task)
|
| 115 |
+
|
| 116 |
+
return wandb_summary, _results
|
| 117 |
+
|
| 118 |
+
def _log_results_as_table(self) -> None:
|
| 119 |
+
"""Generate and log evaluation results as a table to W&B."""
|
| 120 |
+
columns = [
|
| 121 |
+
"Version",
|
| 122 |
+
"Filter",
|
| 123 |
+
"num_fewshot",
|
| 124 |
+
"Metric",
|
| 125 |
+
"Value",
|
| 126 |
+
"Stderr",
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
def make_table(columns: List[str], key: str = "results"):
|
| 130 |
+
import wandb
|
| 131 |
+
|
| 132 |
+
table = wandb.Table(columns=columns)
|
| 133 |
+
results = copy.deepcopy(self.results)
|
| 134 |
+
|
| 135 |
+
for k, dic in results.get(key).items():
|
| 136 |
+
if k in self.group_names and not key == "groups":
|
| 137 |
+
continue
|
| 138 |
+
version = results.get("versions").get(k)
|
| 139 |
+
if version == "N/A":
|
| 140 |
+
version = None
|
| 141 |
+
n = results.get("n-shot").get(k)
|
| 142 |
+
|
| 143 |
+
for (mf), v in dic.items():
|
| 144 |
+
m, _, f = mf.partition(",")
|
| 145 |
+
if m.endswith("_stderr"):
|
| 146 |
+
continue
|
| 147 |
+
if m == "alias":
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
if m + "_stderr" + "," + f in dic:
|
| 151 |
+
se = dic[m + "_stderr" + "," + f]
|
| 152 |
+
if se != "N/A":
|
| 153 |
+
se = "%.4f" % se
|
| 154 |
+
table.add_data(*[k, version, f, n, m, str(v), str(se)])
|
| 155 |
+
else:
|
| 156 |
+
table.add_data(*[k, version, f, n, m, str(v), ""])
|
| 157 |
+
|
| 158 |
+
return table
|
| 159 |
+
|
| 160 |
+
# log the complete eval result to W&B Table
|
| 161 |
+
table = make_table(["Tasks"] + columns, "results")
|
| 162 |
+
self.run.log({"evaluation/eval_results": table}, step=self.step)
|
| 163 |
+
|
| 164 |
+
if "groups" in self.results.keys():
|
| 165 |
+
table = make_table(["Groups"] + columns, "groups")
|
| 166 |
+
self.run.log({"evaluation/group_eval_results": table}, step=self.step)
|
| 167 |
+
|
| 168 |
+
def _log_results_as_artifact(self) -> None:
|
| 169 |
+
"""Log results as JSON artifact to W&B."""
|
| 170 |
+
import wandb
|
| 171 |
+
|
| 172 |
+
dumped = json.dumps(
|
| 173 |
+
self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False
|
| 174 |
+
)
|
| 175 |
+
artifact = wandb.Artifact("results", type="eval_results")
|
| 176 |
+
with artifact.new_file("results.json", mode="w", encoding="utf-8") as f:
|
| 177 |
+
f.write(dumped)
|
| 178 |
+
self.run.log_artifact(artifact)
|
| 179 |
+
|
| 180 |
+
def log_eval_result(self) -> None:
|
| 181 |
+
"""Log evaluation results to W&B."""
|
| 182 |
+
# Log configs to wandb
|
| 183 |
+
configs = self._get_config()
|
| 184 |
+
self.run.config.update(configs, allow_val_change=self.step is not None)
|
| 185 |
+
|
| 186 |
+
wandb_summary, self.wandb_results = self._sanitize_results_dict()
|
| 187 |
+
# update wandb.run.summary with items that were removed
|
| 188 |
+
self.run.summary.update(wandb_summary)
|
| 189 |
+
# Log the evaluation metrics to wandb
|
| 190 |
+
self.run.log(self.wandb_results, step=self.step)
|
| 191 |
+
# Log the evaluation metrics as W&B Table
|
| 192 |
+
self._log_results_as_table()
|
| 193 |
+
# Log the results dict as json to W&B Artifacts
|
| 194 |
+
self._log_results_as_artifact()
|
| 195 |
+
|
| 196 |
+
def _generate_dataset(
|
| 197 |
+
self, data: List[Dict[str, Any]], config: Dict[str, Any]
|
| 198 |
+
) -> pd.DataFrame:
|
| 199 |
+
"""Generate a dataset from evaluation data.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
data (List[Dict[str, Any]]): The data to generate a dataset for.
|
| 203 |
+
config (Dict[str, Any]): The configuration of the task.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
pd.DataFrame: A dataframe that is ready to be uploaded to W&B.
|
| 207 |
+
"""
|
| 208 |
+
ids = [x["doc_id"] for x in data]
|
| 209 |
+
labels = [x["target"] for x in data]
|
| 210 |
+
instance = [""] * len(ids)
|
| 211 |
+
resps = [""] * len(ids)
|
| 212 |
+
filtered_resps = [""] * len(ids)
|
| 213 |
+
model_outputs = {}
|
| 214 |
+
|
| 215 |
+
metrics_list = config["metric_list"]
|
| 216 |
+
metrics = {}
|
| 217 |
+
for metric in metrics_list:
|
| 218 |
+
metric = metric.get("metric")
|
| 219 |
+
if metric in ["word_perplexity", "byte_perplexity", "bits_per_byte"]:
|
| 220 |
+
metrics[f"{metric}_loglikelihood"] = [x[metric][0] for x in data]
|
| 221 |
+
if metric in ["byte_perplexity", "bits_per_byte"]:
|
| 222 |
+
metrics[f"{metric}_bytes"] = [x[metric][1] for x in data]
|
| 223 |
+
else:
|
| 224 |
+
metrics[f"{metric}_words"] = [x[metric][1] for x in data]
|
| 225 |
+
else:
|
| 226 |
+
metrics[metric] = [x[metric] for x in data]
|
| 227 |
+
|
| 228 |
+
if config["output_type"] == "loglikelihood":
|
| 229 |
+
instance = [x["arguments"][0][0] for x in data]
|
| 230 |
+
labels = [x["arguments"][0][1] for x in data]
|
| 231 |
+
resps = [
|
| 232 |
+
f"log probability of continuation is {x['resps'][0][0][0]} "
|
| 233 |
+
+ "\n\n"
|
| 234 |
+
+ "continuation will {} generated with greedy sampling".format(
|
| 235 |
+
"not be" if not x["resps"][0][0][1] else "be"
|
| 236 |
+
)
|
| 237 |
+
for x in data
|
| 238 |
+
]
|
| 239 |
+
filtered_resps = [
|
| 240 |
+
f"log probability of continuation is {x['filtered_resps'][0][0]} "
|
| 241 |
+
+ "\n\n"
|
| 242 |
+
+ "continuation will {} generated with greedy sampling".format(
|
| 243 |
+
"not be" if not x["filtered_resps"][0][1] else "be"
|
| 244 |
+
)
|
| 245 |
+
for x in data
|
| 246 |
+
]
|
| 247 |
+
elif config["output_type"] == "multiple_choice":
|
| 248 |
+
instance = [x["arguments"][0][0] for x in data]
|
| 249 |
+
choices = [
|
| 250 |
+
"\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])])
|
| 251 |
+
for x in data
|
| 252 |
+
]
|
| 253 |
+
resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data]
|
| 254 |
+
filtered_resps = [
|
| 255 |
+
np.argmax([n[0] for n in x["filtered_resps"]]) for x in data
|
| 256 |
+
]
|
| 257 |
+
elif config["output_type"] == "loglikelihood_rolling":
|
| 258 |
+
instance = [x["arguments"][0][0] for x in data]
|
| 259 |
+
resps = [x["resps"][0][0] for x in data]
|
| 260 |
+
filtered_resps = [x["filtered_resps"][0] for x in data]
|
| 261 |
+
elif config["output_type"] == "generate_until":
|
| 262 |
+
instance = [x["arguments"][0][0] for x in data]
|
| 263 |
+
resps = [x["resps"][0][0] for x in data]
|
| 264 |
+
filtered_resps = [x["filtered_resps"][0] for x in data]
|
| 265 |
+
|
| 266 |
+
model_outputs["raw_predictions"] = resps
|
| 267 |
+
model_outputs["filtered_predictions"] = filtered_resps
|
| 268 |
+
|
| 269 |
+
df_data = {
|
| 270 |
+
"id": ids,
|
| 271 |
+
"data": instance,
|
| 272 |
+
}
|
| 273 |
+
if config["output_type"] == "multiple_choice":
|
| 274 |
+
df_data["choices"] = choices
|
| 275 |
+
|
| 276 |
+
tmp_data = {
|
| 277 |
+
"input_len": [len(x) for x in instance],
|
| 278 |
+
"labels": labels,
|
| 279 |
+
"output_type": config["output_type"],
|
| 280 |
+
}
|
| 281 |
+
df_data.update(tmp_data)
|
| 282 |
+
df_data.update(model_outputs)
|
| 283 |
+
df_data.update(metrics)
|
| 284 |
+
|
| 285 |
+
return pd.DataFrame(df_data)
|
| 286 |
+
|
| 287 |
+
def _log_samples_as_artifact(
|
| 288 |
+
self, data: List[Dict[str, Any]], task_name: str
|
| 289 |
+
) -> None:
|
| 290 |
+
import wandb
|
| 291 |
+
|
| 292 |
+
# log the samples as an artifact
|
| 293 |
+
dumped = json.dumps(
|
| 294 |
+
data,
|
| 295 |
+
indent=2,
|
| 296 |
+
default=_handle_non_serializable,
|
| 297 |
+
ensure_ascii=False,
|
| 298 |
+
)
|
| 299 |
+
artifact = wandb.Artifact(f"{task_name}", type="samples_by_task")
|
| 300 |
+
with artifact.new_file(
|
| 301 |
+
f"{task_name}_eval_samples.json", mode="w", encoding="utf-8"
|
| 302 |
+
) as f:
|
| 303 |
+
f.write(dumped)
|
| 304 |
+
self.run.log_artifact(artifact)
|
| 305 |
+
# artifact.wait()
|
| 306 |
+
|
| 307 |
+
def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None:
|
| 308 |
+
"""Log evaluation samples to W&B.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
samples (Dict[str, List[Dict[str, Any]]]): Evaluation samples for each task.
|
| 312 |
+
"""
|
| 313 |
+
task_names: List[str] = [
|
| 314 |
+
x for x in self.task_names if x not in self.group_names
|
| 315 |
+
]
|
| 316 |
+
|
| 317 |
+
ungrouped_tasks = []
|
| 318 |
+
tasks_by_groups = {}
|
| 319 |
+
|
| 320 |
+
for task_name in task_names:
|
| 321 |
+
group_names = self.task_configs[task_name].get("group", None)
|
| 322 |
+
if group_names:
|
| 323 |
+
if isinstance(group_names, str):
|
| 324 |
+
group_names = [group_names]
|
| 325 |
+
|
| 326 |
+
for group_name in group_names:
|
| 327 |
+
if not tasks_by_groups.get(group_name):
|
| 328 |
+
tasks_by_groups[group_name] = [task_name]
|
| 329 |
+
else:
|
| 330 |
+
tasks_by_groups[group_name].append(task_name)
|
| 331 |
+
else:
|
| 332 |
+
ungrouped_tasks.append(task_name)
|
| 333 |
+
|
| 334 |
+
for task_name in ungrouped_tasks:
|
| 335 |
+
eval_preds = samples[task_name]
|
| 336 |
+
|
| 337 |
+
# log the samples as a W&B Table
|
| 338 |
+
df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
|
| 339 |
+
self.run.log({f"{task_name}_eval_results": df}, step=self.step)
|
| 340 |
+
|
| 341 |
+
# log the samples as a json file as W&B Artifact
|
| 342 |
+
self._log_samples_as_artifact(eval_preds, task_name)
|
| 343 |
+
|
| 344 |
+
for group, grouped_tasks in tasks_by_groups.items():
|
| 345 |
+
grouped_df = pd.DataFrame()
|
| 346 |
+
for task_name in grouped_tasks:
|
| 347 |
+
eval_preds = samples[task_name]
|
| 348 |
+
df = self._generate_dataset(
|
| 349 |
+
eval_preds, self.task_configs.get(task_name)
|
| 350 |
+
)
|
| 351 |
+
df["group"] = group
|
| 352 |
+
df["task"] = task_name
|
| 353 |
+
grouped_df = pd.concat([grouped_df, df], ignore_index=True)
|
| 354 |
+
|
| 355 |
+
# log the samples as a json file as W&B Artifact
|
| 356 |
+
self._log_samples_as_artifact(eval_preds, task_name)
|
| 357 |
+
|
| 358 |
+
self.run.log({f"{group}_eval_results": grouped_df}, step=self.step)
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/LLaDA.py
ADDED
|
@@ -0,0 +1,786 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from datetime import timedelta
|
| 4 |
+
from typing import Dict, List, Literal, Optional, Tuple, Union, TypeVar
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
import transformers
|
| 9 |
+
import json
|
| 10 |
+
from accelerate import (
|
| 11 |
+
Accelerator,
|
| 12 |
+
InitProcessGroupKwargs,
|
| 13 |
+
)
|
| 14 |
+
from datasets import Dataset
|
| 15 |
+
from accelerate.utils import get_max_memory
|
| 16 |
+
from packaging import version
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
from transformers.models.auto.modeling_auto import (
|
| 20 |
+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
| 21 |
+
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
| 22 |
+
)
|
| 23 |
+
from dllm_eval.api.instance import Instance
|
| 24 |
+
from dllm_eval.api.model import LM, TemplateLM
|
| 25 |
+
from dllm_eval.api.registry import register_model
|
| 26 |
+
from dllm_eval.models.utils import get_dtype, configure_pad_token
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from .hts_sampler import HTSSampler
|
| 30 |
+
except ImportError:
|
| 31 |
+
HTSSampler = None
|
| 32 |
+
|
| 33 |
+
eval_logger = logging.getLogger(__name__)
|
| 34 |
+
T = TypeVar("T", bound="LM")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def add_gumbel_noise(logits, temperature):
|
| 38 |
+
"""Add Gumbel noise for sampling"""
|
| 39 |
+
if temperature == 0.0:
|
| 40 |
+
return logits
|
| 41 |
+
logits = logits.to(torch.float32)
|
| 42 |
+
noise = torch.rand_like(logits, dtype=torch.float32)
|
| 43 |
+
gumbel_noise = (-torch.log(noise)) ** temperature
|
| 44 |
+
return logits.exp() / gumbel_noise
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_num_transfer_tokens(mask_index, steps):
|
| 48 |
+
"""Calculate number of tokens to transfer at each step"""
|
| 49 |
+
mask_num = mask_index.sum(dim=1, keepdim=True)
|
| 50 |
+
base = mask_num // steps
|
| 51 |
+
remainder = mask_num % steps
|
| 52 |
+
num_transfer_tokens = base.expand(-1, steps).clone()
|
| 53 |
+
if remainder.sum() > 0:
|
| 54 |
+
indices = torch.arange(steps, device=mask_index.device)
|
| 55 |
+
mask = indices.unsqueeze(0) < remainder
|
| 56 |
+
num_transfer_tokens[mask] += 1
|
| 57 |
+
return num_transfer_tokens.to(torch.int64)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@torch.no_grad()
|
| 61 |
+
def generate_llada_v1(model, prompt, attention_mask=None, steps=128, gen_length=128,
|
| 62 |
+
block_length=128, temperature=0., cfg_scale=0.,
|
| 63 |
+
remasking='low_confidence', mask_id=126336,
|
| 64 |
+
logits_eos_inf=False, confidence_eos_eot_inf=False):
|
| 65 |
+
"""
|
| 66 |
+
LLaDA v1 generation function
|
| 67 |
+
This is the original generate function from LLaDA v1
|
| 68 |
+
"""
|
| 69 |
+
x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id,
|
| 70 |
+
dtype=torch.long).to(model.device)
|
| 71 |
+
x[:, :prompt.shape[1]] = prompt.clone()
|
| 72 |
+
|
| 73 |
+
if attention_mask is not None:
|
| 74 |
+
attention_mask = torch.cat([
|
| 75 |
+
attention_mask,
|
| 76 |
+
torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype,
|
| 77 |
+
device=model.device)
|
| 78 |
+
], dim=-1)
|
| 79 |
+
|
| 80 |
+
prompt_index = (x != mask_id)
|
| 81 |
+
|
| 82 |
+
assert gen_length % block_length == 0
|
| 83 |
+
num_blocks = gen_length // block_length
|
| 84 |
+
|
| 85 |
+
assert steps % num_blocks == 0
|
| 86 |
+
steps_per_block = steps // num_blocks
|
| 87 |
+
|
| 88 |
+
for num_block in range(num_blocks):
|
| 89 |
+
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length:
|
| 90 |
+
prompt.shape[1] + (num_block + 1) * block_length] == mask_id)
|
| 91 |
+
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)
|
| 92 |
+
|
| 93 |
+
for i in range(steps_per_block):
|
| 94 |
+
mask_index = (x == mask_id)
|
| 95 |
+
|
| 96 |
+
if cfg_scale > 0.:
|
| 97 |
+
un_x = x.clone()
|
| 98 |
+
un_x[prompt_index] = mask_id
|
| 99 |
+
x_ = torch.cat([x, un_x], dim=0)
|
| 100 |
+
if attention_mask is not None:
|
| 101 |
+
attention_mask_ = torch.cat([attention_mask, attention_mask], dim=0)
|
| 102 |
+
logits = model(x_, attention_mask=attention_mask_).logits
|
| 103 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
| 104 |
+
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
| 105 |
+
else:
|
| 106 |
+
logits = model(x, attention_mask=attention_mask).logits
|
| 107 |
+
|
| 108 |
+
if logits_eos_inf:
|
| 109 |
+
logits[:, :, 126081] = -torch.inf
|
| 110 |
+
|
| 111 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
| 112 |
+
x0 = torch.argmax(logits_with_noise, dim=-1)
|
| 113 |
+
|
| 114 |
+
if confidence_eos_eot_inf:
|
| 115 |
+
logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf
|
| 116 |
+
|
| 117 |
+
if remasking == 'low_confidence':
|
| 118 |
+
p = F.softmax(logits, dim=-1)
|
| 119 |
+
x0_p = torch.squeeze(
|
| 120 |
+
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
|
| 121 |
+
elif remasking == 'random':
|
| 122 |
+
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
|
| 123 |
+
else:
|
| 124 |
+
raise NotImplementedError(remasking)
|
| 125 |
+
|
| 126 |
+
x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
|
| 127 |
+
|
| 128 |
+
x0 = torch.where(mask_index, x0, x)
|
| 129 |
+
confidence = torch.where(mask_index, x0_p, -np.inf)
|
| 130 |
+
|
| 131 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
| 132 |
+
for j in range(confidence.shape[0]):
|
| 133 |
+
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
|
| 134 |
+
transfer_index[j, select_index] = True
|
| 135 |
+
|
| 136 |
+
x[transfer_index] = x0[transfer_index]
|
| 137 |
+
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@register_model("LLaDA")
|
| 142 |
+
class LLaDA(TemplateLM):
|
| 143 |
+
AUTO_MODEL_CLASS = transformers.AutoModel
|
| 144 |
+
_DEFAULT_MAX_LENGTH = 20480
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
pretrained: Union[str, transformers.PreTrainedModel],
|
| 149 |
+
backend: Literal["default", "causal", "seq2seq"] = "causal",
|
| 150 |
+
revision: Optional[str] = "main",
|
| 151 |
+
subfolder: Optional[str] = None,
|
| 152 |
+
tokenizer: Optional[
|
| 153 |
+
Union[
|
| 154 |
+
str,
|
| 155 |
+
transformers.PreTrainedTokenizer,
|
| 156 |
+
transformers.PreTrainedTokenizerFast,
|
| 157 |
+
]
|
| 158 |
+
] = None,
|
| 159 |
+
truncation: Optional[bool] = False,
|
| 160 |
+
logits_cache: bool = True,
|
| 161 |
+
max_length: Optional[int] = None,
|
| 162 |
+
device: Optional[str] = "cuda",
|
| 163 |
+
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
| 164 |
+
batch_size: Optional[Union[int]] = 1,
|
| 165 |
+
max_batch_size: Optional[int] = 64,
|
| 166 |
+
trust_remote_code: Optional[bool] = True,
|
| 167 |
+
use_fast_tokenizer: Optional[bool] = True,
|
| 168 |
+
add_bos_token: Optional[bool] = False,
|
| 169 |
+
escape_until: Optional[bool] = False,
|
| 170 |
+
prefix_token_id: Optional[int] = None,
|
| 171 |
+
parallelize: Optional[bool] = False,
|
| 172 |
+
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
| 173 |
+
max_cpu_memory: Optional[Union[int, str]] = None,
|
| 174 |
+
offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
|
| 175 |
+
peft: Optional[str] = None,
|
| 176 |
+
delta: Optional[str] = None,
|
| 177 |
+
autogptq: Optional[Union[bool, str]] = False,
|
| 178 |
+
gptqmodel: Optional[bool] = False,
|
| 179 |
+
gguf_file: Optional[str] = None,
|
| 180 |
+
mc_num: int = 1024,
|
| 181 |
+
remasking: str = "low_confidence",
|
| 182 |
+
mask_id: int = 126336, # LLaDA v1 default mask_id
|
| 183 |
+
is_check_greedy: bool = True,
|
| 184 |
+
assistant_prefix: Optional[str] = None,
|
| 185 |
+
**kwargs,
|
| 186 |
+
) -> None:
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.mc_num = mc_num
|
| 189 |
+
self.mask_id = mask_id
|
| 190 |
+
self.remasking = remasking
|
| 191 |
+
self.pretrained = pretrained
|
| 192 |
+
self.is_check_greedy = is_check_greedy
|
| 193 |
+
self.assistant_prefix = assistant_prefix
|
| 194 |
+
self.add_bos_token = add_bos_token
|
| 195 |
+
self.escape_until = escape_until
|
| 196 |
+
|
| 197 |
+
if not isinstance(pretrained, str):
|
| 198 |
+
eval_logger.warning(
|
| 199 |
+
"`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored."
|
| 200 |
+
)
|
| 201 |
+
assert not parallelize, (
|
| 202 |
+
"`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
|
| 203 |
+
)
|
| 204 |
+
self._model = pretrained
|
| 205 |
+
self._device = self._model.device
|
| 206 |
+
self._config = self._model.config
|
| 207 |
+
gpus = 0
|
| 208 |
+
else:
|
| 209 |
+
assert isinstance(device, str)
|
| 210 |
+
assert isinstance(pretrained, str)
|
| 211 |
+
assert isinstance(batch_size, (int, str))
|
| 212 |
+
gpus = torch.cuda.device_count()
|
| 213 |
+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
|
| 214 |
+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
|
| 215 |
+
if accelerator.num_processes > 1:
|
| 216 |
+
self.accelerator = accelerator
|
| 217 |
+
if "npu" in accelerator.device.type:
|
| 218 |
+
gpus = torch.npu.device_count()
|
| 219 |
+
if not (parallelize or accelerator.num_processes > 1):
|
| 220 |
+
device_list = set(
|
| 221 |
+
["cuda", "cpu"]
|
| 222 |
+
+ [f"cuda:{i}" for i in range(gpus)]
|
| 223 |
+
+ ["mps", "mps:0"]
|
| 224 |
+
+ [f"npu:{i}" for i in range(gpus)]
|
| 225 |
+
)
|
| 226 |
+
if device and device in device_list:
|
| 227 |
+
self._device = torch.device(device)
|
| 228 |
+
eval_logger.info(f"Using device '{device}'")
|
| 229 |
+
if device in ("mps", "mps:0") and version.parse(
|
| 230 |
+
torch.__version__
|
| 231 |
+
) < version.parse("2.1"):
|
| 232 |
+
raise RuntimeError(
|
| 233 |
+
f"mps requires torch >= 2.1. You have {torch.__version__}"
|
| 234 |
+
)
|
| 235 |
+
else:
|
| 236 |
+
eval_logger.info("Device not specified")
|
| 237 |
+
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
|
| 238 |
+
self._device = (
|
| 239 |
+
torch.device("cuda")
|
| 240 |
+
if torch.cuda.is_available()
|
| 241 |
+
else torch.device("cpu")
|
| 242 |
+
)
|
| 243 |
+
else:
|
| 244 |
+
if device != "cuda":
|
| 245 |
+
eval_logger.info(
|
| 246 |
+
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
|
| 247 |
+
)
|
| 248 |
+
self._device = (
|
| 249 |
+
self.accelerator.device
|
| 250 |
+
if hasattr(self, "accelerator")
|
| 251 |
+
else torch.device(device)
|
| 252 |
+
)
|
| 253 |
+
revision = str(revision)
|
| 254 |
+
revision = revision + ("/" + subfolder if subfolder is not None else "")
|
| 255 |
+
self._get_config(
|
| 256 |
+
pretrained,
|
| 257 |
+
revision=revision,
|
| 258 |
+
trust_remote_code=trust_remote_code,
|
| 259 |
+
gguf_file=gguf_file,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
self._get_backend(
|
| 263 |
+
config=self.config, backend=backend, trust_remote_code=trust_remote_code
|
| 264 |
+
)
|
| 265 |
+
self._create_tokenizer(
|
| 266 |
+
pretrained,
|
| 267 |
+
tokenizer,
|
| 268 |
+
revision=revision,
|
| 269 |
+
trust_remote_code=trust_remote_code,
|
| 270 |
+
use_fast_tokenizer=use_fast_tokenizer,
|
| 271 |
+
gguf_file=gguf_file,
|
| 272 |
+
add_bos_token=add_bos_token,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
if isinstance(pretrained, str):
|
| 276 |
+
self._create_model(
|
| 277 |
+
pretrained=pretrained,
|
| 278 |
+
revision=revision,
|
| 279 |
+
dtype=dtype,
|
| 280 |
+
trust_remote_code=trust_remote_code,
|
| 281 |
+
parallelize=parallelize,
|
| 282 |
+
gpus=gpus,
|
| 283 |
+
max_memory_per_gpu=max_memory_per_gpu,
|
| 284 |
+
max_cpu_memory=max_cpu_memory,
|
| 285 |
+
offload_folder=offload_folder,
|
| 286 |
+
peft=peft,
|
| 287 |
+
delta=delta,
|
| 288 |
+
autogptq=autogptq,
|
| 289 |
+
gptqmodel=gptqmodel,
|
| 290 |
+
gguf_file=gguf_file,
|
| 291 |
+
**kwargs,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if isinstance(self.model, torch.nn.Module):
|
| 295 |
+
self.model.eval()
|
| 296 |
+
self.model.tie_weights()
|
| 297 |
+
|
| 298 |
+
self.truncation = truncation
|
| 299 |
+
self.logits_cache = logits_cache
|
| 300 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 301 |
+
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
|
| 302 |
+
self.add_bos_token = add_bos_token
|
| 303 |
+
|
| 304 |
+
if "gemma" in getattr(self.config, "model_type", ""):
|
| 305 |
+
self.add_bos_token = True
|
| 306 |
+
eval_logger.info(
|
| 307 |
+
f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used."
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
self._max_length = max_length
|
| 311 |
+
self.pretrained = pretrained
|
| 312 |
+
self.delta = delta
|
| 313 |
+
self.peft = peft
|
| 314 |
+
self.revision = revision
|
| 315 |
+
self.batch_schedule = 1
|
| 316 |
+
self.batch_sizes = {}
|
| 317 |
+
self.max_batch_size = max_batch_size
|
| 318 |
+
|
| 319 |
+
if str(batch_size).startswith("auto"):
|
| 320 |
+
batch_size = batch_size.split(":")
|
| 321 |
+
self.batch_size_per_gpu = batch_size[0]
|
| 322 |
+
self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
|
| 323 |
+
else:
|
| 324 |
+
self.batch_size_per_gpu = int(batch_size)
|
| 325 |
+
|
| 326 |
+
if isinstance(pretrained, str):
|
| 327 |
+
if gpus >= 1 or str(self.device) == "mps":
|
| 328 |
+
if not (parallelize or autogptq or hasattr(self, "accelerator")):
|
| 329 |
+
try:
|
| 330 |
+
self.model.to(self.device)
|
| 331 |
+
except ValueError:
|
| 332 |
+
eval_logger.debug(
|
| 333 |
+
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided."
|
| 334 |
+
)
|
| 335 |
+
if gpus > 1:
|
| 336 |
+
if hasattr(self, "accelerator") and self.accelerator.num_processes > 1:
|
| 337 |
+
if parallelize:
|
| 338 |
+
eval_logger.warning(
|
| 339 |
+
"You are both using a HF Accelerate `device_map` and launching via `accelerate launch`."
|
| 340 |
+
)
|
| 341 |
+
elif gpus > self.accelerator.num_processes:
|
| 342 |
+
eval_logger.warning(
|
| 343 |
+
"WARNING: The number of total system GPUs does not match the number of spawned processes."
|
| 344 |
+
)
|
| 345 |
+
self._device = torch.device(f"{self.accelerator.device}")
|
| 346 |
+
self._rank = self.accelerator.local_process_index
|
| 347 |
+
self._world_size = self.accelerator.num_processes
|
| 348 |
+
else:
|
| 349 |
+
self._rank = 0
|
| 350 |
+
self._world_size = 1
|
| 351 |
+
else:
|
| 352 |
+
self._rank = 0
|
| 353 |
+
self._world_size = 1
|
| 354 |
+
else:
|
| 355 |
+
eval_logger.warning(
|
| 356 |
+
"Passed an already-initialized model through `pretrained`, assuming single-process call."
|
| 357 |
+
)
|
| 358 |
+
self._rank = 0
|
| 359 |
+
self._world_size = 1
|
| 360 |
+
|
| 361 |
+
self.custom_prefix_token_id = prefix_token_id
|
| 362 |
+
if prefix_token_id is not None:
|
| 363 |
+
eval_logger.info(
|
| 364 |
+
f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
|
| 365 |
+
)
|
| 366 |
+
self.is_first_inference = True
|
| 367 |
+
|
| 368 |
+
if HTSSampler is not None:
|
| 369 |
+
self.hts_sampler = HTSSampler(self.model, self.tokenizer, device=self.device)
|
| 370 |
+
eval_logger.info("HTSSampler initialized successfully.")
|
| 371 |
+
|
| 372 |
+
# Copy all the property and helper methods from LLaDA2
|
| 373 |
+
@property
|
| 374 |
+
def rank(self):
|
| 375 |
+
if hasattr(self, "_rank"):
|
| 376 |
+
return self._rank
|
| 377 |
+
if hasattr(self, "accelerator"):
|
| 378 |
+
return self.accelerator.local_process_index
|
| 379 |
+
return int(os.environ.get("LOCAL_RANK", 0))
|
| 380 |
+
|
| 381 |
+
@property
|
| 382 |
+
def world_size(self):
|
| 383 |
+
if hasattr(self, "_world_size"):
|
| 384 |
+
return self._world_size
|
| 385 |
+
if hasattr(self, "accelerator"):
|
| 386 |
+
return self.accelerator.num_processes
|
| 387 |
+
return int(os.environ.get("WORLD_SIZE", 1))
|
| 388 |
+
|
| 389 |
+
def _get_accelerate_args(
|
| 390 |
+
self,
|
| 391 |
+
parallelize: Optional[bool] = None,
|
| 392 |
+
device_map: Optional[str] = "auto",
|
| 393 |
+
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
| 394 |
+
max_cpu_memory: Optional[Union[int, str]] = None,
|
| 395 |
+
offload_folder: Optional[str] = "./offload",
|
| 396 |
+
gpus: Optional[int] = None,
|
| 397 |
+
) -> dict:
|
| 398 |
+
"""Get accelerate arguments - same as LLaDA2"""
|
| 399 |
+
num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
|
| 400 |
+
if parallelize is None and gpus is not None and gpus > 1:
|
| 401 |
+
parallelize = True
|
| 402 |
+
args = {}
|
| 403 |
+
if parallelize:
|
| 404 |
+
max_memory_all_gpus = get_max_memory()
|
| 405 |
+
if "cpu" in max_memory_all_gpus:
|
| 406 |
+
del max_memory_all_gpus["cpu"]
|
| 407 |
+
max_memory_per_gpu_map = {
|
| 408 |
+
device_idx: max_memory_per_gpu for device_idx in range(len(max_memory_all_gpus))
|
| 409 |
+
} if max_memory_per_gpu is not None else {k: v for k, v in max_memory_all_gpus.items()}
|
| 410 |
+
if hasattr(self, "accelerator"):
|
| 411 |
+
max_memory_per_gpu_map = {
|
| 412 |
+
k: v for k, v in max_memory_all_gpus.items()
|
| 413 |
+
if k % num_local_processes == self.accelerator.process_index % num_local_processes
|
| 414 |
+
}
|
| 415 |
+
args["max_memory"] = max_memory_per_gpu_map
|
| 416 |
+
args["device_map"] = "auto"
|
| 417 |
+
args["offload_folder"] = offload_folder
|
| 418 |
+
if max_cpu_memory is not None:
|
| 419 |
+
args["max_memory"]["cpu"] = max_cpu_memory
|
| 420 |
+
else:
|
| 421 |
+
args["device_map"] = {"": str(self.device)}
|
| 422 |
+
return args
|
| 423 |
+
|
| 424 |
+
@property
|
| 425 |
+
def config(self):
|
| 426 |
+
return self._config
|
| 427 |
+
|
| 428 |
+
@property
|
| 429 |
+
def model(self):
|
| 430 |
+
if hasattr(self, "accelerator"):
|
| 431 |
+
return self.accelerator.unwrap_model(self._model)
|
| 432 |
+
else:
|
| 433 |
+
return self._model
|
| 434 |
+
|
| 435 |
+
@property
|
| 436 |
+
def eot_token_id(self):
|
| 437 |
+
return self.tokenizer.eos_token_id
|
| 438 |
+
|
| 439 |
+
@property
|
| 440 |
+
def prefix_token_id(self):
|
| 441 |
+
if self.custom_prefix_token_id is not None:
|
| 442 |
+
return self.custom_prefix_token_id
|
| 443 |
+
if self.tokenizer.bos_token_id is not None:
|
| 444 |
+
return self.tokenizer.bos_token_id
|
| 445 |
+
return self.tokenizer.eos_token_id
|
| 446 |
+
|
| 447 |
+
@property
|
| 448 |
+
def max_length(self):
|
| 449 |
+
if self._max_length:
|
| 450 |
+
return self._max_length
|
| 451 |
+
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
|
| 452 |
+
for attr in seqlen_config_attrs:
|
| 453 |
+
if hasattr(self.model.config, attr):
|
| 454 |
+
return getattr(self.model.config, attr)
|
| 455 |
+
if hasattr(self.tokenizer, "model_max_length"):
|
| 456 |
+
if self.tokenizer.model_max_length > 1e10:
|
| 457 |
+
return self._DEFAULT_MAX_LENGTH
|
| 458 |
+
return self.tokenizer.model_max_length
|
| 459 |
+
return self._DEFAULT_MAX_LENGTH
|
| 460 |
+
|
| 461 |
+
@property
|
| 462 |
+
def max_gen_toks(self) -> int:
|
| 463 |
+
return 256
|
| 464 |
+
|
| 465 |
+
@property
|
| 466 |
+
def batch_size(self):
|
| 467 |
+
return self.batch_size_per_gpu
|
| 468 |
+
|
| 469 |
+
@property
|
| 470 |
+
def device(self):
|
| 471 |
+
return self._device
|
| 472 |
+
|
| 473 |
+
@property
|
| 474 |
+
def tokenizer_name(self) -> str:
|
| 475 |
+
return self.tokenizer.name_or_path.replace("/", "__")
|
| 476 |
+
|
| 477 |
+
def _get_backend(self, config, backend, trust_remote_code):
|
| 478 |
+
"""Get backend type - same as LLaDA2"""
|
| 479 |
+
assert backend in ["default", "causal", "seq2seq"]
|
| 480 |
+
if backend != "default":
|
| 481 |
+
self.backend = backend
|
| 482 |
+
eval_logger.info(f"Overrode HF model backend type, and using type '{self.backend}'")
|
| 483 |
+
else:
|
| 484 |
+
if getattr(config, "model_type") in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
| 485 |
+
self.backend = "seq2seq"
|
| 486 |
+
elif getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
| 487 |
+
self.backend = "causal"
|
| 488 |
+
else:
|
| 489 |
+
eval_logger.warning("HF model type is neither CausalLM nor Seq2SeqLM. Assuming CausalLM.")
|
| 490 |
+
self.backend = "causal"
|
| 491 |
+
|
| 492 |
+
def _get_config(self, pretrained, revision, trust_remote_code, gguf_file):
|
| 493 |
+
"""Get model config - same as LLaDA2"""
|
| 494 |
+
self._config = transformers.AutoConfig.from_pretrained(
|
| 495 |
+
pretrained, revision=revision, trust_remote_code=trust_remote_code
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
def _create_model(self, pretrained, revision, dtype, trust_remote_code, parallelize,
|
| 499 |
+
gpus, max_memory_per_gpu, max_cpu_memory, offload_folder,
|
| 500 |
+
peft, delta, autogptq, gptqmodel, gguf_file, **kwargs):
|
| 501 |
+
"""Create model - same as LLaDA2"""
|
| 502 |
+
if autogptq or gptqmodel:
|
| 503 |
+
raise NotImplementedError("Quantization options are not implemented.")
|
| 504 |
+
model_dtype = get_dtype(dtype)
|
| 505 |
+
eval_logger.info(f"Loading model with dtype: {model_dtype}")
|
| 506 |
+
model_kwargs = kwargs if kwargs else {}
|
| 507 |
+
if not parallelize:
|
| 508 |
+
model_kwargs.update(
|
| 509 |
+
self._get_accelerate_args(
|
| 510 |
+
parallelize=parallelize,
|
| 511 |
+
gpus=gpus,
|
| 512 |
+
max_memory_per_gpu=max_memory_per_gpu,
|
| 513 |
+
max_cpu_memory=max_cpu_memory,
|
| 514 |
+
offload_folder=offload_folder
|
| 515 |
+
)
|
| 516 |
+
)
|
| 517 |
+
self._model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 518 |
+
pretrained, revision=revision, torch_dtype=model_dtype,
|
| 519 |
+
trust_remote_code=trust_remote_code, **model_kwargs
|
| 520 |
+
)
|
| 521 |
+
if peft:
|
| 522 |
+
from peft import PeftModel
|
| 523 |
+
eval_logger.info(f"Loading PEFT model from {peft}")
|
| 524 |
+
self._model = PeftModel.from_pretrained(self._model, peft, torch_dtype=model_dtype)
|
| 525 |
+
if not parallelize:
|
| 526 |
+
self._model = self._model.to(self.device)
|
| 527 |
+
self._model = self._model.to(torch.bfloat16)
|
| 528 |
+
self._model.eval()
|
| 529 |
+
|
| 530 |
+
def _create_tokenizer(self, pretrained, tokenizer, revision, trust_remote_code,
|
| 531 |
+
use_fast_tokenizer, gguf_file, add_bos_token):
|
| 532 |
+
"""Create tokenizer - same as LLaDA2"""
|
| 533 |
+
kwargs = {
|
| 534 |
+
"revision": revision,
|
| 535 |
+
"trust_remote_code": trust_remote_code,
|
| 536 |
+
"use_fast": use_fast_tokenizer
|
| 537 |
+
}
|
| 538 |
+
if add_bos_token:
|
| 539 |
+
kwargs["add_bos_token"] = True
|
| 540 |
+
if tokenizer:
|
| 541 |
+
if isinstance(tokenizer, str):
|
| 542 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer, **kwargs)
|
| 543 |
+
else:
|
| 544 |
+
self.tokenizer = tokenizer
|
| 545 |
+
else:
|
| 546 |
+
model_name = pretrained if isinstance(pretrained, str) else self.model.name_or_path
|
| 547 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, **kwargs)
|
| 548 |
+
|
| 549 |
+
def tok_encode(self, string, left_truncate_len=None, add_special_tokens=None):
|
| 550 |
+
"""Tokenize string - same as LLaDA2"""
|
| 551 |
+
special_tokens_kwargs = {}
|
| 552 |
+
if add_special_tokens is None:
|
| 553 |
+
if self.backend == "causal":
|
| 554 |
+
special_tokens_kwargs["add_special_tokens"] = self.add_bos_token
|
| 555 |
+
else:
|
| 556 |
+
special_tokens_kwargs["add_special_tokens"] = add_special_tokens
|
| 557 |
+
encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
|
| 558 |
+
if left_truncate_len:
|
| 559 |
+
encoding = encoding[-left_truncate_len:]
|
| 560 |
+
return encoding
|
| 561 |
+
|
| 562 |
+
def tok_batch_encode(self, strings, padding_side="left", left_truncate_len=None, truncation=False):
|
| 563 |
+
"""Batch tokenize - same as LLaDA2"""
|
| 564 |
+
old_padding_side = self.tokenizer.padding_side
|
| 565 |
+
self.tokenizer.padding_side = padding_side
|
| 566 |
+
add_special_tokens = {"add_special_tokens": self.add_bos_token} if self.backend == "causal" else {}
|
| 567 |
+
encoding = self.tokenizer(
|
| 568 |
+
strings, truncation=truncation, padding="longest",
|
| 569 |
+
return_tensors="pt", **add_special_tokens
|
| 570 |
+
)
|
| 571 |
+
if left_truncate_len and encoding["input_ids"].size(1) > left_truncate_len:
|
| 572 |
+
eval_logger.warning(f"Left-truncating from {encoding['input_ids'].size(1)} to {left_truncate_len} tokens.")
|
| 573 |
+
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
|
| 574 |
+
encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:]
|
| 575 |
+
self.tokenizer.padding_side = old_padding_side
|
| 576 |
+
return encoding["input_ids"].to(self.device), encoding["attention_mask"].to(self.device)
|
| 577 |
+
|
| 578 |
+
def tok_decode(self, tokens, skip_special_tokens=False):
|
| 579 |
+
"""Decode tokens - same as LLaDA2"""
|
| 580 |
+
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
| 581 |
+
|
| 582 |
+
def _model_call(self, inps, attn_mask=None, labels=None):
|
| 583 |
+
"""Model forward call - same as LLaDA2"""
|
| 584 |
+
with torch.no_grad():
|
| 585 |
+
if self.backend == "seq2seq":
|
| 586 |
+
return self.model(input_ids=inps, attention_mask=attn_mask, labels=labels).logits
|
| 587 |
+
else:
|
| 588 |
+
return self.model(inps, attention_mask=attn_mask).logits
|
| 589 |
+
|
| 590 |
+
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
|
| 591 |
+
raise NotImplementedError
|
| 592 |
+
|
| 593 |
+
def loglikelihood_rolling(
|
| 594 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
| 595 |
+
) -> List[float]:
|
| 596 |
+
raise NotImplementedError
|
| 597 |
+
|
| 598 |
+
def loglikelihood(self, requests):
|
| 599 |
+
raise NotImplementedError
|
| 600 |
+
|
| 601 |
+
def generate_until(self, requests: List[Instance]) -> List[str]:
|
| 602 |
+
"""Generate until - adapted for LLaDA v1 """
|
| 603 |
+
res = []
|
| 604 |
+
gen_kwargs = requests[0].args[1]
|
| 605 |
+
use_hts = gen_kwargs.get("use_hts", False)
|
| 606 |
+
|
| 607 |
+
realtime_output = gen_kwargs.get("realtime_output", "realtime_hts_results.jsonl")
|
| 608 |
+
baseline_realtime_output = gen_kwargs.get("realtime_output", "realtime_baseline_results.jsonl")
|
| 609 |
+
|
| 610 |
+
if not use_hts and "realtime_output" not in gen_kwargs:
|
| 611 |
+
baseline_realtime_output = "realtime_baseline_results.jsonl"
|
| 612 |
+
|
| 613 |
+
if not use_hts:
|
| 614 |
+
bar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Running Baseline (LLaDA v1)")
|
| 615 |
+
|
| 616 |
+
for req in requests:
|
| 617 |
+
prompt_text = req.args[0]
|
| 618 |
+
local_gen_kwargs = req.args[1] if len(req.args) > 1 else {}
|
| 619 |
+
|
| 620 |
+
context_enc, _ = self.tok_batch_encode([prompt_text])
|
| 621 |
+
|
| 622 |
+
final_codes, stats = self.hts_sampler.generate_hts(
|
| 623 |
+
prompt_text=prompt_text,
|
| 624 |
+
input_ids=context_enc,
|
| 625 |
+
initial_N=1,
|
| 626 |
+
final_K=1,
|
| 627 |
+
hts_survivor_k=1,
|
| 628 |
+
hts_mode=False,
|
| 629 |
+
hts_start_pct=0.0,
|
| 630 |
+
hts_end_pct=0.0,
|
| 631 |
+
decay_factor=1.5,
|
| 632 |
+
pruning_interval=0,
|
| 633 |
+
reward_mode="confidence",
|
| 634 |
+
task_type=local_gen_kwargs.get("task_type", "code"),
|
| 635 |
+
steps=int(local_gen_kwargs.get("steps", 32)),
|
| 636 |
+
gen_length=int(local_gen_kwargs.get("gen_length", 512)),
|
| 637 |
+
block_length=int(local_gen_kwargs.get("block_length", 32)),
|
| 638 |
+
temperature=float(local_gen_kwargs.get("temperature", 0.0)),
|
| 639 |
+
top_p=float(local_gen_kwargs.get("top_p", 0.95)),
|
| 640 |
+
top_k=local_gen_kwargs.get("top_k", None),
|
| 641 |
+
threshold=float(local_gen_kwargs.get("threshold", 0.85)),
|
| 642 |
+
mask_id=self.mask_id,
|
| 643 |
+
eos_id=self.eot_token_id,
|
| 644 |
+
until=local_gen_kwargs.get("until", []),
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
processed_codes = []
|
| 648 |
+
for code in final_codes:
|
| 649 |
+
code = code.strip()
|
| 650 |
+
if not self.escape_until:
|
| 651 |
+
until_terms = local_gen_kwargs.get("until", [])
|
| 652 |
+
for term in until_terms:
|
| 653 |
+
if len(term) > 0 and term in code:
|
| 654 |
+
code = code.split(term)[0]
|
| 655 |
+
processed_codes.append(code)
|
| 656 |
+
|
| 657 |
+
final_choice = processed_codes[0] if processed_codes else ""
|
| 658 |
+
res.append(final_choice)
|
| 659 |
+
|
| 660 |
+
target_val = getattr(req, "target", None)
|
| 661 |
+
if target_val is None or target_val == "N/A":
|
| 662 |
+
if "test" in req.doc and "entry_point" in req.doc:
|
| 663 |
+
target_val = req.doc["test"] + "\ncheck(" + req.doc["entry_point"] + ")"
|
| 664 |
+
else:
|
| 665 |
+
target_val = req.doc.get("answer", req.doc.get("solution", "N/A"))
|
| 666 |
+
|
| 667 |
+
output_dir = os.path.dirname(baseline_realtime_output)
|
| 668 |
+
if output_dir:
|
| 669 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 670 |
+
with open(baseline_realtime_output, "a", encoding="utf-8") as f:
|
| 671 |
+
all_resps = [[code] for code in processed_codes]
|
| 672 |
+
output_data = {
|
| 673 |
+
"doc": req.doc,
|
| 674 |
+
"target": target_val,
|
| 675 |
+
"resps": all_resps,
|
| 676 |
+
"prompt": prompt_text,
|
| 677 |
+
"entropy_history": stats.get("entropy_history", []),
|
| 678 |
+
"pruning_history": stats.get("pruning_history", []),
|
| 679 |
+
"final_scores": stats.get("final_scores", []),
|
| 680 |
+
"all_trajectories": stats.get("all_trajectories", []),
|
| 681 |
+
"nfe": stats.get("nfe", 0),
|
| 682 |
+
"first_block_nfe": stats.get("first_block_nfe", 0),
|
| 683 |
+
"svf_calls": stats.get("svf_calls", 0),
|
| 684 |
+
"total_steps": stats.get("total_steps", 0),
|
| 685 |
+
"num_gen_blocks": stats.get("num_gen_blocks", []),
|
| 686 |
+
"steps_per_block": stats.get("steps_per_block", [])
|
| 687 |
+
}
|
| 688 |
+
f.write(json.dumps(output_data, ensure_ascii=False) + "\n")
|
| 689 |
+
f.flush()
|
| 690 |
+
|
| 691 |
+
bar.update(1)
|
| 692 |
+
bar.close()
|
| 693 |
+
|
| 694 |
+
else:
|
| 695 |
+
bar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Running HTS+SVF (LLaDA v1)")
|
| 696 |
+
for req in requests:
|
| 697 |
+
prompt_text = req.args[0]
|
| 698 |
+
local_gen_kwargs = req.args[1] if len(req.args) > 1 else {}
|
| 699 |
+
context_enc, _ = self.tok_batch_encode([prompt_text])
|
| 700 |
+
|
| 701 |
+
p_interval = int(local_gen_kwargs.get("pruning_interval", 0))
|
| 702 |
+
|
| 703 |
+
final_codes, stats = self.hts_sampler.generate_hts(
|
| 704 |
+
prompt_text=prompt_text,
|
| 705 |
+
input_ids=context_enc,
|
| 706 |
+
initial_N=int(local_gen_kwargs.get("hts_N", 4)),
|
| 707 |
+
final_K=int(local_gen_kwargs.get("final_K", 1)),
|
| 708 |
+
hts_survivor_k=int(local_gen_kwargs.get("hts_survivor_k", 4)),
|
| 709 |
+
hts_mode=local_gen_kwargs.get("hts_mode", True),
|
| 710 |
+
hts_start_pct=float(local_gen_kwargs.get("hts_start_pct", 0.1)),
|
| 711 |
+
hts_end_pct=float(local_gen_kwargs.get("hts_end_pct", 0.6)),
|
| 712 |
+
decay_factor=float(local_gen_kwargs.get("decay_factor", 1.5)),
|
| 713 |
+
pruning_interval=p_interval,
|
| 714 |
+
reward_mode=local_gen_kwargs.get("reward_mode", "svf"),
|
| 715 |
+
task_type=local_gen_kwargs.get("task_type", "code"),
|
| 716 |
+
steps=int(local_gen_kwargs.get("steps", 32)),
|
| 717 |
+
gen_length=int(local_gen_kwargs.get("gen_length", 512)),
|
| 718 |
+
block_length=int(local_gen_kwargs.get("block_length", 32)),
|
| 719 |
+
temperature=float(local_gen_kwargs.get("temperature", 0.7)),
|
| 720 |
+
top_p=float(local_gen_kwargs.get("top_p", 0.95)),
|
| 721 |
+
top_k=local_gen_kwargs.get("top_k", None),
|
| 722 |
+
threshold=float(local_gen_kwargs.get("threshold", 0.85)),
|
| 723 |
+
mask_id=self.mask_id,
|
| 724 |
+
eos_id=self.eot_token_id,
|
| 725 |
+
until=local_gen_kwargs.get("until", []),
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
processed_codes = []
|
| 729 |
+
for code in final_codes:
|
| 730 |
+
code = code.strip()
|
| 731 |
+
if not self.escape_until:
|
| 732 |
+
until_terms = local_gen_kwargs.get("until", [])
|
| 733 |
+
for term in until_terms:
|
| 734 |
+
if len(term) > 0 and term in code:
|
| 735 |
+
code = code.split(term)[0]
|
| 736 |
+
processed_codes.append(code)
|
| 737 |
+
|
| 738 |
+
final_choice = processed_codes[0]
|
| 739 |
+
res.append(final_choice)
|
| 740 |
+
|
| 741 |
+
target_val = getattr(req, "target", None)
|
| 742 |
+
if target_val is None or target_val == "N/A":
|
| 743 |
+
if "test" in req.doc and "entry_point" in req.doc:
|
| 744 |
+
target_val = req.doc["test"] + "\ncheck(" + req.doc["entry_point"] + ")"
|
| 745 |
+
else:
|
| 746 |
+
target_val = req.doc.get("answer", req.doc.get("solution", "N/A"))
|
| 747 |
+
|
| 748 |
+
output_dir = os.path.dirname(realtime_output)
|
| 749 |
+
if output_dir:
|
| 750 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 751 |
+
with open(realtime_output, "a", encoding="utf-8") as f:
|
| 752 |
+
all_resps = [[code] for code in processed_codes]
|
| 753 |
+
output_data = {
|
| 754 |
+
"doc": req.doc,
|
| 755 |
+
"target": target_val,
|
| 756 |
+
"resps": all_resps,
|
| 757 |
+
"prompt": prompt_text,
|
| 758 |
+
"entropy_history": stats.get("entropy_history", []),
|
| 759 |
+
"pruning_history": stats.get("pruning_history", []),
|
| 760 |
+
"final_scores": stats.get("final_scores", []),
|
| 761 |
+
"all_trajectories": stats.get("all_trajectories", []),
|
| 762 |
+
"nfe": stats.get("nfe", 0),
|
| 763 |
+
"first_block_nfe": stats.get("first_block_nfe", 0),
|
| 764 |
+
"svf_calls": stats.get("svf_calls", 0),
|
| 765 |
+
"total_steps": stats.get("total_steps", 0),
|
| 766 |
+
"num_gen_blocks": stats.get("num_gen_blocks", []),
|
| 767 |
+
"steps_per_block": stats.get("steps_per_block", [])
|
| 768 |
+
}
|
| 769 |
+
f.write(json.dumps(output_data, ensure_ascii=False) + "\n")
|
| 770 |
+
f.flush()
|
| 771 |
+
|
| 772 |
+
bar.update(1)
|
| 773 |
+
bar.close()
|
| 774 |
+
|
| 775 |
+
return res
|
| 776 |
+
|
| 777 |
+
def apply_chat_template(
|
| 778 |
+
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
|
| 779 |
+
) -> str:
|
| 780 |
+
"""Apply chat template - same as LLaDA2"""
|
| 781 |
+
chat_templated = self.tokenizer.apply_chat_template(
|
| 782 |
+
chat_history, tokenize=False, add_generation_prompt=add_generation_prompt
|
| 783 |
+
)
|
| 784 |
+
if self.assistant_prefix:
|
| 785 |
+
chat_templated += self.assistant_prefix
|
| 786 |
+
return chat_templated
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import (
|
| 2 |
+
LLaDA,
|
| 3 |
+
huggingface,
|
| 4 |
+
)
|
| 5 |
+
# from .configuration_llada import LLaDAConfig
|
| 6 |
+
# from .modeling_llada import LLaDAModelLM
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
# enable hf hub transfer if available
|
| 11 |
+
import hf_transfer # type: ignore # noqa
|
| 12 |
+
import huggingface_hub.constants # type: ignore
|
| 13 |
+
|
| 14 |
+
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# __all__ = ['LLaDAConfig', 'LLaDAModelLM']
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/dummy.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
from dllm_eval.api.model import LM
|
| 6 |
+
from dllm_eval.api.registry import register_model
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@register_model("dummy")
|
| 10 |
+
class DummyLM(LM):
|
| 11 |
+
def __init__(self) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
@classmethod
|
| 15 |
+
def create_from_arg_string(cls, arg_string, additional_config=None):
|
| 16 |
+
return cls()
|
| 17 |
+
|
| 18 |
+
def loglikelihood(self, requests, disable_tqdm: bool = False):
|
| 19 |
+
res = []
|
| 20 |
+
|
| 21 |
+
for _ in tqdm(requests, disable=disable_tqdm):
|
| 22 |
+
res.append((-random.random(), False))
|
| 23 |
+
|
| 24 |
+
return res
|
| 25 |
+
|
| 26 |
+
def generate_until(self, requests, disable_tqdm: bool = False):
|
| 27 |
+
res = []
|
| 28 |
+
|
| 29 |
+
for request in tqdm(requests, disable=disable_tqdm):
|
| 30 |
+
res.append("lol")
|
| 31 |
+
assert request.arguments[0].strip() != ""
|
| 32 |
+
|
| 33 |
+
return res
|
| 34 |
+
|
| 35 |
+
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
|
| 36 |
+
res = []
|
| 37 |
+
|
| 38 |
+
for _ in tqdm(requests, disable=disable_tqdm):
|
| 39 |
+
res.append(-random.random())
|
| 40 |
+
|
| 41 |
+
return res
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/hts_sampler.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
from .verifier import CodeVerifier
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class HTSSampler:
|
| 12 |
+
def __init__(self, model, tokenizer, device="cuda"):
|
| 13 |
+
self.model = model
|
| 14 |
+
self.tokenizer = tokenizer
|
| 15 |
+
self.device = device
|
| 16 |
+
self.verifier = CodeVerifier(model, tokenizer, device)
|
| 17 |
+
|
| 18 |
+
def _get_num_transfer_tokens(self, block_length, steps):
|
| 19 |
+
if steps == 0: return torch.tensor([], dtype=torch.int64)
|
| 20 |
+
base = block_length // steps
|
| 21 |
+
remainder = block_length % steps
|
| 22 |
+
num_transfer_tokens = torch.full((steps,), base, dtype=torch.int64)
|
| 23 |
+
num_transfer_tokens[:remainder] += 1
|
| 24 |
+
return num_transfer_tokens
|
| 25 |
+
|
| 26 |
+
def _sample_with_temperature(self, logits, temperature, top_k, top_p):
|
| 27 |
+
logits = logits.to(torch.float32)
|
| 28 |
+
|
| 29 |
+
orig_probs = torch.softmax(logits, dim=-1)
|
| 30 |
+
x0_p, _ = torch.max(orig_probs, dim=-1)
|
| 31 |
+
|
| 32 |
+
if temperature > 0.0:
|
| 33 |
+
noise = torch.rand_like(logits, dtype=torch.float32)
|
| 34 |
+
gumbel_noise = -torch.log(-torch.log(noise + 1e-10) + 1e-10)
|
| 35 |
+
logits = logits / temperature + gumbel_noise
|
| 36 |
+
|
| 37 |
+
if top_k is not None and top_k > 0:
|
| 38 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 39 |
+
logits[indices_to_remove] = -float('Inf')
|
| 40 |
+
|
| 41 |
+
x0 = torch.argmax(logits, dim=-1)
|
| 42 |
+
|
| 43 |
+
return x0, x0_p
|
| 44 |
+
|
| 45 |
+
def _safe_scalar(self, val):
|
| 46 |
+
if isinstance(val, torch.Tensor):
|
| 47 |
+
if val.numel() > 1: return val.mean().item()
|
| 48 |
+
return val.item()
|
| 49 |
+
return float(val)
|
| 50 |
+
|
| 51 |
+
def _analyze_structure(self, text, task_type="code"):
|
| 52 |
+
score = 0.0
|
| 53 |
+
stripped = text.strip()
|
| 54 |
+
if task_type == "code":
|
| 55 |
+
if len(stripped) < 5: return -0.1
|
| 56 |
+
keywords = ["return", "print", "yield", "lambda", "class ", "def "]
|
| 57 |
+
if any(k in stripped for k in keywords): score += 0.05
|
| 58 |
+
if ":" in stripped: score += 0.02
|
| 59 |
+
if " " in text: score += 0.03
|
| 60 |
+
elif task_type == "math":
|
| 61 |
+
if "\\boxed{" in stripped: score += 0.1
|
| 62 |
+
if "The answer is" in stripped: score += 0.05
|
| 63 |
+
if len(stripped) < 10: return -0.1
|
| 64 |
+
if "Step" in stripped and stripped.count("Step") > 15: score -= 0.2
|
| 65 |
+
return score
|
| 66 |
+
|
| 67 |
+
def _chunked_forward(self, x, chunk_size=96, slice_indices=None):
|
| 68 |
+
total_batch = x.shape[0]
|
| 69 |
+
logits_list = []
|
| 70 |
+
for i in range(0, total_batch, chunk_size):
|
| 71 |
+
end_idx = min(i + chunk_size, total_batch)
|
| 72 |
+
sub_x = x[i:end_idx]
|
| 73 |
+
sub_mask = torch.ones_like(sub_x, device=self.device)
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
outputs = self.model(input_ids=sub_x, attention_mask=sub_mask)
|
| 76 |
+
sub_logits = outputs.logits
|
| 77 |
+
if slice_indices is not None:
|
| 78 |
+
s_start, s_end = slice_indices
|
| 79 |
+
sub_logits = sub_logits[:, s_start:s_end, :]
|
| 80 |
+
logits_list.append(sub_logits.detach().clone())
|
| 81 |
+
return torch.cat(logits_list, dim=0)
|
| 82 |
+
|
| 83 |
+
def _branch_and_resample(self, x, conf_scores, survivor_indices, target_width, mask_id,
|
| 84 |
+
prompt_length, resample_window=5, task_type="code"):
|
| 85 |
+
num_survivors = len(survivor_indices)
|
| 86 |
+
if num_survivors == 0: return x[:target_width].clone(), conf_scores[:target_width].clone()
|
| 87 |
+
|
| 88 |
+
if task_type == "math": resample_window = 6
|
| 89 |
+
elif task_type == "reasoning": resample_window = 6
|
| 90 |
+
elif task_type == "code": resample_window = 6
|
| 91 |
+
|
| 92 |
+
base_repeat = target_width // num_survivors
|
| 93 |
+
remainder = target_width % num_survivors
|
| 94 |
+
new_x_list = []
|
| 95 |
+
new_conf_list = []
|
| 96 |
+
|
| 97 |
+
for i in range(num_survivors):
|
| 98 |
+
count = base_repeat + (1 if i < remainder else 0)
|
| 99 |
+
if count == 0: continue
|
| 100 |
+
|
| 101 |
+
survivor_x = x[survivor_indices[i]]
|
| 102 |
+
survivor_conf = conf_scores[survivor_indices[i]]
|
| 103 |
+
|
| 104 |
+
new_x_list.append(survivor_x.unsqueeze(0))
|
| 105 |
+
new_conf_list.append(survivor_conf.unsqueeze(0))
|
| 106 |
+
|
| 107 |
+
if count > 1:
|
| 108 |
+
gen_part = survivor_x[prompt_length:]
|
| 109 |
+
gen_conf = survivor_conf[prompt_length:]
|
| 110 |
+
non_mask_indices = (gen_part != mask_id).nonzero(as_tuple=True)[0]
|
| 111 |
+
|
| 112 |
+
for _ in range(count - 1):
|
| 113 |
+
perturbed_x = survivor_x.clone()
|
| 114 |
+
perturbed_conf = survivor_conf.clone()
|
| 115 |
+
|
| 116 |
+
if len(non_mask_indices) > 0:
|
| 117 |
+
pool_size = min(resample_window * 2, len(non_mask_indices))
|
| 118 |
+
current_token_confs = gen_conf[non_mask_indices]
|
| 119 |
+
|
| 120 |
+
_, candidate_indices = torch.topk(current_token_confs, k=pool_size, largest=False)
|
| 121 |
+
|
| 122 |
+
num_to_perturb = min(resample_window, pool_size)
|
| 123 |
+
rand_indices = torch.randperm(pool_size, device=self.device)[:num_to_perturb]
|
| 124 |
+
selected_sub_indices = candidate_indices[rand_indices]
|
| 125 |
+
|
| 126 |
+
target_indices_in_x = prompt_length + non_mask_indices[selected_sub_indices]
|
| 127 |
+
perturbed_x[target_indices_in_x] = mask_id
|
| 128 |
+
perturbed_conf[target_indices_in_x] = 0.0
|
| 129 |
+
|
| 130 |
+
new_x_list.append(perturbed_x.unsqueeze(0))
|
| 131 |
+
new_conf_list.append(perturbed_conf.unsqueeze(0))
|
| 132 |
+
|
| 133 |
+
return torch.cat(new_x_list, dim=0), torch.cat(new_conf_list, dim=0)
|
| 134 |
+
|
| 135 |
+
@torch.no_grad()
|
| 136 |
+
def generate_hts(self, prompt_text, input_ids, problem_data=None,
|
| 137 |
+
initial_N=1, final_K=1, survivor_K=None,
|
| 138 |
+
prune_step_pct=0.0, reward_mode="confidence",
|
| 139 |
+
temperature=0.7, block_length=32, steps=64, gen_length=1024,
|
| 140 |
+
top_p=0.95, top_k=None, minimal_topk=1, threshold=0.9,
|
| 141 |
+
eos_id=156892, mask_id=156895,
|
| 142 |
+
hts_mode=False, hts_start_pct=0.1, hts_end_pct=0.6, decay_factor=1.5,
|
| 143 |
+
hts_survivor_k=4, task_type="code", until=None, pruning_interval=0):
|
| 144 |
+
|
| 145 |
+
input_ids = input_ids.to(self.device)
|
| 146 |
+
if input_ids.shape[0] == 1: input_ids = input_ids.repeat(initial_N, 1)
|
| 147 |
+
|
| 148 |
+
schedule_map = {}
|
| 149 |
+
ts_start, tr_end = 0, 0
|
| 150 |
+
if not hts_mode:
|
| 151 |
+
final_K_list = [final_K] if not isinstance(final_K, list) else final_K
|
| 152 |
+
prune_pct_list = [prune_step_pct] if not isinstance(prune_step_pct, list) else prune_step_pct
|
| 153 |
+
survivor_K_list = final_K_list if survivor_K is None else ([survivor_K] if not isinstance(survivor_K, list) else survivor_K)
|
| 154 |
+
if len(survivor_K_list) < len(final_K_list): survivor_K_list.extend(final_K_list[len(survivor_K_list):])
|
| 155 |
+
for pct, width, parents in zip(prune_pct_list, final_K_list, survivor_K_list):
|
| 156 |
+
if pct > 0:
|
| 157 |
+
s = int(steps * pct)
|
| 158 |
+
schedule_map[s] = (width, parents)
|
| 159 |
+
else:
|
| 160 |
+
final_K_list = [final_K] if not isinstance(final_K, int) else [final_K]
|
| 161 |
+
ts_start, tr_end = int(steps * hts_start_pct), int(steps * hts_end_pct)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
prompt_length = input_ids.shape[1]
|
| 165 |
+
num_blocks = (gen_length + block_length - 1) // block_length
|
| 166 |
+
total_length = prompt_length + num_blocks * block_length
|
| 167 |
+
|
| 168 |
+
x = torch.full((initial_N, total_length), mask_id, dtype=torch.long, device=self.device)
|
| 169 |
+
x[:, :prompt_length] = input_ids.clone()
|
| 170 |
+
|
| 171 |
+
conf_scores = torch.zeros((initial_N, total_length), dtype=torch.float32, device=self.device)
|
| 172 |
+
conf_scores[:, :prompt_length] = 1.0
|
| 173 |
+
|
| 174 |
+
prefill_blocks = 0
|
| 175 |
+
num_gen_blocks = num_blocks
|
| 176 |
+
current_bsz = initial_N
|
| 177 |
+
|
| 178 |
+
next_allowed_pruning_step = ts_start if hts_mode else 0
|
| 179 |
+
|
| 180 |
+
stats = {
|
| 181 |
+
"initial_n": initial_N, "final_k": final_K_list[-1],
|
| 182 |
+
"pruning_history": [], "entropy_history": [], "nfe": 0.0,
|
| 183 |
+
"svf_calls": 0, "final_scores": [], "total_steps": steps,
|
| 184 |
+
"first_block_nfe": 0.0, "num_gen_blocks": [], "steps_per_block": []
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
for num_block in range(num_gen_blocks):
|
| 188 |
+
stats["num_gen_blocks"].append(num_block)
|
| 189 |
+
|
| 190 |
+
window_start = prompt_length + num_block * block_length
|
| 191 |
+
window_end = window_start + block_length
|
| 192 |
+
|
| 193 |
+
schedule = self._get_num_transfer_tokens(block_length, steps)
|
| 194 |
+
|
| 195 |
+
steps_this_block = 0
|
| 196 |
+
for step in range(steps):
|
| 197 |
+
steps_this_block += 1
|
| 198 |
+
cur_full_x = x[:current_bsz, :]
|
| 199 |
+
|
| 200 |
+
perform_pruning = False
|
| 201 |
+
num_parents_to_select = 0
|
| 202 |
+
|
| 203 |
+
if hts_mode and step >= next_allowed_pruning_step and step < tr_end:
|
| 204 |
+
target_width = max(final_K_list[-1], math.ceil(initial_N * (decay_factor ** -(step - ts_start))))
|
| 205 |
+
if current_bsz > target_width:
|
| 206 |
+
perform_pruning = True
|
| 207 |
+
num_parents_to_select = hts_survivor_k
|
| 208 |
+
elif not hts_mode and step in schedule_map:
|
| 209 |
+
target_width, num_parents_to_select = schedule_map[step]
|
| 210 |
+
if current_bsz > target_width: perform_pruning = True
|
| 211 |
+
|
| 212 |
+
if perform_pruning:
|
| 213 |
+
stats["nfe"] += current_bsz
|
| 214 |
+
if num_block == 0: stats["first_block_nfe"] += current_bsz
|
| 215 |
+
stats["svf_calls"] += current_bsz
|
| 216 |
+
|
| 217 |
+
gen_logits = self._chunked_forward(cur_full_x, chunk_size=64, slice_indices=(prompt_length, total_length))
|
| 218 |
+
rough_ids = torch.argmax(gen_logits, dim=-1)
|
| 219 |
+
rough_codes_snippet = self.tokenizer.batch_decode(rough_ids, skip_special_tokens=True)
|
| 220 |
+
candidates = []
|
| 221 |
+
for i in range(current_bsz):
|
| 222 |
+
full_code = rough_codes_snippet[i]
|
| 223 |
+
s = self._safe_scalar(self.verifier.get_reward(prompt_text, full_code, mode=reward_mode, problem_data=problem_data, current_logits=gen_logits[i] if reward_mode != "svf" else None, task_type=task_type))
|
| 224 |
+
s += self._analyze_structure(full_code, task_type=task_type)
|
| 225 |
+
clean_content = full_code.strip().replace(" ", "").replace("\n", "")
|
| 226 |
+
candidates.append({'score': s, 'idx': i, 'key': hash(clean_content[:200] + clean_content[-200:])})
|
| 227 |
+
|
| 228 |
+
stats["pruning_history"].append({"step": step, "scores": [c['score'] for c in candidates]})
|
| 229 |
+
candidates.sort(key=lambda x: x['score'], reverse=True)
|
| 230 |
+
|
| 231 |
+
selected_indices, seen_keys = [], set()
|
| 232 |
+
for cand in candidates:
|
| 233 |
+
if len(selected_indices) >= num_parents_to_select: break
|
| 234 |
+
if cand['key'] not in seen_keys:
|
| 235 |
+
selected_indices.append(cand['idx']); seen_keys.add(cand['key'])
|
| 236 |
+
|
| 237 |
+
if len(selected_indices) < num_parents_to_select:
|
| 238 |
+
for cand in candidates:
|
| 239 |
+
if len(selected_indices) >= num_parents_to_select: break
|
| 240 |
+
if cand['idx'] not in selected_indices: selected_indices.append(cand['idx'])
|
| 241 |
+
|
| 242 |
+
top_indices = torch.tensor(selected_indices, device=self.device)
|
| 243 |
+
x, conf_scores = self._branch_and_resample(x, conf_scores, top_indices, target_width, mask_id, prompt_length, task_type=task_type)
|
| 244 |
+
|
| 245 |
+
current_bsz = target_width
|
| 246 |
+
cur_full_x = x[:current_bsz, :]
|
| 247 |
+
next_allowed_pruning_step = step + 1 + pruning_interval
|
| 248 |
+
|
| 249 |
+
stats["nfe"] += current_bsz
|
| 250 |
+
if num_block == 0: stats["first_block_nfe"] += current_bsz
|
| 251 |
+
|
| 252 |
+
active_logits = self._chunked_forward(cur_full_x, chunk_size=32, slice_indices=(window_start, window_end))
|
| 253 |
+
active_logits[:, :, eos_id] = -1e10
|
| 254 |
+
|
| 255 |
+
x0, x0_p = self._sample_with_temperature(active_logits, temperature, top_k, top_p)
|
| 256 |
+
|
| 257 |
+
active_mask = x[:current_bsz, window_start:window_end] == mask_id
|
| 258 |
+
|
| 259 |
+
num_transfer = schedule[step].item()
|
| 260 |
+
confidence = torch.where(active_mask, x0_p, -torch.inf)
|
| 261 |
+
transfer_idx = torch.zeros_like(x0, dtype=torch.bool)
|
| 262 |
+
|
| 263 |
+
for b in range(current_bsz):
|
| 264 |
+
mask_count = active_mask[b].sum().item()
|
| 265 |
+
if mask_count > 0:
|
| 266 |
+
k_transfer = min(num_transfer, mask_count)
|
| 267 |
+
active_indices = torch.where(active_mask[b])[0]
|
| 268 |
+
high_conf_mask = (confidence[b] > threshold) & active_mask[b]
|
| 269 |
+
if high_conf_mask.sum().item() >= k_transfer:
|
| 270 |
+
conf_indices = torch.where(high_conf_mask)[0]
|
| 271 |
+
transfer_idx[b, conf_indices] = True
|
| 272 |
+
else:
|
| 273 |
+
_, topk_indices = torch.topk(confidence[b][active_indices], k=min(k_transfer, len(active_indices)))
|
| 274 |
+
transfer_idx[b, active_indices[topk_indices]] = True
|
| 275 |
+
|
| 276 |
+
if transfer_idx.any():
|
| 277 |
+
x[:current_bsz, window_start:window_end][transfer_idx] = x0[transfer_idx]
|
| 278 |
+
conf_scores[:current_bsz, window_start:window_end][transfer_idx] = x0_p[transfer_idx]
|
| 279 |
+
|
| 280 |
+
if task_type in ["math", "reasoning"]:
|
| 281 |
+
for b in range(current_bsz):
|
| 282 |
+
text_snippet = self.tokenizer.decode(x[b, prompt_length:window_end], skip_special_tokens=True)
|
| 283 |
+
should_stop = False
|
| 284 |
+
if task_type == "reasoning" and ("###" in text_snippet): should_stop = True
|
| 285 |
+
if task_type == "math" and ("\\boxed{" in text_snippet and "}" in text_snippet.split("\\boxed{")[-1]): should_stop = True
|
| 286 |
+
|
| 287 |
+
if should_stop:
|
| 288 |
+
after_mask = (x[b, window_start:total_length] == mask_id)
|
| 289 |
+
x[b, window_start:total_length][after_mask] = eos_id
|
| 290 |
+
|
| 291 |
+
stats["steps_per_block"].append(steps_this_block)
|
| 292 |
+
x = x[:current_bsz]
|
| 293 |
+
|
| 294 |
+
stats["nfe"] = int(round(stats["nfe"]))
|
| 295 |
+
stats["first_block_nfe"] = int(round(stats["first_block_nfe"]))
|
| 296 |
+
|
| 297 |
+
final_gen_tokens = x[:current_bsz, prompt_length:]
|
| 298 |
+
final_codes = self.tokenizer.batch_decode(final_gen_tokens, skip_special_tokens=True)
|
| 299 |
+
final_candidates = []
|
| 300 |
+
|
| 301 |
+
stats["svf_calls"] += len(final_codes)
|
| 302 |
+
for i in range(len(final_codes)):
|
| 303 |
+
txt = final_codes[i]
|
| 304 |
+
if until:
|
| 305 |
+
for term in until:
|
| 306 |
+
if term in txt: txt = txt.split(term)[0]
|
| 307 |
+
s = self._safe_scalar(self.verifier.get_reward(prompt_text, txt, mode=reward_mode, task_type=task_type))
|
| 308 |
+
s += self._analyze_structure(txt, task_type)
|
| 309 |
+
final_candidates.append({'resp': txt, 'score': s})
|
| 310 |
+
|
| 311 |
+
final_candidates.sort(key=lambda x: x['score'], reverse=True)
|
| 312 |
+
stats["final_scores"] = [c['score'] for c in final_candidates]
|
| 313 |
+
stats["all_trajectories"] = [{"rank": i+1, "resp": c['resp'], "score": c['score']} for i, c in enumerate(final_candidates)]
|
| 314 |
+
|
| 315 |
+
return [c['resp'] for c in final_candidates], stats
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/huggingface.py
ADDED
|
@@ -0,0 +1,1489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from datetime import timedelta
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import jinja2
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import transformers
|
| 12 |
+
from accelerate import (
|
| 13 |
+
Accelerator,
|
| 14 |
+
InitProcessGroupKwargs,
|
| 15 |
+
find_executable_batch_size,
|
| 16 |
+
)
|
| 17 |
+
from accelerate.utils import get_max_memory
|
| 18 |
+
from huggingface_hub import HfApi
|
| 19 |
+
from packaging import version
|
| 20 |
+
from peft import PeftModel
|
| 21 |
+
from peft import __version__ as PEFT_VERSION
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
from transformers.models.auto.modeling_auto import (
|
| 24 |
+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
| 25 |
+
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
from dllm_eval import utils
|
| 29 |
+
from dllm_eval.api.instance import Instance
|
| 30 |
+
from dllm_eval.api.model import TemplateLM
|
| 31 |
+
from dllm_eval.api.registry import register_model
|
| 32 |
+
from dllm_eval.models.utils import (
|
| 33 |
+
Collator,
|
| 34 |
+
clear_torch_cache,
|
| 35 |
+
configure_pad_token,
|
| 36 |
+
get_dtype,
|
| 37 |
+
handle_stop_sequences,
|
| 38 |
+
pad_and_concat,
|
| 39 |
+
stop_sequences_criteria,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
eval_logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@register_model("hf-auto", "hf", "huggingface")
|
| 47 |
+
class HFLM(TemplateLM):
|
| 48 |
+
"""
|
| 49 |
+
An abstracted Huggingface model class. Enables usage with both models of
|
| 50 |
+
`transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
|
| 51 |
+
|
| 52 |
+
Supports data-parallel multi-GPU with HF Accelerate.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
AUTO_MODEL_CLASS = None
|
| 56 |
+
_DEFAULT_MAX_LENGTH = 2048
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
pretrained: Union[str, transformers.PreTrainedModel],
|
| 61 |
+
backend: Literal["default", "causal", "seq2seq"] = "default",
|
| 62 |
+
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
|
| 63 |
+
revision: Optional[str] = "main",
|
| 64 |
+
subfolder: str = "",
|
| 65 |
+
tokenizer: Optional[
|
| 66 |
+
Union[
|
| 67 |
+
str,
|
| 68 |
+
transformers.PreTrainedTokenizer,
|
| 69 |
+
transformers.PreTrainedTokenizerFast,
|
| 70 |
+
]
|
| 71 |
+
] = None,
|
| 72 |
+
truncation: Optional[bool] = False,
|
| 73 |
+
logits_cache: bool = True,
|
| 74 |
+
max_length: Optional[int] = None,
|
| 75 |
+
device: Optional[str] = "cuda",
|
| 76 |
+
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
| 77 |
+
softmax_dtype: Optional[Union[str, torch.dtype]] = None,
|
| 78 |
+
batch_size: Optional[Union[int, str]] = 1,
|
| 79 |
+
max_batch_size: Optional[int] = 64,
|
| 80 |
+
trust_remote_code: Optional[bool] = False,
|
| 81 |
+
use_fast_tokenizer: Optional[bool] = True,
|
| 82 |
+
add_bos_token: Optional[bool] = False,
|
| 83 |
+
prefix_token_id: Optional[int] = None,
|
| 84 |
+
# arguments used for splitting a model across GPUs naively.
|
| 85 |
+
# only used if `parallelize=True`.
|
| 86 |
+
parallelize: Optional[bool] = False,
|
| 87 |
+
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
| 88 |
+
max_cpu_memory: Optional[Union[int, str]] = None,
|
| 89 |
+
offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
|
| 90 |
+
# PEFT, delta weights and quantization options
|
| 91 |
+
peft: Optional[str] = None,
|
| 92 |
+
delta: Optional[str] = None,
|
| 93 |
+
autogptq: Optional[Union[bool, str]] = False,
|
| 94 |
+
gptqmodel: Optional[bool] = False,
|
| 95 |
+
gguf_file: Optional[str] = None,
|
| 96 |
+
**kwargs,
|
| 97 |
+
) -> None:
|
| 98 |
+
super().__init__()
|
| 99 |
+
# optionally: take in an already-initialized transformers.PreTrainedModel
|
| 100 |
+
if not isinstance(pretrained, str):
|
| 101 |
+
eval_logger.warning(
|
| 102 |
+
"`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
|
| 103 |
+
)
|
| 104 |
+
assert not parallelize, (
|
| 105 |
+
"`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
|
| 106 |
+
)
|
| 107 |
+
self._model = pretrained
|
| 108 |
+
self._device = self._model.device
|
| 109 |
+
self._config = self._model.config
|
| 110 |
+
gpus = 0
|
| 111 |
+
|
| 112 |
+
else:
|
| 113 |
+
assert isinstance(device, str)
|
| 114 |
+
assert isinstance(pretrained, str)
|
| 115 |
+
assert isinstance(batch_size, (int, str))
|
| 116 |
+
|
| 117 |
+
gpus = torch.cuda.device_count()
|
| 118 |
+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
|
| 119 |
+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
|
| 120 |
+
if accelerator.num_processes > 1:
|
| 121 |
+
self.accelerator = accelerator
|
| 122 |
+
|
| 123 |
+
if "npu" in accelerator.device.type:
|
| 124 |
+
gpus = torch.npu.device_count()
|
| 125 |
+
|
| 126 |
+
# using one process with no model parallelism
|
| 127 |
+
if not (parallelize or accelerator.num_processes > 1):
|
| 128 |
+
# use user-passed device
|
| 129 |
+
device_list = set(
|
| 130 |
+
["cuda", "cpu"]
|
| 131 |
+
+ [f"cuda:{i}" for i in range(gpus)]
|
| 132 |
+
+ ["mps", "mps:0"]
|
| 133 |
+
+ [f"npu:{i}" for i in range(gpus)]
|
| 134 |
+
)
|
| 135 |
+
if device and device in device_list:
|
| 136 |
+
self._device = torch.device(device)
|
| 137 |
+
eval_logger.info(f"Using device '{device}'")
|
| 138 |
+
if device in ("mps", "mps:0") and version.parse(
|
| 139 |
+
torch.__version__
|
| 140 |
+
) < version.parse("2.1"):
|
| 141 |
+
raise RuntimeError(
|
| 142 |
+
f"mps requires torch >= 2.1. You have {torch.__version__}"
|
| 143 |
+
)
|
| 144 |
+
else:
|
| 145 |
+
eval_logger.info("Device not specified")
|
| 146 |
+
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
|
| 147 |
+
self._device = (
|
| 148 |
+
torch.device("cuda")
|
| 149 |
+
if torch.cuda.is_available()
|
| 150 |
+
else torch.device("cpu")
|
| 151 |
+
)
|
| 152 |
+
else: # Parallelism managed by accelerate
|
| 153 |
+
if device != "cuda":
|
| 154 |
+
eval_logger.info(
|
| 155 |
+
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
|
| 156 |
+
)
|
| 157 |
+
# TODO: include in warning that `load_in_8bit` etc. affect this too
|
| 158 |
+
self._device = (
|
| 159 |
+
self.accelerator.device
|
| 160 |
+
if hasattr(self, "accelerator")
|
| 161 |
+
else torch.device(device)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
revision = str(revision) # cast to string if not already one
|
| 165 |
+
|
| 166 |
+
self._get_config(
|
| 167 |
+
pretrained,
|
| 168 |
+
revision=revision,
|
| 169 |
+
trust_remote_code=trust_remote_code,
|
| 170 |
+
gguf_file=gguf_file,
|
| 171 |
+
subfolder=subfolder,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# determine which of 'causal' and 'seq2seq' backends to use for HF models
|
| 175 |
+
self._get_backend(
|
| 176 |
+
config=self.config, backend=backend, trust_remote_code=trust_remote_code
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# load tokenizer so we know tokenizer vocabulary size before loading model and PEFT
|
| 180 |
+
self._create_tokenizer(
|
| 181 |
+
pretrained,
|
| 182 |
+
tokenizer,
|
| 183 |
+
revision=revision,
|
| 184 |
+
subfolder=subfolder,
|
| 185 |
+
trust_remote_code=trust_remote_code,
|
| 186 |
+
use_fast_tokenizer=use_fast_tokenizer,
|
| 187 |
+
gguf_file=gguf_file,
|
| 188 |
+
add_bos_token=add_bos_token,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# if we passed `pretrained` as a string, initialize our model now
|
| 192 |
+
if isinstance(pretrained, str):
|
| 193 |
+
self._create_model(
|
| 194 |
+
pretrained=pretrained,
|
| 195 |
+
revision=revision,
|
| 196 |
+
dtype=dtype,
|
| 197 |
+
trust_remote_code=trust_remote_code,
|
| 198 |
+
parallelize=parallelize,
|
| 199 |
+
gpus=gpus,
|
| 200 |
+
max_memory_per_gpu=max_memory_per_gpu,
|
| 201 |
+
max_cpu_memory=max_cpu_memory,
|
| 202 |
+
offload_folder=offload_folder,
|
| 203 |
+
peft=peft,
|
| 204 |
+
delta=delta,
|
| 205 |
+
autogptq=autogptq,
|
| 206 |
+
gptqmodel=gptqmodel,
|
| 207 |
+
gguf_file=gguf_file,
|
| 208 |
+
quantization_config=getattr(self.config, "quantization_config", None),
|
| 209 |
+
subfolder=subfolder,
|
| 210 |
+
**kwargs,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# access self._model through self.model property outside this method
|
| 214 |
+
if isinstance(self.model, torch.nn.Module):
|
| 215 |
+
self.model.eval()
|
| 216 |
+
self.model.tie_weights()
|
| 217 |
+
|
| 218 |
+
self.truncation = truncation
|
| 219 |
+
self.logits_cache = logits_cache
|
| 220 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 221 |
+
# select (or create) a pad token to use
|
| 222 |
+
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
|
| 223 |
+
|
| 224 |
+
self.add_bos_token = add_bos_token
|
| 225 |
+
if "gemma" in getattr(self.config, "model_type", ""):
|
| 226 |
+
self.add_bos_token = True
|
| 227 |
+
eval_logger.info(
|
| 228 |
+
f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
self._max_length = max_length
|
| 232 |
+
self.pretrained = pretrained
|
| 233 |
+
self.delta = delta
|
| 234 |
+
self.peft = peft
|
| 235 |
+
self.revision = revision
|
| 236 |
+
self.batch_schedule = 1
|
| 237 |
+
self.batch_sizes = {}
|
| 238 |
+
self.max_batch_size = max_batch_size
|
| 239 |
+
self.softmax_dtype = (
|
| 240 |
+
get_dtype(softmax_dtype) if softmax_dtype is not None else None
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
if str(batch_size).startswith("auto"):
|
| 244 |
+
batch_size = batch_size.split(":")
|
| 245 |
+
self.batch_size_per_gpu = batch_size[0]
|
| 246 |
+
self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
|
| 247 |
+
else:
|
| 248 |
+
self.batch_size_per_gpu = int(batch_size)
|
| 249 |
+
|
| 250 |
+
if isinstance(pretrained, str):
|
| 251 |
+
if gpus >= 1 or str(self.device) == "mps":
|
| 252 |
+
# TODO: can remove this whole snippet except in the mps case, perhaps?
|
| 253 |
+
if not (parallelize or autogptq or hasattr(self, "accelerator")):
|
| 254 |
+
# place model onto device requested manually,
|
| 255 |
+
# if not using HF Accelerate or device_map
|
| 256 |
+
# or any other option that preloads model onto device
|
| 257 |
+
try:
|
| 258 |
+
self.model.to(self.device)
|
| 259 |
+
except ValueError:
|
| 260 |
+
eval_logger.debug(
|
| 261 |
+
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
|
| 262 |
+
)
|
| 263 |
+
# multigpu data-parallel support when launched with accelerate
|
| 264 |
+
if gpus > 1:
|
| 265 |
+
if accelerator.num_processes > 1:
|
| 266 |
+
if parallelize:
|
| 267 |
+
eval_logger.warning(
|
| 268 |
+
"You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
|
| 269 |
+
)
|
| 270 |
+
elif gpus > accelerator.num_processes:
|
| 271 |
+
eval_logger.warning(
|
| 272 |
+
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
|
| 273 |
+
"If you would like to use data parallelism, please launch the script "
|
| 274 |
+
"with 'accelerate launch *script*'. "
|
| 275 |
+
f"Current run will proceed with {accelerator.num_processes} devices."
|
| 276 |
+
)
|
| 277 |
+
if self.accelerator.is_local_main_process:
|
| 278 |
+
eval_logger.info(
|
| 279 |
+
f"Using {gpus} devices with data parallelism"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
self._device = torch.device(f"{accelerator.device}")
|
| 283 |
+
self.accelerator = accelerator
|
| 284 |
+
|
| 285 |
+
self._rank = self.accelerator.local_process_index
|
| 286 |
+
self._world_size = self.accelerator.num_processes
|
| 287 |
+
else:
|
| 288 |
+
# if we aren't launching via accelerate, ditch
|
| 289 |
+
self._rank = 0
|
| 290 |
+
self._world_size = 1
|
| 291 |
+
else:
|
| 292 |
+
# if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
|
| 293 |
+
eval_logger.warning(
|
| 294 |
+
"Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
|
| 295 |
+
)
|
| 296 |
+
self._rank = 0
|
| 297 |
+
self._world_size = 1
|
| 298 |
+
|
| 299 |
+
self.custom_prefix_token_id = prefix_token_id
|
| 300 |
+
if prefix_token_id is not None:
|
| 301 |
+
eval_logger.info(
|
| 302 |
+
f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
def _get_accelerate_args(
|
| 306 |
+
self,
|
| 307 |
+
parallelize: Optional[bool] = None,
|
| 308 |
+
device_map: Optional[str] = "auto",
|
| 309 |
+
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
| 310 |
+
max_cpu_memory: Optional[Union[int, str]] = None,
|
| 311 |
+
offload_folder: Optional[str] = "./offload",
|
| 312 |
+
gpus: Optional[int] = None,
|
| 313 |
+
) -> dict:
|
| 314 |
+
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
|
| 315 |
+
num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
|
| 316 |
+
num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes
|
| 317 |
+
if (
|
| 318 |
+
num_machines == 0
|
| 319 |
+
and hasattr(self, "accelerator")
|
| 320 |
+
and self.accelerator is not None
|
| 321 |
+
):
|
| 322 |
+
eval_logger.info(
|
| 323 |
+
"We are not in a distributed setting for accelerate. Setting model_parallel to False."
|
| 324 |
+
)
|
| 325 |
+
parallelize = False
|
| 326 |
+
|
| 327 |
+
if parallelize is None:
|
| 328 |
+
# If parallelism is unset by the user, we automatically assign model parallelism
|
| 329 |
+
# if enough extra GPUs are available
|
| 330 |
+
max_memory_all_gpus = get_max_memory()
|
| 331 |
+
# We just want gpu, not cpu, max memory
|
| 332 |
+
if "cpu" in max_memory_all_gpus:
|
| 333 |
+
del max_memory_all_gpus["cpu"]
|
| 334 |
+
parallelize = bool(num_local_processes < len(max_memory_all_gpus))
|
| 335 |
+
eval_logger.info(
|
| 336 |
+
f"Setting model parallel to {parallelize} since "
|
| 337 |
+
f"the number of local processes is {num_local_processes} "
|
| 338 |
+
f"and the number of GPUs is {len(max_memory_all_gpus)}"
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
args = {}
|
| 342 |
+
if parallelize: # Model parallelism will be used
|
| 343 |
+
max_memory = {}
|
| 344 |
+
if max_memory_per_gpu is not None: # Using the provided memory requirements
|
| 345 |
+
max_memory_per_gpu_map = {
|
| 346 |
+
device_idx: max_memory_per_gpu for device_idx in range(gpus)
|
| 347 |
+
}
|
| 348 |
+
else: # Estimating the possible memory requirements
|
| 349 |
+
max_memory_all_gpus = get_max_memory()
|
| 350 |
+
if "cpu" in max_memory_all_gpus:
|
| 351 |
+
del max_memory_all_gpus["cpu"]
|
| 352 |
+
if not hasattr(self, "accelerator"):
|
| 353 |
+
max_memory_per_gpu_map = {
|
| 354 |
+
k: v for k, v in max_memory_all_gpus.items()
|
| 355 |
+
}
|
| 356 |
+
else:
|
| 357 |
+
# use only 1 / num_processes of the GPUs if we are running under accelerate launch
|
| 358 |
+
max_memory_per_gpu_map = {
|
| 359 |
+
k: v
|
| 360 |
+
for k, v in max_memory_all_gpus.items()
|
| 361 |
+
if k % num_local_processes
|
| 362 |
+
== (self.accelerator.process_index % num_local_processes)
|
| 363 |
+
}
|
| 364 |
+
args["max_memory"] = max_memory_per_gpu_map
|
| 365 |
+
args["device_map"] = "auto" if device_map is None else device_map
|
| 366 |
+
eval_logger.info(
|
| 367 |
+
f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}"
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
if max_cpu_memory is not None:
|
| 371 |
+
max_memory["cpu"] = max_cpu_memory
|
| 372 |
+
|
| 373 |
+
args["offload_folder"] = offload_folder
|
| 374 |
+
elif (
|
| 375 |
+
device_map is None
|
| 376 |
+
): # No model parallelism, we use the default provided device for our model
|
| 377 |
+
if hasattr(self, "accelerator"):
|
| 378 |
+
device_map = {"": f"{self.accelerator.device}"}
|
| 379 |
+
else:
|
| 380 |
+
device_map = {"": str(self.device)}
|
| 381 |
+
args["max_memory"] = None
|
| 382 |
+
args["device_map"] = device_map
|
| 383 |
+
eval_logger.info(
|
| 384 |
+
f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}"
|
| 385 |
+
)
|
| 386 |
+
else:
|
| 387 |
+
args["max_memory"] = None
|
| 388 |
+
args["device_map"] = None
|
| 389 |
+
eval_logger.info("Model parallel was set to False.")
|
| 390 |
+
|
| 391 |
+
return args
|
| 392 |
+
|
| 393 |
+
@property
|
| 394 |
+
def config(self):
|
| 395 |
+
# return the associated transformers.AutoConfig for the given pretrained model.
|
| 396 |
+
return self._config
|
| 397 |
+
|
| 398 |
+
@property
|
| 399 |
+
def model(self):
|
| 400 |
+
# returns the model, unwrapping it if using Accelerate
|
| 401 |
+
if hasattr(self, "accelerator"):
|
| 402 |
+
return self.accelerator.unwrap_model(self._model)
|
| 403 |
+
else:
|
| 404 |
+
return self._model
|
| 405 |
+
|
| 406 |
+
@property
|
| 407 |
+
def eot_token_id(self):
|
| 408 |
+
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
|
| 409 |
+
return self.tokenizer.eos_token_id
|
| 410 |
+
|
| 411 |
+
@property
|
| 412 |
+
def prefix_token_id(self):
|
| 413 |
+
# it is used as prefix for loglikelihood
|
| 414 |
+
if self.custom_prefix_token_id is not None:
|
| 415 |
+
return self.custom_prefix_token_id
|
| 416 |
+
if self.tokenizer.bos_token_id is not None:
|
| 417 |
+
return self.tokenizer.bos_token_id
|
| 418 |
+
return self.tokenizer.eos_token_id
|
| 419 |
+
|
| 420 |
+
@property
|
| 421 |
+
def max_length(self):
|
| 422 |
+
if self._max_length: # if max length manually set, return it
|
| 423 |
+
return self._max_length
|
| 424 |
+
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
|
| 425 |
+
for attr in seqlen_config_attrs:
|
| 426 |
+
if hasattr(self.model.config, attr):
|
| 427 |
+
return getattr(self.model.config, attr)
|
| 428 |
+
if hasattr(self.tokenizer, "model_max_length"):
|
| 429 |
+
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
|
| 430 |
+
return self._DEFAULT_MAX_LENGTH
|
| 431 |
+
return self.tokenizer.model_max_length
|
| 432 |
+
return self._DEFAULT_MAX_LENGTH
|
| 433 |
+
|
| 434 |
+
@property
|
| 435 |
+
def max_gen_toks(self) -> int:
|
| 436 |
+
return 256
|
| 437 |
+
|
| 438 |
+
@property
|
| 439 |
+
def batch_size(self):
|
| 440 |
+
return self.batch_size_per_gpu
|
| 441 |
+
|
| 442 |
+
@property
|
| 443 |
+
def device(self):
|
| 444 |
+
return self._device
|
| 445 |
+
|
| 446 |
+
@property
|
| 447 |
+
def rank(self):
|
| 448 |
+
return self._rank
|
| 449 |
+
|
| 450 |
+
@property
|
| 451 |
+
def world_size(self):
|
| 452 |
+
return self._world_size
|
| 453 |
+
|
| 454 |
+
@property
|
| 455 |
+
def tokenizer_name(self) -> str:
|
| 456 |
+
return self.tokenizer.name_or_path.replace("/", "__")
|
| 457 |
+
|
| 458 |
+
def _get_backend(
|
| 459 |
+
self,
|
| 460 |
+
config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
|
| 461 |
+
backend: Literal["default", "causal", "seq2seq"] = "default",
|
| 462 |
+
trust_remote_code: Optional[bool] = False,
|
| 463 |
+
) -> None:
|
| 464 |
+
"""
|
| 465 |
+
Helper method during initialization.
|
| 466 |
+
Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used.
|
| 467 |
+
sets `self.AUTO_MODEL_CLASS` appropriately if not already set.
|
| 468 |
+
|
| 469 |
+
**If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM,
|
| 470 |
+
user must set `self.backend` to be either "causal" or "seq2seq" manually!**
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
assert backend in ["default", "causal", "seq2seq"]
|
| 474 |
+
|
| 475 |
+
if backend != "default":
|
| 476 |
+
# if we've settled on non-default backend, use that manually
|
| 477 |
+
if backend == "causal":
|
| 478 |
+
self.backend = backend
|
| 479 |
+
elif backend == "seq2seq":
|
| 480 |
+
self.backend = backend
|
| 481 |
+
eval_logger.info(
|
| 482 |
+
f"Overrode HF model backend type, and using type '{self.backend}'"
|
| 483 |
+
)
|
| 484 |
+
else:
|
| 485 |
+
# determine and use the default HF backend for this model, based on its config + metadata.
|
| 486 |
+
if (
|
| 487 |
+
getattr(config, "model_type")
|
| 488 |
+
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
| 489 |
+
):
|
| 490 |
+
# first check if model type is listed under seq2seq models, since some
|
| 491 |
+
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
|
| 492 |
+
# these special cases should be treated as seq2seq models.
|
| 493 |
+
self.backend = "seq2seq"
|
| 494 |
+
eval_logger.debug(f"Using model type '{self.backend}'")
|
| 495 |
+
elif (
|
| 496 |
+
getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 497 |
+
):
|
| 498 |
+
self.backend = "causal"
|
| 499 |
+
eval_logger.debug(f"Using model type '{self.backend}'")
|
| 500 |
+
else:
|
| 501 |
+
if not trust_remote_code:
|
| 502 |
+
eval_logger.warning(
|
| 503 |
+
"HF model type is neither marked as CausalLM or Seq2SeqLM. \
|
| 504 |
+
This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
|
| 505 |
+
"Setting backend to causal"
|
| 506 |
+
)
|
| 507 |
+
# if model type is neither in HF transformers causal or seq2seq model registries
|
| 508 |
+
# then we default to assuming AutoModelForCausalLM
|
| 509 |
+
self.backend = "causal"
|
| 510 |
+
eval_logger.info(
|
| 511 |
+
f"Model type cannot be determined. Using default model type '{self.backend}'"
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
if self.AUTO_MODEL_CLASS is None:
|
| 515 |
+
if self.backend == "causal":
|
| 516 |
+
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
|
| 517 |
+
elif self.backend == "seq2seq":
|
| 518 |
+
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
|
| 519 |
+
|
| 520 |
+
def _get_config(
|
| 521 |
+
self,
|
| 522 |
+
pretrained: str,
|
| 523 |
+
revision: str = "main",
|
| 524 |
+
trust_remote_code: bool = False,
|
| 525 |
+
gguf_file: Optional[str] = None,
|
| 526 |
+
subfolder: str = "",
|
| 527 |
+
) -> None:
|
| 528 |
+
"""Return the model config for HuggingFace models"""
|
| 529 |
+
self._config = transformers.AutoConfig.from_pretrained(
|
| 530 |
+
pretrained,
|
| 531 |
+
revision=revision,
|
| 532 |
+
trust_remote_code=trust_remote_code,
|
| 533 |
+
gguf_file=gguf_file,
|
| 534 |
+
subfolder=subfolder,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
def _create_model(
|
| 538 |
+
self,
|
| 539 |
+
pretrained: str,
|
| 540 |
+
revision: Optional[str] = "main",
|
| 541 |
+
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
| 542 |
+
trust_remote_code: Optional[bool] = False,
|
| 543 |
+
# arguments used for splitting a model across GPUs naively.
|
| 544 |
+
# only used if `parallelize=True`.
|
| 545 |
+
# (accelerate naive PP (device_map) options)
|
| 546 |
+
parallelize: Optional[bool] = False,
|
| 547 |
+
gpus: Optional[int] = None,
|
| 548 |
+
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
| 549 |
+
max_cpu_memory: Optional[Union[int, str]] = None,
|
| 550 |
+
offload_folder: Optional[str] = "./offload",
|
| 551 |
+
# PEFT, delta weights and quantization options
|
| 552 |
+
peft: Optional[str] = None,
|
| 553 |
+
delta: Optional[str] = None,
|
| 554 |
+
autogptq: Optional[Union[bool, str]] = False,
|
| 555 |
+
gptqmodel: Optional[bool] = False,
|
| 556 |
+
gguf_file: Optional[str] = None,
|
| 557 |
+
quantization_config: Optional[Dict[str, Any]] = None,
|
| 558 |
+
subfolder: str = "",
|
| 559 |
+
**kwargs,
|
| 560 |
+
) -> None:
|
| 561 |
+
"""
|
| 562 |
+
Initializes an HF or HF-compatible PreTrainedModel from scratch
|
| 563 |
+
inside HFLM, using the kwargs passed into self.__init__().
|
| 564 |
+
|
| 565 |
+
Also handles functionality such as AutoGPTQ usage and PEFT wrapping.
|
| 566 |
+
|
| 567 |
+
For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
|
| 568 |
+
(such as PyTorch models that are nearly, but not quite, fully mirroring
|
| 569 |
+
HF's public interface relied on in this HFLM class)
|
| 570 |
+
please consider subclassing HFLM and overriding this and other methods as needed.
|
| 571 |
+
"""
|
| 572 |
+
|
| 573 |
+
model_kwargs = kwargs if kwargs else {}
|
| 574 |
+
|
| 575 |
+
model_kwargs.update(
|
| 576 |
+
self._get_accelerate_args(
|
| 577 |
+
parallelize=parallelize,
|
| 578 |
+
device_map=kwargs.get("device_map", None),
|
| 579 |
+
max_memory_per_gpu=max_memory_per_gpu,
|
| 580 |
+
max_cpu_memory=max_cpu_memory,
|
| 581 |
+
offload_folder=offload_folder,
|
| 582 |
+
gpus=gpus,
|
| 583 |
+
)
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
if not autogptq and not gptqmodel:
|
| 587 |
+
if model_kwargs.get("load_in_4bit", None):
|
| 588 |
+
assert transformers.__version__ >= "4.30.0", (
|
| 589 |
+
"load_in_4bit requires transformers >= 4.30.0"
|
| 590 |
+
)
|
| 591 |
+
if transformers.__version__ >= "4.30.0":
|
| 592 |
+
if model_kwargs.get("load_in_4bit", None):
|
| 593 |
+
if model_kwargs.get("bnb_4bit_compute_dtype", None):
|
| 594 |
+
model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
|
| 595 |
+
model_kwargs["bnb_4bit_compute_dtype"]
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
|
| 599 |
+
pretrained,
|
| 600 |
+
revision=revision,
|
| 601 |
+
torch_dtype=get_dtype(dtype),
|
| 602 |
+
trust_remote_code=trust_remote_code,
|
| 603 |
+
gguf_file=gguf_file,
|
| 604 |
+
quantization_config=quantization_config,
|
| 605 |
+
subfolder=subfolder,
|
| 606 |
+
**model_kwargs,
|
| 607 |
+
)
|
| 608 |
+
else:
|
| 609 |
+
if autogptq and gptqmodel:
|
| 610 |
+
raise ValueError(
|
| 611 |
+
"Cannot use both 'autogptq' and 'gptqmodel' options at the same time."
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
if autogptq:
|
| 615 |
+
try:
|
| 616 |
+
from auto_gptq import AutoGPTQForCausalLM
|
| 617 |
+
except ModuleNotFoundError as exception:
|
| 618 |
+
raise type(exception)(
|
| 619 |
+
"Tried to load auto_gptq, but auto-gptq is not installed ",
|
| 620 |
+
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
self._model = AutoGPTQForCausalLM.from_quantized(
|
| 624 |
+
pretrained,
|
| 625 |
+
trust_remote_code=trust_remote_code,
|
| 626 |
+
model_basename=None if autogptq is True else Path(autogptq).stem,
|
| 627 |
+
use_safetensors=True
|
| 628 |
+
if autogptq is True
|
| 629 |
+
else autogptq.endswith(".safetensors"),
|
| 630 |
+
**model_kwargs,
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
if gptqmodel:
|
| 634 |
+
try:
|
| 635 |
+
from gptqmodel import GPTQModel
|
| 636 |
+
except ModuleNotFoundError as exception:
|
| 637 |
+
raise type(exception)(
|
| 638 |
+
"Tried to load gptqmodel, but gptqmodel is not installed ",
|
| 639 |
+
"please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`",
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
self._model = GPTQModel.from_quantized(
|
| 643 |
+
pretrained, trust_remote_code=trust_remote_code, **model_kwargs
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
if peft and delta:
|
| 647 |
+
raise ValueError(
|
| 648 |
+
"Cannot use both 'peft' and 'delta' options at the same time."
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
if peft:
|
| 652 |
+
if model_kwargs.get("load_in_4bit", None):
|
| 653 |
+
if version.parse(PEFT_VERSION) < version.parse("0.4.0"):
|
| 654 |
+
raise AssertionError("load_in_4bit requires peft >= 0.4.0")
|
| 655 |
+
if self._model.config.vocab_size != len(self.tokenizer):
|
| 656 |
+
# resize model for LoRAs with added tokens
|
| 657 |
+
eval_logger.info(
|
| 658 |
+
f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
|
| 659 |
+
)
|
| 660 |
+
self._model.resize_token_embeddings(len(self.tokenizer))
|
| 661 |
+
self._model = PeftModel.from_pretrained(
|
| 662 |
+
self._model, peft, revision=revision
|
| 663 |
+
)
|
| 664 |
+
elif delta:
|
| 665 |
+
if autogptq:
|
| 666 |
+
eval_logger.warning(
|
| 667 |
+
"Delta weights might trigger unexpected behavior when used with AutoGPTQ."
|
| 668 |
+
)
|
| 669 |
+
_model_delta = self.AUTO_MODEL_CLASS.from_pretrained(
|
| 670 |
+
delta,
|
| 671 |
+
revision=revision,
|
| 672 |
+
torch_dtype=get_dtype(dtype),
|
| 673 |
+
trust_remote_code=trust_remote_code,
|
| 674 |
+
**model_kwargs,
|
| 675 |
+
)
|
| 676 |
+
for name, param in self._model.state_dict().items():
|
| 677 |
+
try:
|
| 678 |
+
param.data += _model_delta.state_dict()[name]
|
| 679 |
+
except KeyError:
|
| 680 |
+
raise KeyError(f"Delta model is missing weights for layer: {name}")
|
| 681 |
+
except Exception as e:
|
| 682 |
+
raise RuntimeError(
|
| 683 |
+
f"Failed to add delta weights to layer {name}. Error: {e}"
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
del _model_delta
|
| 687 |
+
|
| 688 |
+
return None
|
| 689 |
+
|
| 690 |
+
def _create_tokenizer(
|
| 691 |
+
self,
|
| 692 |
+
pretrained: Union[str, transformers.PreTrainedModel],
|
| 693 |
+
tokenizer: Optional[
|
| 694 |
+
Union[
|
| 695 |
+
str,
|
| 696 |
+
transformers.PreTrainedTokenizer,
|
| 697 |
+
transformers.PreTrainedTokenizerFast,
|
| 698 |
+
]
|
| 699 |
+
],
|
| 700 |
+
revision: Optional[str] = "main",
|
| 701 |
+
trust_remote_code: Optional[bool] = False,
|
| 702 |
+
use_fast_tokenizer: Optional[bool] = True,
|
| 703 |
+
gguf_file: Optional[str] = None,
|
| 704 |
+
add_bos_token: Optional[bool] = False,
|
| 705 |
+
subfolder: Optional[str] = "",
|
| 706 |
+
) -> None:
|
| 707 |
+
"""
|
| 708 |
+
Helper method during initialization.
|
| 709 |
+
|
| 710 |
+
Create a tokenizer object corresponding to the correct
|
| 711 |
+
tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
|
| 712 |
+
"""
|
| 713 |
+
kwargs = {
|
| 714 |
+
"revision": revision,
|
| 715 |
+
"trust_remote_code": trust_remote_code,
|
| 716 |
+
}
|
| 717 |
+
|
| 718 |
+
# gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param
|
| 719 |
+
if gguf_file is not None:
|
| 720 |
+
kwargs["gguf_file"] = gguf_file
|
| 721 |
+
else:
|
| 722 |
+
kwargs["use_fast"] = use_fast_tokenizer
|
| 723 |
+
|
| 724 |
+
if add_bos_token:
|
| 725 |
+
kwargs["add_bos_token"] = True
|
| 726 |
+
|
| 727 |
+
if subfolder:
|
| 728 |
+
kwargs["subfolder"] = subfolder
|
| 729 |
+
|
| 730 |
+
if tokenizer:
|
| 731 |
+
if isinstance(tokenizer, str):
|
| 732 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 733 |
+
tokenizer, **kwargs
|
| 734 |
+
)
|
| 735 |
+
else:
|
| 736 |
+
assert isinstance(
|
| 737 |
+
tokenizer, transformers.PreTrainedTokenizer
|
| 738 |
+
) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
|
| 739 |
+
self.tokenizer = tokenizer
|
| 740 |
+
else:
|
| 741 |
+
# Get tokenizer based on 'pretrained'
|
| 742 |
+
if isinstance(pretrained, str):
|
| 743 |
+
model_name = pretrained
|
| 744 |
+
else:
|
| 745 |
+
# get the HF hub name via accessor on model
|
| 746 |
+
model_name = self.model.name_or_path
|
| 747 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 748 |
+
model_name, **kwargs
|
| 749 |
+
)
|
| 750 |
+
return None
|
| 751 |
+
|
| 752 |
+
def _detect_batch_size(self, requests=None, pos: int = 0):
|
| 753 |
+
if requests:
|
| 754 |
+
_, context_enc, continuation_enc = requests[pos]
|
| 755 |
+
max_length = len(
|
| 756 |
+
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
|
| 757 |
+
)
|
| 758 |
+
max_context_enc = len(context_enc[-(self.max_length + 1) :])
|
| 759 |
+
max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
|
| 760 |
+
else:
|
| 761 |
+
max_length = self.max_length
|
| 762 |
+
max_context_enc = max_length
|
| 763 |
+
max_cont_enc = max_length
|
| 764 |
+
|
| 765 |
+
# if OOM, then halves batch_size and tries again
|
| 766 |
+
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
|
| 767 |
+
def forward_batch(batch_size):
|
| 768 |
+
if self.backend == "seq2seq":
|
| 769 |
+
length = max(max_context_enc, max_cont_enc)
|
| 770 |
+
batched_conts = torch.ones(
|
| 771 |
+
(batch_size, length), device=self.device
|
| 772 |
+
).long()
|
| 773 |
+
test_batch = torch.ones((batch_size, length), device=self.device).long()
|
| 774 |
+
call_kwargs = {
|
| 775 |
+
"attn_mask": test_batch,
|
| 776 |
+
"labels": batched_conts,
|
| 777 |
+
}
|
| 778 |
+
else:
|
| 779 |
+
call_kwargs = {}
|
| 780 |
+
test_batch = torch.ones(
|
| 781 |
+
(batch_size, max_length), device=self.device
|
| 782 |
+
).long()
|
| 783 |
+
for _ in range(5):
|
| 784 |
+
out = F.log_softmax( # noqa: F841
|
| 785 |
+
self._model_call(test_batch, **call_kwargs),
|
| 786 |
+
dim=-1,
|
| 787 |
+
dtype=self.softmax_dtype,
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
return batch_size
|
| 791 |
+
|
| 792 |
+
try:
|
| 793 |
+
batch_size = forward_batch()
|
| 794 |
+
except RuntimeError as e:
|
| 795 |
+
if "No executable batch size found" in str(e):
|
| 796 |
+
batch_size = 1
|
| 797 |
+
else:
|
| 798 |
+
raise
|
| 799 |
+
|
| 800 |
+
if self.world_size > 1:
|
| 801 |
+
# if multi-GPU, always take minimum over all selected batch sizes
|
| 802 |
+
max_rnk_bs = torch.tensor([batch_size], device=self.device)
|
| 803 |
+
gathered = (
|
| 804 |
+
self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
|
| 805 |
+
)
|
| 806 |
+
batch_size = min(gathered)
|
| 807 |
+
clear_torch_cache()
|
| 808 |
+
return batch_size
|
| 809 |
+
|
| 810 |
+
clear_torch_cache()
|
| 811 |
+
return batch_size
|
| 812 |
+
|
| 813 |
+
def tok_encode(
|
| 814 |
+
self, string: str, left_truncate_len=None, add_special_tokens=None
|
| 815 |
+
) -> List[int]:
|
| 816 |
+
""" """
|
| 817 |
+
# default for None - empty dict, use predefined tokenizer param
|
| 818 |
+
# used for all models except for CausalLM or predefined value
|
| 819 |
+
special_tokens_kwargs = {}
|
| 820 |
+
|
| 821 |
+
# by default for CausalLM - false or self.add_bos_token is set
|
| 822 |
+
if add_special_tokens is None:
|
| 823 |
+
if self.backend == "causal":
|
| 824 |
+
special_tokens_kwargs = {
|
| 825 |
+
"add_special_tokens": False or self.add_bos_token
|
| 826 |
+
}
|
| 827 |
+
# otherwise the method explicitly defines the value
|
| 828 |
+
else:
|
| 829 |
+
special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
|
| 830 |
+
|
| 831 |
+
encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
|
| 832 |
+
|
| 833 |
+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
|
| 834 |
+
if left_truncate_len:
|
| 835 |
+
encoding = encoding[-left_truncate_len:]
|
| 836 |
+
|
| 837 |
+
return encoding
|
| 838 |
+
|
| 839 |
+
def tok_batch_encode(
|
| 840 |
+
self,
|
| 841 |
+
strings: List[str],
|
| 842 |
+
padding_side: str = "left",
|
| 843 |
+
left_truncate_len: int = None,
|
| 844 |
+
truncation: bool = False,
|
| 845 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 846 |
+
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
|
| 847 |
+
old_padding_side = self.tokenizer.padding_side
|
| 848 |
+
self.tokenizer.padding_side = padding_side
|
| 849 |
+
|
| 850 |
+
add_special_tokens = {}
|
| 851 |
+
if self.backend == "causal":
|
| 852 |
+
add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
|
| 853 |
+
|
| 854 |
+
encoding = self.tokenizer(
|
| 855 |
+
strings,
|
| 856 |
+
truncation=truncation,
|
| 857 |
+
padding="longest",
|
| 858 |
+
return_tensors="pt",
|
| 859 |
+
**add_special_tokens,
|
| 860 |
+
)
|
| 861 |
+
if left_truncate_len:
|
| 862 |
+
original_lengths = encoding["input_ids"].size(1)
|
| 863 |
+
if original_lengths > left_truncate_len:
|
| 864 |
+
eval_logger.warn(
|
| 865 |
+
f"Left truncation applied. Original sequence length was {original_lengths}, "
|
| 866 |
+
f"truncating to last {left_truncate_len} tokens. Some content will be lost.",
|
| 867 |
+
)
|
| 868 |
+
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
|
| 869 |
+
encoding["attention_mask"] = encoding["attention_mask"][
|
| 870 |
+
:, -left_truncate_len:
|
| 871 |
+
]
|
| 872 |
+
self.tokenizer.padding_side = old_padding_side
|
| 873 |
+
|
| 874 |
+
return encoding["input_ids"], encoding["attention_mask"]
|
| 875 |
+
|
| 876 |
+
def tok_decode(self, tokens, skip_special_tokens=True):
|
| 877 |
+
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
| 878 |
+
|
| 879 |
+
def _model_call(self, inps, attn_mask=None, labels=None):
|
| 880 |
+
"""
|
| 881 |
+
:param inps: torch.Tensor
|
| 882 |
+
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
|
| 883 |
+
[batch, sequence_ctx]. the size of sequence may vary from call to call
|
| 884 |
+
:param attn_mask: torch.Tensor, optional
|
| 885 |
+
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
|
| 886 |
+
(and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
|
| 887 |
+
:param labels: torch.Tensor, optional
|
| 888 |
+
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
|
| 889 |
+
(and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
|
| 890 |
+
:return
|
| 891 |
+
A torch tensor of shape [batch, sequence, vocab] with the
|
| 892 |
+
logits returned from the model's decoder
|
| 893 |
+
"""
|
| 894 |
+
with torch.no_grad():
|
| 895 |
+
if attn_mask is not None or labels is not None:
|
| 896 |
+
assert attn_mask is not None and labels is not None
|
| 897 |
+
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
|
| 898 |
+
return self.model(
|
| 899 |
+
input_ids=inps, attention_mask=attn_mask, labels=labels
|
| 900 |
+
).logits
|
| 901 |
+
else:
|
| 902 |
+
assert self.AUTO_MODEL_CLASS in (
|
| 903 |
+
transformers.AutoModelForCausalLM,
|
| 904 |
+
transformers.AutoModelForVision2Seq,
|
| 905 |
+
)
|
| 906 |
+
return self.model(inps).logits
|
| 907 |
+
|
| 908 |
+
def _model_generate(self, context, max_length, stop, **generation_kwargs):
|
| 909 |
+
# temperature = 0.0 if not set
|
| 910 |
+
# if do_sample is false and temp==0.0:
|
| 911 |
+
# remove temperature, as do_sample=False takes care of this
|
| 912 |
+
# and we don't want a warning from HF
|
| 913 |
+
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
|
| 914 |
+
do_sample = generation_kwargs.get("do_sample", None)
|
| 915 |
+
|
| 916 |
+
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
|
| 917 |
+
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
|
| 918 |
+
generation_kwargs["do_sample"] = do_sample = False
|
| 919 |
+
|
| 920 |
+
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
|
| 921 |
+
generation_kwargs.pop("temperature")
|
| 922 |
+
# build stopping criteria
|
| 923 |
+
stopping_criteria = stop_sequences_criteria(
|
| 924 |
+
self.tokenizer, stop, context.shape[1], context.shape[0]
|
| 925 |
+
)
|
| 926 |
+
return self.model.generate(
|
| 927 |
+
input_ids=context,
|
| 928 |
+
max_length=max_length,
|
| 929 |
+
stopping_criteria=stopping_criteria,
|
| 930 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 931 |
+
use_cache=True,
|
| 932 |
+
**generation_kwargs,
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
def _select_cont_toks(
|
| 936 |
+
self, logits: torch.Tensor, contlen: int = None, inplen: int = None
|
| 937 |
+
) -> torch.Tensor:
|
| 938 |
+
if self.backend == "causal":
|
| 939 |
+
assert contlen and inplen, (
|
| 940 |
+
"Must pass input len and cont. len to select scored logits for causal LM"
|
| 941 |
+
)
|
| 942 |
+
# discard right-padding.
|
| 943 |
+
# also discard the input/context tokens. we'll only score continuations.
|
| 944 |
+
logits = logits[inplen - contlen : inplen]
|
| 945 |
+
elif self.backend == "seq2seq":
|
| 946 |
+
assert contlen and not inplen, (
|
| 947 |
+
"Selecting scored logits for Seq2SeqLM requires only cont. len"
|
| 948 |
+
)
|
| 949 |
+
# only discard right-padding.
|
| 950 |
+
# the logits input to this fn only contain decoder-side tokens.
|
| 951 |
+
logits = logits[:contlen]
|
| 952 |
+
|
| 953 |
+
return logits
|
| 954 |
+
|
| 955 |
+
def loglikelihood_rolling(
|
| 956 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
| 957 |
+
) -> List[float]:
|
| 958 |
+
adaptive_batch_size = None
|
| 959 |
+
if self.batch_size == "auto":
|
| 960 |
+
# using rolling window with maximum context
|
| 961 |
+
print("Passed argument batch_size = auto. Detecting largest batch size")
|
| 962 |
+
batch_size = self._detect_batch_size()
|
| 963 |
+
print(f"Determined Largest batch size: {batch_size}")
|
| 964 |
+
adaptive_batch_size = batch_size
|
| 965 |
+
|
| 966 |
+
# First, collect all windows from all requests
|
| 967 |
+
all_windows = [] # List of (request_idx, window) tuples
|
| 968 |
+
request_window_counts = [] # Track number of windows per request
|
| 969 |
+
|
| 970 |
+
for req_idx, (string,) in enumerate(
|
| 971 |
+
tqdm(
|
| 972 |
+
[req.args for req in requests],
|
| 973 |
+
disable=(disable_tqdm or (self.rank != 0)),
|
| 974 |
+
)
|
| 975 |
+
):
|
| 976 |
+
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
|
| 977 |
+
map(
|
| 978 |
+
utils.make_disjoint_window,
|
| 979 |
+
utils.get_rolling_token_windows(
|
| 980 |
+
token_list=self.tok_encode(string),
|
| 981 |
+
prefix_token=self.prefix_token_id,
|
| 982 |
+
max_seq_len=self.max_length,
|
| 983 |
+
context_len=1,
|
| 984 |
+
),
|
| 985 |
+
)
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
|
| 989 |
+
windows = [(None,) + x for x in rolling_token_windows]
|
| 990 |
+
|
| 991 |
+
# Store windows with their request index
|
| 992 |
+
all_windows.extend((req_idx, window) for window in windows)
|
| 993 |
+
request_window_counts.append(len(windows))
|
| 994 |
+
|
| 995 |
+
# Handle distributed case padding
|
| 996 |
+
pad_amnt = 0
|
| 997 |
+
if self.world_size > 1:
|
| 998 |
+
mytensor = torch.tensor(len(all_windows), device=self.device)
|
| 999 |
+
gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
|
| 1000 |
+
pad_amnt = max(gathered) - gathered[self.rank]
|
| 1001 |
+
if pad_amnt > 0:
|
| 1002 |
+
all_windows += pad_amnt * [all_windows[0]]
|
| 1003 |
+
|
| 1004 |
+
all_nlls = []
|
| 1005 |
+
batch_size = adaptive_batch_size or self.batch_size
|
| 1006 |
+
for i in range(0, len(all_windows), batch_size):
|
| 1007 |
+
batch = all_windows[i : i + batch_size]
|
| 1008 |
+
# Extract just the windows for processing, keeping track of request indices
|
| 1009 |
+
batch_indices, batch_windows = zip(*batch)
|
| 1010 |
+
|
| 1011 |
+
batch_nlls = self._loglikelihood_tokens(
|
| 1012 |
+
requests=batch_windows,
|
| 1013 |
+
disable_tqdm=False,
|
| 1014 |
+
override_bs=len(batch_windows),
|
| 1015 |
+
)
|
| 1016 |
+
# Store results with their request indices
|
| 1017 |
+
all_nlls.extend(zip(batch_indices, batch_nlls))
|
| 1018 |
+
|
| 1019 |
+
# Remove padding if necessary
|
| 1020 |
+
if (self.world_size > 1) and (pad_amnt > 0):
|
| 1021 |
+
all_nlls = all_nlls[:-pad_amnt]
|
| 1022 |
+
|
| 1023 |
+
# Reconstruct per-request loglikelihoods
|
| 1024 |
+
loglikelihoods = []
|
| 1025 |
+
current_idx = 0
|
| 1026 |
+
for window_count in request_window_counts:
|
| 1027 |
+
# Get all nlls for this request
|
| 1028 |
+
request_nlls = all_nlls[current_idx : current_idx + window_count]
|
| 1029 |
+
# Sum up the nlls for this request (discarding is_greedy)
|
| 1030 |
+
request_total = sum(nll[0] for _, nll in request_nlls)
|
| 1031 |
+
loglikelihoods.append(request_total)
|
| 1032 |
+
current_idx += window_count
|
| 1033 |
+
|
| 1034 |
+
string = requests[len(loglikelihoods) - 1].args[0]
|
| 1035 |
+
self.cache_hook.add_partial(
|
| 1036 |
+
"loglikelihood_rolling", (string,), request_total
|
| 1037 |
+
)
|
| 1038 |
+
|
| 1039 |
+
return loglikelihoods
|
| 1040 |
+
|
| 1041 |
+
def _batch_scheduler(self, pos, n_reordered_requests):
|
| 1042 |
+
sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
|
| 1043 |
+
if sched in self.batch_sizes:
|
| 1044 |
+
return self.batch_sizes[sched]
|
| 1045 |
+
if (len(self.batch_sizes) > 1) and (
|
| 1046 |
+
self.batch_sizes[sched - 1] == self.max_batch_size
|
| 1047 |
+
):
|
| 1048 |
+
# if previous batch size is already maximal, skip recomputation
|
| 1049 |
+
self.batch_sizes[sched] = self.max_batch_size
|
| 1050 |
+
return self.batch_sizes[sched]
|
| 1051 |
+
print(
|
| 1052 |
+
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
|
| 1053 |
+
)
|
| 1054 |
+
self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
|
| 1055 |
+
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
|
| 1056 |
+
return self.batch_sizes[sched]
|
| 1057 |
+
|
| 1058 |
+
def _loglikelihood_tokens(
|
| 1059 |
+
self,
|
| 1060 |
+
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
|
| 1061 |
+
disable_tqdm: bool = False,
|
| 1062 |
+
override_bs: int = None,
|
| 1063 |
+
) -> List[Tuple[float, bool]]:
|
| 1064 |
+
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
|
| 1065 |
+
res = []
|
| 1066 |
+
|
| 1067 |
+
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
|
| 1068 |
+
"""Defines the key for the sorted method"""
|
| 1069 |
+
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
| 1070 |
+
# - time estimates will always be over not underestimates, which is more useful for planning
|
| 1071 |
+
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
| 1072 |
+
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
| 1073 |
+
# automatic adaptive batches much much easier to implement
|
| 1074 |
+
# - any OOMs will happen right away rather than near the end
|
| 1075 |
+
|
| 1076 |
+
toks = req[1] + req[2]
|
| 1077 |
+
return -len(toks), tuple(toks)
|
| 1078 |
+
|
| 1079 |
+
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
|
| 1080 |
+
"""Defines the key to group and lookup one-token continuations"""
|
| 1081 |
+
# Use with group_by="contexts" (optional)"
|
| 1082 |
+
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
|
| 1083 |
+
# speeds up some multiple-choice tasks proportionally to the number of choices.
|
| 1084 |
+
# groups requests by context+continuation[:-1] and infer on one request/group.
|
| 1085 |
+
return req[-2] + req[-1][:-1]
|
| 1086 |
+
|
| 1087 |
+
re_ord = Collator(
|
| 1088 |
+
requests,
|
| 1089 |
+
sort_fn=_collate,
|
| 1090 |
+
group_by="contexts"
|
| 1091 |
+
if self.backend == "causal" and self.logits_cache
|
| 1092 |
+
else None,
|
| 1093 |
+
group_fn=_lookup_one_token_cont,
|
| 1094 |
+
)
|
| 1095 |
+
|
| 1096 |
+
# automatic (variable) batch size detection for vectorization
|
| 1097 |
+
# pull longest context sample from request
|
| 1098 |
+
n_reordered_requests = len(re_ord)
|
| 1099 |
+
batch_size = (
|
| 1100 |
+
self.batch_size
|
| 1101 |
+
if self.batch_size != "auto"
|
| 1102 |
+
else override_bs
|
| 1103 |
+
if override_bs is not None
|
| 1104 |
+
else 0
|
| 1105 |
+
)
|
| 1106 |
+
batch_fn = (
|
| 1107 |
+
self._batch_scheduler
|
| 1108 |
+
if self.batch_size == "auto"
|
| 1109 |
+
and n_reordered_requests > 0
|
| 1110 |
+
and not override_bs
|
| 1111 |
+
else None
|
| 1112 |
+
)
|
| 1113 |
+
|
| 1114 |
+
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
|
| 1115 |
+
pbar = tqdm(
|
| 1116 |
+
total=len(requests),
|
| 1117 |
+
disable=(disable_tqdm or (self.rank != 0)),
|
| 1118 |
+
desc="Running loglikelihood requests",
|
| 1119 |
+
)
|
| 1120 |
+
for chunk in chunks:
|
| 1121 |
+
inps = []
|
| 1122 |
+
cont_toks_list = []
|
| 1123 |
+
inplens = []
|
| 1124 |
+
|
| 1125 |
+
conts = []
|
| 1126 |
+
encoder_attns = []
|
| 1127 |
+
|
| 1128 |
+
padding_len_inp = None
|
| 1129 |
+
padding_len_cont = None
|
| 1130 |
+
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
|
| 1131 |
+
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
|
| 1132 |
+
# again because vectorizing is annoying
|
| 1133 |
+
|
| 1134 |
+
for _, context_enc, continuation_enc in chunk:
|
| 1135 |
+
# sanity check
|
| 1136 |
+
assert len(context_enc) > 0
|
| 1137 |
+
assert len(continuation_enc) > 0
|
| 1138 |
+
assert len(continuation_enc) <= self.max_length
|
| 1139 |
+
|
| 1140 |
+
# how this all works (illustrated on a causal decoder-only setup):
|
| 1141 |
+
# CTX CONT
|
| 1142 |
+
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
|
| 1143 |
+
# model \ \
|
| 1144 |
+
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
|
| 1145 |
+
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
|
| 1146 |
+
|
| 1147 |
+
# when too long to fit in context, truncate from the left
|
| 1148 |
+
if self.backend == "causal":
|
| 1149 |
+
total_length = len(context_enc) + len(continuation_enc)
|
| 1150 |
+
if total_length > self.max_length + 1:
|
| 1151 |
+
eval_logger.warning(
|
| 1152 |
+
f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
|
| 1153 |
+
f"exceeds model's maximum length ({self.max_length}). "
|
| 1154 |
+
f"Truncating {total_length - self.max_length + 1} tokens from the left."
|
| 1155 |
+
)
|
| 1156 |
+
inp = torch.tensor(
|
| 1157 |
+
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
|
| 1158 |
+
dtype=torch.long,
|
| 1159 |
+
device=self.device,
|
| 1160 |
+
)
|
| 1161 |
+
(inplen,) = inp.shape
|
| 1162 |
+
elif self.backend == "seq2seq":
|
| 1163 |
+
inp = torch.tensor(
|
| 1164 |
+
(context_enc)[-self.max_length :],
|
| 1165 |
+
dtype=torch.long,
|
| 1166 |
+
device=self.device,
|
| 1167 |
+
)
|
| 1168 |
+
(inplen,) = inp.shape
|
| 1169 |
+
|
| 1170 |
+
# build encoder attn masks
|
| 1171 |
+
encoder_attns.append(torch.ones_like(inp))
|
| 1172 |
+
|
| 1173 |
+
cont = torch.tensor(
|
| 1174 |
+
(continuation_enc)[-self.max_length :],
|
| 1175 |
+
# TODO: left-shift these?
|
| 1176 |
+
# TODO: our code assumes we never end up truncating conts for either model type
|
| 1177 |
+
dtype=torch.long,
|
| 1178 |
+
device=self.device,
|
| 1179 |
+
)
|
| 1180 |
+
(contlen,) = cont.shape
|
| 1181 |
+
|
| 1182 |
+
conts.append(cont)
|
| 1183 |
+
|
| 1184 |
+
padding_len_cont = (
|
| 1185 |
+
max(padding_len_cont, contlen)
|
| 1186 |
+
if padding_len_cont is not None
|
| 1187 |
+
else contlen
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
padding_len_inp = (
|
| 1191 |
+
max(padding_len_inp, inplen)
|
| 1192 |
+
if padding_len_inp is not None
|
| 1193 |
+
else inplen
|
| 1194 |
+
)
|
| 1195 |
+
|
| 1196 |
+
inps.append(inp) # [1, inp_length]
|
| 1197 |
+
cont_toks_list.append(continuation_enc)
|
| 1198 |
+
inplens.append(inplen)
|
| 1199 |
+
|
| 1200 |
+
# create encoder attn mask and batched conts, if seq2seq
|
| 1201 |
+
call_kwargs = {}
|
| 1202 |
+
if self.backend == "causal":
|
| 1203 |
+
batched_inps = pad_and_concat(
|
| 1204 |
+
padding_len_inp, inps, padding_side="right"
|
| 1205 |
+
) # [batch, padding_len_inp]
|
| 1206 |
+
elif self.backend == "seq2seq":
|
| 1207 |
+
# TODO: left-pad encoder inps and mask?
|
| 1208 |
+
batched_inps = pad_and_concat(
|
| 1209 |
+
padding_len_inp, inps
|
| 1210 |
+
) # [batch, padding_len_inp]
|
| 1211 |
+
batched_conts = pad_and_concat(
|
| 1212 |
+
padding_len_cont, conts
|
| 1213 |
+
) # [batch, padding_len_cont]
|
| 1214 |
+
batched_encoder_mask = pad_and_concat(
|
| 1215 |
+
padding_len_inp, encoder_attns
|
| 1216 |
+
) # [batch, padding_len_inp]
|
| 1217 |
+
call_kwargs = {
|
| 1218 |
+
"attn_mask": batched_encoder_mask,
|
| 1219 |
+
"labels": batched_conts,
|
| 1220 |
+
}
|
| 1221 |
+
|
| 1222 |
+
multi_logits = F.log_softmax(
|
| 1223 |
+
self._model_call(batched_inps, **call_kwargs),
|
| 1224 |
+
dim=-1,
|
| 1225 |
+
dtype=self.softmax_dtype,
|
| 1226 |
+
) # [batch, padding_length (inp or cont), vocab]
|
| 1227 |
+
|
| 1228 |
+
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
|
| 1229 |
+
chunk, multi_logits, inplens, cont_toks_list
|
| 1230 |
+
):
|
| 1231 |
+
# Slice to original seq length
|
| 1232 |
+
contlen = len(cont_toks)
|
| 1233 |
+
# take only logits in the continuation
|
| 1234 |
+
# (discard context toks if decoder-only ; discard right-padding)
|
| 1235 |
+
# also discards + checks for "virtual tokens" in the causal LM's input window
|
| 1236 |
+
# from prompt/prefix tuning tokens, if applicable
|
| 1237 |
+
ctx_len = (
|
| 1238 |
+
inplen + (logits.shape[0] - padding_len_inp)
|
| 1239 |
+
if self.backend == "causal"
|
| 1240 |
+
else None
|
| 1241 |
+
)
|
| 1242 |
+
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
|
| 1243 |
+
logits = logits.unsqueeze(0) # [1, seq, vocab]
|
| 1244 |
+
|
| 1245 |
+
# Check if per-token argmax is exactly equal to continuation
|
| 1246 |
+
greedy_tokens = logits.argmax(dim=-1)
|
| 1247 |
+
|
| 1248 |
+
# check for one-token continuation cache hits.
|
| 1249 |
+
# noop in case group_by != "contexts" or no cache hit and returns the
|
| 1250 |
+
# original args. Otherwise, expands the logits batch dimension and yields each
|
| 1251 |
+
# batch along with matching continuation tokens and prompt strings.
|
| 1252 |
+
# logits -> [1, seq, vocab]
|
| 1253 |
+
for request_str, cont_toks, logits in re_ord.get_cache(
|
| 1254 |
+
req_str=request_str,
|
| 1255 |
+
cxt_toks=ctx_tokens,
|
| 1256 |
+
cont_toks=cont_toks,
|
| 1257 |
+
logits=logits,
|
| 1258 |
+
):
|
| 1259 |
+
cont_toks = torch.tensor(
|
| 1260 |
+
cont_toks, dtype=torch.long, device=self.device
|
| 1261 |
+
).unsqueeze(0) # [1, seq]
|
| 1262 |
+
# Use trailing slice [-cont_toks.shape[1]:] to handle variable length cont_len (but same ctx+cont[:-1]).
|
| 1263 |
+
# i.e. continuations can be sliced at diff points. Collator ensures we have sufficient greedy_tokens
|
| 1264 |
+
# by choosing key with longest cont if group_by="contexts".
|
| 1265 |
+
max_equal = (
|
| 1266 |
+
greedy_tokens[:, -cont_toks.shape[1] :] == cont_toks
|
| 1267 |
+
).all()
|
| 1268 |
+
|
| 1269 |
+
# Obtain log-probs at the corresponding continuation token indices
|
| 1270 |
+
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
|
| 1271 |
+
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
|
| 1272 |
+
-1
|
| 1273 |
+
) # [1, seq]
|
| 1274 |
+
|
| 1275 |
+
# Answer: (log prob, is-exact-match)
|
| 1276 |
+
answer = (float(logits.sum()), bool(max_equal))
|
| 1277 |
+
|
| 1278 |
+
res.append(answer)
|
| 1279 |
+
|
| 1280 |
+
if request_str is not None:
|
| 1281 |
+
# special case: loglikelihood_rolling produces a number of loglikelihood requests
|
| 1282 |
+
# all with cache key None. instead do add_partial on the per-example level
|
| 1283 |
+
# in the loglikelihood_rolling() function for those.
|
| 1284 |
+
self.cache_hook.add_partial(
|
| 1285 |
+
"loglikelihood", request_str, answer
|
| 1286 |
+
)
|
| 1287 |
+
pbar.update(1)
|
| 1288 |
+
|
| 1289 |
+
pbar.close()
|
| 1290 |
+
|
| 1291 |
+
return re_ord.get_original(res)
|
| 1292 |
+
|
| 1293 |
+
def generate_until(
|
| 1294 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
| 1295 |
+
) -> List[str]:
|
| 1296 |
+
res = []
|
| 1297 |
+
|
| 1298 |
+
def _collate(req: Tuple[str, dict]):
|
| 1299 |
+
"""Defines the key for the sorted method"""
|
| 1300 |
+
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
| 1301 |
+
# - time estimates will always be over not underestimates, which is more useful for planning
|
| 1302 |
+
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
| 1303 |
+
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
| 1304 |
+
# automatic adaptive batches much much easier to implement
|
| 1305 |
+
# - any OOMs will happen right away rather than near the end
|
| 1306 |
+
toks = self.tok_encode(req[0])
|
| 1307 |
+
return -len(toks), req[0]
|
| 1308 |
+
|
| 1309 |
+
pbar = tqdm(
|
| 1310 |
+
total=len(requests),
|
| 1311 |
+
disable=(disable_tqdm or (self.rank != 0)),
|
| 1312 |
+
desc="Running generate_until requests",
|
| 1313 |
+
)
|
| 1314 |
+
adaptive_batch_size = None
|
| 1315 |
+
if self.batch_size == "auto":
|
| 1316 |
+
# using rolling window with maximum context
|
| 1317 |
+
print("Passed argument batch_size = auto. Detecting largest batch size")
|
| 1318 |
+
batch_size = self._detect_batch_size()
|
| 1319 |
+
print(f"Determined Largest batch size: {batch_size}")
|
| 1320 |
+
adaptive_batch_size = batch_size
|
| 1321 |
+
# for each different set of kwargs, we execute all requests, by batch.
|
| 1322 |
+
batch_size = (
|
| 1323 |
+
self.batch_size
|
| 1324 |
+
if self.batch_size != "auto"
|
| 1325 |
+
else adaptive_batch_size
|
| 1326 |
+
if adaptive_batch_size is not None
|
| 1327 |
+
else 0
|
| 1328 |
+
)
|
| 1329 |
+
batch_fn = (
|
| 1330 |
+
self._batch_scheduler
|
| 1331 |
+
if self.batch_size == "auto" and not adaptive_batch_size
|
| 1332 |
+
else None
|
| 1333 |
+
)
|
| 1334 |
+
|
| 1335 |
+
# we group requests by their generation_kwargs,
|
| 1336 |
+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
|
| 1337 |
+
# in the same batch.
|
| 1338 |
+
# group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
|
| 1339 |
+
re_ords = Collator(
|
| 1340 |
+
[reg.args for reg in requests],
|
| 1341 |
+
sort_fn=_collate,
|
| 1342 |
+
group_by="gen_kwargs",
|
| 1343 |
+
group_fn=lambda x: x[1],
|
| 1344 |
+
)
|
| 1345 |
+
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
|
| 1346 |
+
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
|
| 1347 |
+
for chunk in chunks:
|
| 1348 |
+
contexts, all_gen_kwargs = zip(*chunk)
|
| 1349 |
+
# we assume all gen kwargs in the batch are the same
|
| 1350 |
+
# this is safe to assume because the `grouper` object ensures it.
|
| 1351 |
+
gen_kwargs = all_gen_kwargs[0]
|
| 1352 |
+
# unpack our keyword arguments.
|
| 1353 |
+
if isinstance(gen_kwargs, dict):
|
| 1354 |
+
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
|
| 1355 |
+
# add EOS token to stop sequences
|
| 1356 |
+
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
|
| 1357 |
+
else:
|
| 1358 |
+
raise ValueError(
|
| 1359 |
+
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
|
| 1360 |
+
)
|
| 1361 |
+
if "max_gen_toks" in kwargs.keys():
|
| 1362 |
+
max_gen_toks = kwargs.pop("max_gen_toks")
|
| 1363 |
+
else:
|
| 1364 |
+
max_gen_toks = self.max_gen_toks
|
| 1365 |
+
|
| 1366 |
+
# set the max length in tokens of inputs ("context_enc")
|
| 1367 |
+
if self.backend == "causal":
|
| 1368 |
+
# max len for inputs = max length, minus room to generate the max new tokens
|
| 1369 |
+
max_ctx_len = self.max_length - max_gen_toks
|
| 1370 |
+
assert max_ctx_len > 0, (
|
| 1371 |
+
f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
|
| 1372 |
+
)
|
| 1373 |
+
elif self.backend == "seq2seq":
|
| 1374 |
+
# max len for inputs = encoder's whole max_length
|
| 1375 |
+
max_ctx_len = self.max_length
|
| 1376 |
+
|
| 1377 |
+
# encode, pad, and truncate contexts for this batch
|
| 1378 |
+
context_enc, attn_masks = self.tok_batch_encode(
|
| 1379 |
+
contexts,
|
| 1380 |
+
left_truncate_len=max_ctx_len,
|
| 1381 |
+
truncation=self.truncation,
|
| 1382 |
+
)
|
| 1383 |
+
context_enc = context_enc.to(self.device)
|
| 1384 |
+
attn_masks = attn_masks.to(self.device)
|
| 1385 |
+
|
| 1386 |
+
if "max_length" not in kwargs:
|
| 1387 |
+
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
|
| 1388 |
+
|
| 1389 |
+
# perform batched generation
|
| 1390 |
+
cont = self._model_generate(
|
| 1391 |
+
context=context_enc,
|
| 1392 |
+
attention_mask=attn_masks,
|
| 1393 |
+
stop=until,
|
| 1394 |
+
**kwargs,
|
| 1395 |
+
)
|
| 1396 |
+
|
| 1397 |
+
cont_toks_list = cont.tolist()
|
| 1398 |
+
for cont_toks, context in zip(cont_toks_list, contexts):
|
| 1399 |
+
# discard context + left-padding toks if using causal decoder-only LM
|
| 1400 |
+
if self.backend == "causal":
|
| 1401 |
+
cont_toks = cont_toks[context_enc.shape[1] :]
|
| 1402 |
+
|
| 1403 |
+
s = self.tok_decode(cont_toks)
|
| 1404 |
+
|
| 1405 |
+
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
|
| 1406 |
+
for term in until:
|
| 1407 |
+
if len(term) > 0:
|
| 1408 |
+
# ignore '' separator,
|
| 1409 |
+
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
|
| 1410 |
+
s = s.split(term)[0]
|
| 1411 |
+
|
| 1412 |
+
res.append(s)
|
| 1413 |
+
|
| 1414 |
+
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
|
| 1415 |
+
pbar.update(1)
|
| 1416 |
+
# reorder this group of results back to original unsorted form
|
| 1417 |
+
res = re_ords.get_original(res)
|
| 1418 |
+
|
| 1419 |
+
pbar.close()
|
| 1420 |
+
|
| 1421 |
+
return res
|
| 1422 |
+
|
| 1423 |
+
def apply_chat_template(
|
| 1424 |
+
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
|
| 1425 |
+
) -> str:
|
| 1426 |
+
"""
|
| 1427 |
+
Method to apply a chat template to a list of chat history between user and model.
|
| 1428 |
+
"""
|
| 1429 |
+
try:
|
| 1430 |
+
chat_templated = self.tokenizer.apply_chat_template(
|
| 1431 |
+
chat_history,
|
| 1432 |
+
tokenize=False,
|
| 1433 |
+
add_generation_prompt=add_generation_prompt,
|
| 1434 |
+
continue_final_message=not add_generation_prompt,
|
| 1435 |
+
)
|
| 1436 |
+
except jinja2.exceptions.TemplateError:
|
| 1437 |
+
eval_logger.warning(
|
| 1438 |
+
"Failed to apply chat template. removing the system role in chat history."
|
| 1439 |
+
)
|
| 1440 |
+
chat_history = [msg for msg in chat_history if msg["role"] != "system"]
|
| 1441 |
+
chat_templated = self.tokenizer.apply_chat_template(
|
| 1442 |
+
chat_history,
|
| 1443 |
+
tokenize=False,
|
| 1444 |
+
add_generation_prompt=add_generation_prompt,
|
| 1445 |
+
continue_final_message=not add_generation_prompt,
|
| 1446 |
+
)
|
| 1447 |
+
|
| 1448 |
+
return chat_templated
|
| 1449 |
+
|
| 1450 |
+
def get_model_info(self) -> dict:
|
| 1451 |
+
"""
|
| 1452 |
+
Method to get Hugging Face model information for experiment reproducibility.
|
| 1453 |
+
"""
|
| 1454 |
+
|
| 1455 |
+
def get_model_num_params(model) -> int:
|
| 1456 |
+
if hasattr(model, "num_parameters"):
|
| 1457 |
+
return model.num_parameters()
|
| 1458 |
+
if hasattr(model, "parameters"):
|
| 1459 |
+
return sum(p.numel() for p in model.parameters())
|
| 1460 |
+
else:
|
| 1461 |
+
return -1
|
| 1462 |
+
|
| 1463 |
+
def get_model_dtype(model) -> str:
|
| 1464 |
+
if hasattr(model, "dtype"):
|
| 1465 |
+
return model.dtype
|
| 1466 |
+
else:
|
| 1467 |
+
return ""
|
| 1468 |
+
|
| 1469 |
+
def get_model_sha(pretrained: str, revision: str) -> str:
|
| 1470 |
+
try:
|
| 1471 |
+
model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
|
| 1472 |
+
return model_info.sha
|
| 1473 |
+
except Exception as e:
|
| 1474 |
+
eval_logger.debug(
|
| 1475 |
+
f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
|
| 1476 |
+
)
|
| 1477 |
+
return ""
|
| 1478 |
+
|
| 1479 |
+
model_info = {
|
| 1480 |
+
"model_num_parameters": get_model_num_params(self._model),
|
| 1481 |
+
"model_dtype": get_model_dtype(self._model),
|
| 1482 |
+
"model_revision": self.revision,
|
| 1483 |
+
"model_sha": get_model_sha(self.pretrained, self.revision),
|
| 1484 |
+
}
|
| 1485 |
+
if self.peft:
|
| 1486 |
+
model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
|
| 1487 |
+
if self.delta:
|
| 1488 |
+
model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
|
| 1489 |
+
return model_info
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/utils.py
ADDED
|
@@ -0,0 +1,854 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import fnmatch
|
| 3 |
+
import gc
|
| 4 |
+
import itertools
|
| 5 |
+
import logging
|
| 6 |
+
import time
|
| 7 |
+
from functools import wraps
|
| 8 |
+
from typing import (
|
| 9 |
+
TYPE_CHECKING,
|
| 10 |
+
Any,
|
| 11 |
+
Callable,
|
| 12 |
+
Dict,
|
| 13 |
+
Iterable,
|
| 14 |
+
Iterator,
|
| 15 |
+
List,
|
| 16 |
+
Literal,
|
| 17 |
+
Optional,
|
| 18 |
+
Tuple,
|
| 19 |
+
Type,
|
| 20 |
+
Union,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import transformers
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
eval_logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if TYPE_CHECKING:
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from transformers import PreTrainedTokenizerBase
|
| 33 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def chunks(iter, n: int = 0, fn=None):
|
| 37 |
+
"""
|
| 38 |
+
Divides an iterable into chunks of specified size or based on a given function.
|
| 39 |
+
Useful for batching
|
| 40 |
+
|
| 41 |
+
Parameters:
|
| 42 |
+
- iter: The input iterable to be divided into chunks.
|
| 43 |
+
- n: An integer representing the size of each chunk. Default is 0.
|
| 44 |
+
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
An iterator that yields chunks of the input iterable.
|
| 48 |
+
|
| 49 |
+
Example usage:
|
| 50 |
+
```
|
| 51 |
+
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 52 |
+
for chunk in chunks(data, 3):
|
| 53 |
+
print(chunk)
|
| 54 |
+
```
|
| 55 |
+
Output:
|
| 56 |
+
```
|
| 57 |
+
[1, 2, 3]
|
| 58 |
+
[4, 5, 6]
|
| 59 |
+
[7, 8, 9]
|
| 60 |
+
[10]
|
| 61 |
+
```
|
| 62 |
+
"""
|
| 63 |
+
arr = []
|
| 64 |
+
for i, x in enumerate(iter):
|
| 65 |
+
arr.append(x)
|
| 66 |
+
if len(arr) == (fn(i, iter) if fn else n):
|
| 67 |
+
yield arr
|
| 68 |
+
arr = []
|
| 69 |
+
|
| 70 |
+
if arr:
|
| 71 |
+
yield arr
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class MultiChoice:
|
| 75 |
+
def __init__(self, choices) -> None:
|
| 76 |
+
self.choices = choices
|
| 77 |
+
|
| 78 |
+
# Simple wildcard support (linux filename patterns)
|
| 79 |
+
def __contains__(self, values) -> bool:
|
| 80 |
+
for value in values.split(","):
|
| 81 |
+
if len(fnmatch.filter(self.choices, value)) == 0:
|
| 82 |
+
eval_logger.info("Available tasks to choose:")
|
| 83 |
+
for choice in self.choices:
|
| 84 |
+
eval_logger.info(f" - {choice}")
|
| 85 |
+
raise ValueError("'{}' is not in task list".format(value))
|
| 86 |
+
return True
|
| 87 |
+
|
| 88 |
+
def __iter__(self) -> Iterator:
|
| 89 |
+
for choice in self.choices:
|
| 90 |
+
yield choice
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class Grouper:
|
| 94 |
+
"""
|
| 95 |
+
takes an array `arr` and function `fn` and returns a dictionary
|
| 96 |
+
with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all
|
| 97 |
+
objects in `arr` satisfying `key == fn(ob)`.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, arr, fn) -> None:
|
| 101 |
+
# self.orig_arr = arr
|
| 102 |
+
self.size = len(arr)
|
| 103 |
+
arr = list(enumerate(arr))
|
| 104 |
+
|
| 105 |
+
def group_return_dict(arr, fn):
|
| 106 |
+
res = collections.defaultdict(list)
|
| 107 |
+
|
| 108 |
+
for ob in arr:
|
| 109 |
+
res[fn(ob)].append(ob)
|
| 110 |
+
return res
|
| 111 |
+
|
| 112 |
+
arr = group_return_dict(arr, lambda x: fn(x[1]))
|
| 113 |
+
|
| 114 |
+
# self.arr has format Dict[Tuple[int, <entry from orig. arr>]]
|
| 115 |
+
self.arr = arr
|
| 116 |
+
self._grouped = None
|
| 117 |
+
|
| 118 |
+
def get_grouped(self):
|
| 119 |
+
# return the contents but not indices for our grouped dict.
|
| 120 |
+
if self._grouped:
|
| 121 |
+
return self._grouped
|
| 122 |
+
grouped = {}
|
| 123 |
+
for key in self.arr.keys():
|
| 124 |
+
# drop the index from each element of self.arr
|
| 125 |
+
grouped[key] = [y[1] for y in self.arr[key]]
|
| 126 |
+
self._grouped = grouped
|
| 127 |
+
return grouped
|
| 128 |
+
|
| 129 |
+
def get_original(self, grouped_dict):
|
| 130 |
+
# take in a grouped dictionary with e.g. results for each key listed
|
| 131 |
+
# in the same order as the instances in `self.arr`, and
|
| 132 |
+
# return the results in the same (single list) order as `self.orig_arr`.
|
| 133 |
+
res = [None] * self.size
|
| 134 |
+
cov = [False] * self.size
|
| 135 |
+
# orig = [None] * self.size
|
| 136 |
+
|
| 137 |
+
assert grouped_dict.keys() == self.arr.keys()
|
| 138 |
+
|
| 139 |
+
for key in grouped_dict.keys():
|
| 140 |
+
for (ind, _), v in zip(self.arr[key], grouped_dict[key]):
|
| 141 |
+
res[ind] = v
|
| 142 |
+
cov[ind] = True
|
| 143 |
+
# orig[ind] = _
|
| 144 |
+
|
| 145 |
+
assert all(cov)
|
| 146 |
+
# assert orig == self.orig_arr
|
| 147 |
+
|
| 148 |
+
return res
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def pad_and_concat(
|
| 152 |
+
max_length: int,
|
| 153 |
+
tensors: List[torch.Tensor],
|
| 154 |
+
padding_side: Literal["right", "left"] = "right",
|
| 155 |
+
):
|
| 156 |
+
"""
|
| 157 |
+
Method for padding a list of tensors given the maximum tensor
|
| 158 |
+
length in the batch. Used for batching inputs and continuations in
|
| 159 |
+
seq2seq models.
|
| 160 |
+
"""
|
| 161 |
+
assert padding_side == "left" or padding_side == "right", (
|
| 162 |
+
f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
for i, tensor in enumerate(tensors):
|
| 166 |
+
if len(tensor.shape) == 2:
|
| 167 |
+
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
|
| 168 |
+
tensor_len = tensor.shape[0]
|
| 169 |
+
if tensor_len < max_length:
|
| 170 |
+
if padding_side == "right":
|
| 171 |
+
# right-pad
|
| 172 |
+
tensors[i] = torch.cat(
|
| 173 |
+
[
|
| 174 |
+
tensor, # [seq]
|
| 175 |
+
torch.zeros(
|
| 176 |
+
max_length - tensor_len,
|
| 177 |
+
dtype=torch.long,
|
| 178 |
+
device=tensor.device,
|
| 179 |
+
), # [padding_length - seq]
|
| 180 |
+
],
|
| 181 |
+
dim=0,
|
| 182 |
+
).unsqueeze(0)
|
| 183 |
+
else:
|
| 184 |
+
# left-pad
|
| 185 |
+
tensors[i] = torch.cat(
|
| 186 |
+
[
|
| 187 |
+
torch.zeros(
|
| 188 |
+
max_length - tensor_len,
|
| 189 |
+
dtype=torch.long,
|
| 190 |
+
device=tensor.device,
|
| 191 |
+
), # [padding_length - seq]
|
| 192 |
+
tensor, # [seq]
|
| 193 |
+
],
|
| 194 |
+
dim=0,
|
| 195 |
+
).unsqueeze(0)
|
| 196 |
+
else:
|
| 197 |
+
tensors[i] = tensor.unsqueeze(0)
|
| 198 |
+
|
| 199 |
+
return torch.cat(tensors, dim=0)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def clear_torch_cache() -> None:
|
| 203 |
+
gc.collect()
|
| 204 |
+
torch.cuda.empty_cache()
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
|
| 208 |
+
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
|
| 209 |
+
if isinstance(dtype, str) and dtype != "auto":
|
| 210 |
+
# Convert `str` args torch dtype: `float16` -> `torch.float16`
|
| 211 |
+
_torch_dtype = getattr(torch, dtype)
|
| 212 |
+
else:
|
| 213 |
+
_torch_dtype = dtype
|
| 214 |
+
return _torch_dtype
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
|
| 218 |
+
"""Criteria to stop on the specified multi-token sequence."""
|
| 219 |
+
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
sequence: str,
|
| 223 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 224 |
+
initial_decoder_input_length: int,
|
| 225 |
+
batch_size: int,
|
| 226 |
+
) -> None:
|
| 227 |
+
self.initial_decoder_input_length = initial_decoder_input_length
|
| 228 |
+
self.done_tracker = [False] * batch_size
|
| 229 |
+
self.sequence = sequence
|
| 230 |
+
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
|
| 231 |
+
# print(sequence, self.sequence_ids)
|
| 232 |
+
# we look back for 2 more tokens than it takes to encode our stop sequence
|
| 233 |
+
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
|
| 234 |
+
# and we don't want to mistakenly not stop a generation because our
|
| 235 |
+
# (string) stop sequence was output in a different tokenization
|
| 236 |
+
|
| 237 |
+
# NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
|
| 238 |
+
# and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
|
| 239 |
+
# Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
|
| 240 |
+
self.sequence_id_len = len(self.sequence_ids) + 2
|
| 241 |
+
self.tokenizer = tokenizer
|
| 242 |
+
|
| 243 |
+
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
| 244 |
+
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
|
| 245 |
+
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]
|
| 246 |
+
|
| 247 |
+
lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
|
| 248 |
+
|
| 249 |
+
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
|
| 250 |
+
|
| 251 |
+
for i, done in enumerate(self.done_tracker):
|
| 252 |
+
if not done:
|
| 253 |
+
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
|
| 254 |
+
return False not in self.done_tracker
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def stop_sequences_criteria(
|
| 258 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 259 |
+
stop_sequences: List[str],
|
| 260 |
+
initial_decoder_input_length: int,
|
| 261 |
+
batch_size: int,
|
| 262 |
+
) -> transformers.StoppingCriteriaList:
|
| 263 |
+
return transformers.StoppingCriteriaList(
|
| 264 |
+
[
|
| 265 |
+
*[
|
| 266 |
+
MultiTokenEOSCriteria(
|
| 267 |
+
sequence, tokenizer, initial_decoder_input_length, batch_size
|
| 268 |
+
)
|
| 269 |
+
for sequence in stop_sequences
|
| 270 |
+
],
|
| 271 |
+
]
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def undistribute(iterable):
|
| 276 |
+
"""
|
| 277 |
+
Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute .
|
| 278 |
+
|
| 279 |
+
Re-interleaves results that have been split using more_itertools.distribute:
|
| 280 |
+
>>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
|
| 281 |
+
>>> list(group_1)
|
| 282 |
+
[1, 3, 5]
|
| 283 |
+
>>> list(group_2)
|
| 284 |
+
[2, 4, 6]
|
| 285 |
+
>>> undistribute([group_1, group_2])
|
| 286 |
+
[1, 2, 3, 4, 5, 6]
|
| 287 |
+
|
| 288 |
+
Handles non-uniform component lengths:
|
| 289 |
+
|
| 290 |
+
>>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
|
| 291 |
+
>>> [list(c) for c in children]
|
| 292 |
+
[[1, 4, 7], [2, 5], [3, 6]]
|
| 293 |
+
>>> undistribute(children)
|
| 294 |
+
[1, 2, 3, 4, 5, 6, 7]
|
| 295 |
+
|
| 296 |
+
Also handles when some iterables are empty:
|
| 297 |
+
|
| 298 |
+
>>> children = distribute(5, [1, 2, 3])
|
| 299 |
+
>>> [list(c) for c in children]
|
| 300 |
+
[[1], [2], [3], [], []]
|
| 301 |
+
>>> undistribute(children)
|
| 302 |
+
[1, 2, 3]
|
| 303 |
+
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
return [
|
| 307 |
+
x
|
| 308 |
+
for x in itertools.chain.from_iterable(
|
| 309 |
+
itertools.zip_longest(*[list(x) for x in iterable])
|
| 310 |
+
)
|
| 311 |
+
if x is not None
|
| 312 |
+
]
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def retry_on_specific_exceptions(
|
| 316 |
+
on_exceptions: List[Type[Exception]],
|
| 317 |
+
max_retries: Optional[int] = None,
|
| 318 |
+
backoff_time: float = 3.0,
|
| 319 |
+
backoff_multiplier: float = 1.5,
|
| 320 |
+
on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
|
| 321 |
+
):
|
| 322 |
+
"""Retry on an LLM Provider's rate limit error with exponential backoff
|
| 323 |
+
For example, to use for OpenAI, do the following:
|
| 324 |
+
```
|
| 325 |
+
from openai import RateLimitError
|
| 326 |
+
|
| 327 |
+
# Recommend specifying max_retries to avoid infinite loops!
|
| 328 |
+
@retry_on_specific_exceptions([RateLimitError], max_retries=3)
|
| 329 |
+
def completion(...):
|
| 330 |
+
# Wrap OpenAI completion function here
|
| 331 |
+
...
|
| 332 |
+
```
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
def decorator(func: Callable):
|
| 336 |
+
@wraps(func)
|
| 337 |
+
def wrapper(*args, **kwargs):
|
| 338 |
+
sleep_time = backoff_time
|
| 339 |
+
attempt = 0
|
| 340 |
+
while max_retries is None or attempt < max_retries:
|
| 341 |
+
try:
|
| 342 |
+
return func(*args, **kwargs)
|
| 343 |
+
except tuple(on_exceptions) as e:
|
| 344 |
+
if on_exception_callback is not None:
|
| 345 |
+
on_exception_callback(e, sleep_time)
|
| 346 |
+
time.sleep(sleep_time)
|
| 347 |
+
sleep_time *= backoff_multiplier
|
| 348 |
+
attempt += 1
|
| 349 |
+
|
| 350 |
+
return wrapper
|
| 351 |
+
|
| 352 |
+
return decorator
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class Collator:
|
| 356 |
+
"""
|
| 357 |
+
A class for reordering and batching elements of an array.
|
| 358 |
+
|
| 359 |
+
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
|
| 360 |
+
|
| 361 |
+
Objects of this class have the group_by attribute which determines the method for grouping
|
| 362 |
+
the data while batching it. Three options include "gen_kwargs", "contexts", or None:
|
| 363 |
+
If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs
|
| 364 |
+
If group_by == "contexts" then requests will be grouped by context + cont[:-1]
|
| 365 |
+
If None then requests will just be reordered by length descending.
|
| 366 |
+
"""
|
| 367 |
+
|
| 368 |
+
def __init__(
|
| 369 |
+
self,
|
| 370 |
+
arr: List,
|
| 371 |
+
sort_fn: Callable = lambda x: x,
|
| 372 |
+
group_fn: Callable = lambda x: x[1],
|
| 373 |
+
group_by: Union[Literal["gen_kwargs", "contexts"], None] = None,
|
| 374 |
+
) -> None:
|
| 375 |
+
self._group_by = group_by
|
| 376 |
+
# 0 indices are enumerated indices. Apply functions to original arr.
|
| 377 |
+
self._sort_fn = lambda x: sort_fn(x[1])
|
| 378 |
+
self._group_fn = lambda x: group_fn(x[1])
|
| 379 |
+
self._reorder_indices: List = []
|
| 380 |
+
self._size = len(arr)
|
| 381 |
+
self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple(
|
| 382 |
+
enumerate(arr)
|
| 383 |
+
) # [indices, (arr)]
|
| 384 |
+
if self._group_by == "contexts":
|
| 385 |
+
self._group_by_context()
|
| 386 |
+
elif self._group_by == "gen_kwargs":
|
| 387 |
+
self._group_by_index()
|
| 388 |
+
|
| 389 |
+
def _group_by_index(self) -> None:
|
| 390 |
+
"""Group the elements of a list based on their indices."""
|
| 391 |
+
self._arr_with_indices = self.group(
|
| 392 |
+
self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs"
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
def _group_by_context(self) -> None:
|
| 396 |
+
"""Group the array with indices by context."""
|
| 397 |
+
self._arr_with_indices = self.group(
|
| 398 |
+
self._arr_with_indices, fn=self._group_fn, group_by="contexts"
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
|
| 402 |
+
"""
|
| 403 |
+
Generates and yields batches from the reordered array. The method of grouping and batching
|
| 404 |
+
depends on the parameter `group_by`.
|
| 405 |
+
If `group_by` is set to "gen_kwargs", it will batch the
|
| 406 |
+
re-ordered values with same gen_kwargs for each batch.
|
| 407 |
+
If `group_by` is "contexts", it caches the requests by context before batching.
|
| 408 |
+
If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array
|
| 409 |
+
|
| 410 |
+
Parameters:
|
| 411 |
+
- n (int): The size of each batch. Defaults to 1.
|
| 412 |
+
- batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
|
| 413 |
+
each batch. Optional, defaults to None.
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
Iterator: An iterator over batches of reordered elements grouped as per the `group_by`
|
| 417 |
+
attribute.
|
| 418 |
+
|
| 419 |
+
Yields:
|
| 420 |
+
List of batched elements according to the `group_by` attribute.
|
| 421 |
+
"""
|
| 422 |
+
if self._group_by == "gen_kwargs":
|
| 423 |
+
for (
|
| 424 |
+
key,
|
| 425 |
+
values,
|
| 426 |
+
) in self._arr_with_indices.items(): # type: ignore
|
| 427 |
+
values = self._reorder(values)
|
| 428 |
+
batch = self.get_chunks(values, n=n, fn=batch_fn)
|
| 429 |
+
yield from batch
|
| 430 |
+
elif self._group_by == "contexts":
|
| 431 |
+
# Get one sample from each key.
|
| 432 |
+
# Select longest continuation per group to ensure sufficient context logits
|
| 433 |
+
values = self._reorder(
|
| 434 |
+
[
|
| 435 |
+
max(value, key=lambda x: len(x[1][-1]))
|
| 436 |
+
for value in self._arr_with_indices.values()
|
| 437 |
+
]
|
| 438 |
+
)
|
| 439 |
+
batch = self.get_chunks(values, n=n, fn=batch_fn)
|
| 440 |
+
yield from batch
|
| 441 |
+
else:
|
| 442 |
+
values = self._reorder(self._arr_with_indices) # type: ignore
|
| 443 |
+
batch = self.get_chunks(values, n=n, fn=batch_fn)
|
| 444 |
+
yield from batch
|
| 445 |
+
|
| 446 |
+
def get_cache(
|
| 447 |
+
self,
|
| 448 |
+
req_str: Tuple[str, str] = None,
|
| 449 |
+
cxt_toks: List[int] = None,
|
| 450 |
+
cont_toks: List[int] = None,
|
| 451 |
+
logits: torch.Tensor = None,
|
| 452 |
+
) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]:
|
| 453 |
+
"""
|
| 454 |
+
Retrieves cached single-token continuations and their associated arguments, updating indices as necessary.
|
| 455 |
+
|
| 456 |
+
The behavior of this function varies depending on how the `group_by` attribute is set:
|
| 457 |
+
|
| 458 |
+
- When `group_by` is "contexts":
|
| 459 |
+
The function identifies single-token continuations by checking for keys that equate to
|
| 460 |
+
[context+continuation][-1] and logs the indices for re-ordering.
|
| 461 |
+
In this mode, this function can work in two scenarios:
|
| 462 |
+
|
| 463 |
+
1. Cache Hit - Single Match:
|
| 464 |
+
If a single matching context-continuation pair is found in the cache,
|
| 465 |
+
the function yields the original arguments.
|
| 466 |
+
|
| 467 |
+
2. Cache Hit - Multiple Matches:
|
| 468 |
+
If multiple matching context-continuation pairs are found in the cache,
|
| 469 |
+
the function expands the logits batch dimension to match the number of cache hits.
|
| 470 |
+
It updates the original requests and continuation tokens.
|
| 471 |
+
|
| 472 |
+
- When `group_by` is not set to "contexts":
|
| 473 |
+
This method yields the original arguments, logits and continuation tokens,
|
| 474 |
+
without checking for one-token continuations.
|
| 475 |
+
|
| 476 |
+
Parameters:
|
| 477 |
+
- req_str (tuple[str, str]): Original strings used for CachingLM.
|
| 478 |
+
- cxt_toks (list[int]): Full context tokens used for lookup.
|
| 479 |
+
- cont_toks (list[int]): Continuation tokens for which logits were generated.
|
| 480 |
+
- logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys.
|
| 481 |
+
|
| 482 |
+
Yields:
|
| 483 |
+
- Iterator:
|
| 484 |
+
- req_str (tuple[str, str]): strings used for CachingLM.
|
| 485 |
+
- cont_toks (list[int]) : continuation tokens.
|
| 486 |
+
- logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times)
|
| 487 |
+
"""
|
| 488 |
+
if self._group_by == "contexts":
|
| 489 |
+
cache_hit: List[
|
| 490 |
+
Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]]
|
| 491 |
+
] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1]))
|
| 492 |
+
if (cache_size := len(cache_hit)) == 1:
|
| 493 |
+
self._reorder_indices.extend(x[0] for x in cache_hit)
|
| 494 |
+
yield req_str, cont_toks, logits
|
| 495 |
+
else:
|
| 496 |
+
# If we have matching requests then expand the batch dimension (no-op) and
|
| 497 |
+
# yield each along with its corresponding args.
|
| 498 |
+
multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size)
|
| 499 |
+
indices, req_str, cont_toks = zip(
|
| 500 |
+
*[(x[0], x[1][0], x[-1][-1]) for x in cache_hit]
|
| 501 |
+
)
|
| 502 |
+
self._reorder_indices.extend(indices)
|
| 503 |
+
for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits):
|
| 504 |
+
yield c_key, cont_tok, logit
|
| 505 |
+
else:
|
| 506 |
+
yield req_str, cont_toks, logits
|
| 507 |
+
|
| 508 |
+
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator:
|
| 509 |
+
"""
|
| 510 |
+
Reorders the elements in the array based on the sorting function.
|
| 511 |
+
|
| 512 |
+
Parameters:
|
| 513 |
+
- arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered.
|
| 514 |
+
|
| 515 |
+
Yields:
|
| 516 |
+
Iterator
|
| 517 |
+
"""
|
| 518 |
+
arr = sorted(arr, key=self._sort_fn)
|
| 519 |
+
if not self._group_by == "contexts":
|
| 520 |
+
# If grouped by contexts then indices will be set in get_cache()
|
| 521 |
+
self._reorder_indices.extend([x[0] for x in arr])
|
| 522 |
+
yield from [x[1] for x in arr]
|
| 523 |
+
|
| 524 |
+
def get_original(self, newarr: List) -> List:
|
| 525 |
+
"""
|
| 526 |
+
Restores the original order of elements from the reordered list.
|
| 527 |
+
|
| 528 |
+
Parameters:
|
| 529 |
+
- newarr (list): The reordered array.
|
| 530 |
+
|
| 531 |
+
Returns:
|
| 532 |
+
list: The array with elements restored to their original order.
|
| 533 |
+
"""
|
| 534 |
+
res = [None] * self._size
|
| 535 |
+
cov = [False] * self._size
|
| 536 |
+
|
| 537 |
+
for ind, v in zip(self._reorder_indices, newarr):
|
| 538 |
+
res[ind] = v
|
| 539 |
+
cov[ind] = True
|
| 540 |
+
|
| 541 |
+
assert all(cov)
|
| 542 |
+
|
| 543 |
+
return res
|
| 544 |
+
|
| 545 |
+
def __len__(self):
|
| 546 |
+
return self._size
|
| 547 |
+
|
| 548 |
+
@staticmethod
|
| 549 |
+
def group(
|
| 550 |
+
arr: Iterable,
|
| 551 |
+
fn: Callable,
|
| 552 |
+
group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs",
|
| 553 |
+
) -> dict:
|
| 554 |
+
"""
|
| 555 |
+
Groups elements of an iterable based on a provided function.
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
The `group_by` parameter determines the method of grouping.
|
| 559 |
+
If `group_by` is "contexts", the elements are grouped by [context + cont][:-1].
|
| 560 |
+
If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict.
|
| 561 |
+
|
| 562 |
+
Parameters:
|
| 563 |
+
- arr (Iterable): The iterable to be grouped.
|
| 564 |
+
- fn (Callable): The function to determine the grouping.
|
| 565 |
+
- values (bool): If True, returns the values of the group. Defaults to False.
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
Iterator: An iterable of grouped elements.
|
| 569 |
+
"""
|
| 570 |
+
res = collections.defaultdict(list)
|
| 571 |
+
for ob in arr:
|
| 572 |
+
# where ob == [context + cont]
|
| 573 |
+
if group_by == "contexts":
|
| 574 |
+
res[tuple(fn(ob))].append(ob)
|
| 575 |
+
else:
|
| 576 |
+
try:
|
| 577 |
+
hashable_dict = tuple(
|
| 578 |
+
(
|
| 579 |
+
key,
|
| 580 |
+
tuple(value)
|
| 581 |
+
if isinstance(value, collections.abc.Iterable)
|
| 582 |
+
else value,
|
| 583 |
+
)
|
| 584 |
+
for key, value in sorted(fn(ob).items())
|
| 585 |
+
)
|
| 586 |
+
res[hashable_dict].append(ob)
|
| 587 |
+
except (TypeError, AttributeError):
|
| 588 |
+
res[tuple(fn(ob))].append(ob)
|
| 589 |
+
return res
|
| 590 |
+
|
| 591 |
+
@staticmethod
|
| 592 |
+
def get_chunks(_iter, n: int = 0, fn=None):
|
| 593 |
+
"""
|
| 594 |
+
Divides an iterable into chunks of specified size or based on a given function.
|
| 595 |
+
Useful for batching
|
| 596 |
+
|
| 597 |
+
Parameters:
|
| 598 |
+
- iter: The input iterable to be divided into chunks.
|
| 599 |
+
- n: An integer representing the size of each chunk. Default is 0.
|
| 600 |
+
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
|
| 601 |
+
|
| 602 |
+
Returns:
|
| 603 |
+
An iterator that yields chunks of the input iterable.
|
| 604 |
+
|
| 605 |
+
Example usage:
|
| 606 |
+
```
|
| 607 |
+
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 608 |
+
for chunk in chunks(data, 3):
|
| 609 |
+
print(chunk)
|
| 610 |
+
```
|
| 611 |
+
Output:
|
| 612 |
+
```
|
| 613 |
+
[1, 2, 3]
|
| 614 |
+
[4, 5, 6]
|
| 615 |
+
[7, 8, 9]
|
| 616 |
+
[10]
|
| 617 |
+
```
|
| 618 |
+
"""
|
| 619 |
+
arr = []
|
| 620 |
+
_iter = tuple(_iter)
|
| 621 |
+
for i, x in enumerate(_iter):
|
| 622 |
+
arr.append(x)
|
| 623 |
+
if len(arr) == (fn(i, _iter) if fn else n):
|
| 624 |
+
yield arr
|
| 625 |
+
arr = []
|
| 626 |
+
|
| 627 |
+
if arr:
|
| 628 |
+
yield arr
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def configure_pad_token(
|
| 632 |
+
tokenizer: "PreTrainedTokenizerBase",
|
| 633 |
+
model_config: Optional["PretrainedConfig"] = None,
|
| 634 |
+
) -> "PreTrainedTokenizerBase":
|
| 635 |
+
"""
|
| 636 |
+
This function checks if the (Hugging Face) tokenizer has a padding token and sets it if not present.
|
| 637 |
+
Some tokenizers require special handling.
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
tokenizer: The tokenizer for which the padding token is to be handled.
|
| 641 |
+
model_config: The configuration of the model. Default is None.
|
| 642 |
+
|
| 643 |
+
Returns:
|
| 644 |
+
The tokenizer after the padding token has been handled.
|
| 645 |
+
|
| 646 |
+
Raises:
|
| 647 |
+
AssertionError: If the tokenizer is of type RWKVWorldTokenizer or Rwkv5Tokenizer and the padding token id is not 0.
|
| 648 |
+
"""
|
| 649 |
+
if tokenizer.pad_token:
|
| 650 |
+
pass
|
| 651 |
+
elif tokenizer.unk_token:
|
| 652 |
+
tokenizer.pad_token_id = tokenizer.unk_token_id
|
| 653 |
+
elif tokenizer.eos_token:
|
| 654 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 655 |
+
else:
|
| 656 |
+
# handle special cases
|
| 657 |
+
if model_config and getattr(model_config, "model_type", None) == "qwen":
|
| 658 |
+
# Qwen's trust_remote_code tokenizer does not allow for adding special tokens
|
| 659 |
+
tokenizer.pad_token = "<|endoftext|>"
|
| 660 |
+
elif (
|
| 661 |
+
tokenizer.__class__.__name__ == "RWKVWorldTokenizer"
|
| 662 |
+
or tokenizer.__class__.__name__ == "Rwkv5Tokenizer"
|
| 663 |
+
):
|
| 664 |
+
# The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0)
|
| 665 |
+
# The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer
|
| 666 |
+
# ---
|
| 667 |
+
# Note that the world tokenizer class name, might change in the future for the final huggingface merge
|
| 668 |
+
# https://github.com/huggingface/transformers/pull/26963
|
| 669 |
+
assert tokenizer.pad_token_id == 0
|
| 670 |
+
else:
|
| 671 |
+
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
|
| 672 |
+
|
| 673 |
+
return tokenizer
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def replace_placeholders(
|
| 677 |
+
string: str, default_placeholder: str, image_token: str, max_images: int
|
| 678 |
+
):
|
| 679 |
+
"""
|
| 680 |
+
A utility function used for local multimodal models. It locates all `placeholder` string
|
| 681 |
+
occurrences in the given input `string_` and replaces the first `max_count` instances with
|
| 682 |
+
`replacement`, and all subsequent occurrences with the empty string.
|
| 683 |
+
|
| 684 |
+
This is used to replace <image> placeholder tags by model-specific image tokens like <|image_pad|>
|
| 685 |
+
and to allow for only the first `max_count` images to be passed to a model if desired.
|
| 686 |
+
|
| 687 |
+
:param string: The original string containing placeholders.
|
| 688 |
+
:param default_placeholder: The placeholder text to be replaced.
|
| 689 |
+
:param image_token: The token to replace the placeholder with.
|
| 690 |
+
:param max_images: The maximum number of replacements to make.
|
| 691 |
+
:return: The string with placeholders replaced.
|
| 692 |
+
"""
|
| 693 |
+
count = 0
|
| 694 |
+
result = []
|
| 695 |
+
|
| 696 |
+
parts = string.split(default_placeholder)
|
| 697 |
+
for part in parts[:-1]: # Iterate through all but the last part
|
| 698 |
+
result.append(part)
|
| 699 |
+
if count < max_images:
|
| 700 |
+
result.append(image_token)
|
| 701 |
+
count += 1
|
| 702 |
+
elif default_placeholder != image_token:
|
| 703 |
+
result.append(default_placeholder)
|
| 704 |
+
|
| 705 |
+
# Add the last part of the string
|
| 706 |
+
result.append(parts[-1])
|
| 707 |
+
return "".join(result)
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
def flatten_image_list(images: List[List]):
|
| 711 |
+
"""
|
| 712 |
+
Takes in a list of lists of images, and returns a single list of all images in order.
|
| 713 |
+
Used for some multimodal models like Llava-1.5 which expects this flattened-list format for its image processor.
|
| 714 |
+
|
| 715 |
+
:param images: A list of lists of PIL images.
|
| 716 |
+
:return: a list of PIL images, via concatenating all the sub-lists in order.
|
| 717 |
+
"""
|
| 718 |
+
return [image for image_list in images for image in image_list]
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def handle_stop_sequences(
|
| 722 |
+
until: Union[str, List[str], None], eos: Optional[str]
|
| 723 |
+
) -> List[str]:
|
| 724 |
+
"""Ensures that the `until` parameter is a list of stop sequences and includes the EOS token."""
|
| 725 |
+
if isinstance(until, str):
|
| 726 |
+
until = [until]
|
| 727 |
+
elif until is None:
|
| 728 |
+
until = []
|
| 729 |
+
elif not isinstance(until, list):
|
| 730 |
+
raise ValueError(
|
| 731 |
+
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
if eos is not None and eos not in until:
|
| 735 |
+
until.append(eos)
|
| 736 |
+
return until
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def resize_image(
|
| 740 |
+
image: "Image.Image",
|
| 741 |
+
width: Optional[int] = None,
|
| 742 |
+
height: Optional[int] = None,
|
| 743 |
+
max_dimension: Optional[int] = None,
|
| 744 |
+
keep_aspect_ratio: bool = True,
|
| 745 |
+
resample_filter: Union[int, str] = "Image.BICUBIC",
|
| 746 |
+
min_width: int = 1,
|
| 747 |
+
min_height: int = 1,
|
| 748 |
+
) -> "Image.Image":
|
| 749 |
+
"""
|
| 750 |
+
Resizes a PIL Image object with flexible options.
|
| 751 |
+
|
| 752 |
+
Args:
|
| 753 |
+
image: The PIL Image object to resize.
|
| 754 |
+
width: Target width in pixels.
|
| 755 |
+
height: Target height in pixels.
|
| 756 |
+
max_dimension: Maximum size for the longer dimension of the image.
|
| 757 |
+
keep_aspect_ratio: If True (default) and both width and height are provided,
|
| 758 |
+
the image is resized to fit within these dimensions while
|
| 759 |
+
maintaining its aspect ratio. If False, the image is stretched
|
| 760 |
+
to the exact width and height.
|
| 761 |
+
resample_filter: The resampling filter to use for resizing.
|
| 762 |
+
Defaults to Image.BICUBIC.
|
| 763 |
+
min_width: Minimum width for the resized image. Defaults to 1.
|
| 764 |
+
min_height: Minimum height for the resized image. Defaults to 1.
|
| 765 |
+
|
| 766 |
+
Returns:
|
| 767 |
+
The resized PIL Image object. If no resize parameters are provided
|
| 768 |
+
or if the image already meets the criteria, the original image is returned.
|
| 769 |
+
|
| 770 |
+
Order of precedence for resizing:
|
| 771 |
+
1. If width AND height are provided:
|
| 772 |
+
- If keep_aspect_ratio is True: Fits image within bounds, preserving aspect ratio.
|
| 773 |
+
- If keep_aspect_ratio is False: Resizes to exact dimensions (may distort).
|
| 774 |
+
2. Else if only width is provided: Calculates height proportionally.
|
| 775 |
+
3. Else if only height is provided: Calculates width proportionally.
|
| 776 |
+
4. Else if max_dimension is provided: Resizes the longest side to max_dimension
|
| 777 |
+
and scales the other side proportionally.
|
| 778 |
+
5. If none of the above are provided, returns the original image.
|
| 779 |
+
"""
|
| 780 |
+
original_width, original_height = image.size
|
| 781 |
+
|
| 782 |
+
# If no arguments are provided, return the original image
|
| 783 |
+
if width is None and height is None and max_dimension is None:
|
| 784 |
+
return image
|
| 785 |
+
|
| 786 |
+
new_width = original_width
|
| 787 |
+
new_height = original_height
|
| 788 |
+
|
| 789 |
+
if width is not None and height is not None:
|
| 790 |
+
# No resize needed if image is already smaller than target dimensions
|
| 791 |
+
if original_width <= width and original_height <= height:
|
| 792 |
+
return image
|
| 793 |
+
|
| 794 |
+
if keep_aspect_ratio:
|
| 795 |
+
# Calculate the ratio to fit within the target dimensions
|
| 796 |
+
ratio = min(width / original_width, height / original_height)
|
| 797 |
+
new_width = int(original_width * ratio)
|
| 798 |
+
new_height = int(original_height * ratio)
|
| 799 |
+
else:
|
| 800 |
+
# Stretch to exact dimensions
|
| 801 |
+
new_width = width
|
| 802 |
+
new_height = height
|
| 803 |
+
elif width is not None:
|
| 804 |
+
# No resize needed if width is already smaller
|
| 805 |
+
if original_width <= width:
|
| 806 |
+
return image
|
| 807 |
+
# Calculate height proportionally
|
| 808 |
+
new_width = width
|
| 809 |
+
new_height = int((original_height / original_width) * new_width)
|
| 810 |
+
elif height is not None:
|
| 811 |
+
# No resize needed if height is already smaller
|
| 812 |
+
if original_height <= height:
|
| 813 |
+
return image
|
| 814 |
+
# Calculate width proportionally
|
| 815 |
+
new_height = height
|
| 816 |
+
new_width = int((original_width / original_height) * new_height)
|
| 817 |
+
elif max_dimension is not None:
|
| 818 |
+
# No resize needed if both dimensions are smaller than max_dimension
|
| 819 |
+
if max(original_height, original_width) <= max_dimension:
|
| 820 |
+
return image
|
| 821 |
+
|
| 822 |
+
if original_width > original_height:
|
| 823 |
+
# Width is the longer side
|
| 824 |
+
new_width = max_dimension
|
| 825 |
+
new_height = int((original_height / original_width) * new_width)
|
| 826 |
+
else:
|
| 827 |
+
# Height is the longer side or sides are equal
|
| 828 |
+
new_height = max_dimension
|
| 829 |
+
new_width = int((original_width / original_height) * new_height)
|
| 830 |
+
|
| 831 |
+
# Ensure dimensions are at least minimum values
|
| 832 |
+
new_width = max(min_width, new_width)
|
| 833 |
+
new_height = max(min_height, new_height)
|
| 834 |
+
|
| 835 |
+
# Perform the resize operation with the calculated dimensions
|
| 836 |
+
return image.resize((new_width, new_height), resample_filter)
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
def truncate_tokens(
|
| 840 |
+
tokens: List[int],
|
| 841 |
+
max_length: int,
|
| 842 |
+
tokenizer: "PreTrainedTokenizerBase",
|
| 843 |
+
strategy: str = "left",
|
| 844 |
+
):
|
| 845 |
+
if strategy == "left":
|
| 846 |
+
return tokens[-max_length:]
|
| 847 |
+
elif strategy == "right":
|
| 848 |
+
return tokens[:max_length]
|
| 849 |
+
elif strategy == "middle":
|
| 850 |
+
# Truncate the middle of the sequence
|
| 851 |
+
left_length = max_length // 2
|
| 852 |
+
right_length = max_length - left_length
|
| 853 |
+
return tokens[:left_length] + tokens[-right_length:]
|
| 854 |
+
return None
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/verifier.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import logging
|
| 3 |
+
import ast
|
| 4 |
+
import re
|
| 5 |
+
import numpy as np
|
| 6 |
+
import textwrap
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class CodeVerifier:
|
| 11 |
+
def __init__(self, model, tokenizer, device="cuda"):
|
| 12 |
+
self.model = model
|
| 13 |
+
self.tokenizer = tokenizer
|
| 14 |
+
self.device = device
|
| 15 |
+
|
| 16 |
+
self.yes_ids, self.no_ids = [], []
|
| 17 |
+
for t in ["Yes", " Yes", "YES"]:
|
| 18 |
+
ids = self.tokenizer.encode(t, add_special_tokens=False)
|
| 19 |
+
if len(ids) > 0: self.yes_ids.append(ids[-1])
|
| 20 |
+
for t in ["No", " No", "NO"]:
|
| 21 |
+
ids = self.tokenizer.encode(t, add_special_tokens=False)
|
| 22 |
+
if len(ids) > 0: self.no_ids.append(ids[-1])
|
| 23 |
+
|
| 24 |
+
self.yes_ids = list(set(self.yes_ids))
|
| 25 |
+
self.no_ids = list(set(self.no_ids))
|
| 26 |
+
|
| 27 |
+
def _extract_python_code(self, text):
|
| 28 |
+
text = text.strip()
|
| 29 |
+
match = re.search(r"```python\s*(.*?)```", text, re.DOTALL)
|
| 30 |
+
if match: return match.group(1)
|
| 31 |
+
match_generic = re.search(r"```\s*(.*?)```", text, re.DOTALL)
|
| 32 |
+
if match_generic: return match_generic.group(1)
|
| 33 |
+
return text
|
| 34 |
+
|
| 35 |
+
def check_syntax(self, code_str):
|
| 36 |
+
clean_code = self._extract_python_code(code_str)
|
| 37 |
+
try:
|
| 38 |
+
if len(clean_code.strip()) < 5: return False
|
| 39 |
+
ast.parse(clean_code)
|
| 40 |
+
return True
|
| 41 |
+
except:
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
def compute_confidence(self, logits):
|
| 45 |
+
if logits is None: return 0.0
|
| 46 |
+
probs = torch.softmax(logits, dim=-1)
|
| 47 |
+
max_probs, _ = torch.max(probs, dim=-1)
|
| 48 |
+
log_probs = torch.log(max_probs + 1e-10)
|
| 49 |
+
return torch.exp(torch.mean(log_probs)).item()
|
| 50 |
+
|
| 51 |
+
def svf_score(self, prompt, code_str, task_type="code"):
|
| 52 |
+
|
| 53 |
+
max_len = 2000
|
| 54 |
+
if len(code_str) > max_len:
|
| 55 |
+
if task_type == "reasoning":
|
| 56 |
+
truncated_code = code_str[:500] + "\n...[truncated]...\n" + code_str[-(max_len-500):]
|
| 57 |
+
else:
|
| 58 |
+
truncated_code = code_str[-max_len:]
|
| 59 |
+
else:
|
| 60 |
+
truncated_code = code_str
|
| 61 |
+
|
| 62 |
+
if task_type == "code":
|
| 63 |
+
prompt_template = f"""
|
| 64 |
+
You are an expert programming contest judge. Your task is to evaluate a generated solution for a given problem based on correctness, efficiency, and adherence to constraints.
|
| 65 |
+
|
| 66 |
+
[Problem Statement]
|
| 67 |
+
{prompt}
|
| 68 |
+
[/Problem Statement]
|
| 69 |
+
|
| 70 |
+
[Proposed Python Solution]
|
| 71 |
+
```python
|
| 72 |
+
{truncated_code}
|
| 73 |
+
```
|
| 74 |
+
[/Proposed Python Solution]
|
| 75 |
+
|
| 76 |
+
**Analysis Steps:**
|
| 77 |
+
1. Correctness: Does the core algorithm correctly solve the problem?
|
| 78 |
+
2. Efficiency: Is the time complexity acceptable for the given constraints?
|
| 79 |
+
3. Edge Cases & Constraints: Does the code handle all rules and edge cases?
|
| 80 |
+
|
| 81 |
+
**Conclusion**: Based on your analysis, is the solution likely to be fully correct? Answer with a single word: Yes or No.
|
| 82 |
+
**Answer:** """
|
| 83 |
+
|
| 84 |
+
elif task_type == "math":
|
| 85 |
+
prompt_template = f"""
|
| 86 |
+
You are an expert mathematician and competition judge. Your task is to evaluate a proposed mathematical solution for a given problem based on its logical rigor and accuracy.
|
| 87 |
+
|
| 88 |
+
[Math Problem]
|
| 89 |
+
{prompt}
|
| 90 |
+
[/Math Problem]
|
| 91 |
+
|
| 92 |
+
[Proposed Mathematical Solution]
|
| 93 |
+
{truncated_code}
|
| 94 |
+
[/Proposed Mathematical Solution]
|
| 95 |
+
|
| 96 |
+
**Analysis Steps:**
|
| 97 |
+
1. Reasoning Validity: Are the logical steps and mathematical properties applied correctly?
|
| 98 |
+
2. Calculation Accuracy: Are the intermediate calculations or algebraic manipulations accurate?
|
| 99 |
+
3. Goal Alignment: Does the current reasoning path directly lead toward the final answer required by the problem?
|
| 100 |
+
|
| 101 |
+
**Conclusion**: Based on your analysis, is this solution path sound and likely to result in the correct final answer? Answer with a single word: Yes or No.
|
| 102 |
+
**Answer:** """
|
| 103 |
+
|
| 104 |
+
elif task_type == "reasoning":
|
| 105 |
+
prompt_template = f"""
|
| 106 |
+
You are an expert reading comprehension and faithfulness judge. Your task is to evaluate a generated answer based on the provided context and question.
|
| 107 |
+
|
| 108 |
+
[Context and Question]
|
| 109 |
+
{prompt}
|
| 110 |
+
[/Context and Question]
|
| 111 |
+
|
| 112 |
+
[Proposed Answer]
|
| 113 |
+
{truncated_code}
|
| 114 |
+
[/Proposed Answer]
|
| 115 |
+
|
| 116 |
+
**Analysis Steps :**
|
| 117 |
+
1. Faithfulness: Is the answer an exact, literal span from the context?
|
| 118 |
+
2. Relevance: Does the answer directly address the specific question asked without hallucinating external information?
|
| 119 |
+
3. Accuracy: Does the provided context strictly support this answer?
|
| 120 |
+
|
| 121 |
+
**Conclusion**: Based on your analysis, is the answer fully faithful to the context and correct? Answer with a single word: Yes or No.
|
| 122 |
+
**Answer:** """
|
| 123 |
+
|
| 124 |
+
else:
|
| 125 |
+
prompt_template = f"Is the following answer correct?\nQuestion: {prompt}\nAnswer: {truncated_code}\nAnswer Yes or No.\nAnswer:"
|
| 126 |
+
|
| 127 |
+
verify_text = textwrap.dedent(prompt_template).strip()
|
| 128 |
+
input_ids = self.tokenizer(verify_text, return_tensors="pt").input_ids.to(self.device)
|
| 129 |
+
|
| 130 |
+
max_pos = getattr(self.model.config, "max_position_embeddings",
|
| 131 |
+
getattr(self.model.config, "n_positions",
|
| 132 |
+
getattr(self.model.config, "max_sequence_length", 20480)))
|
| 133 |
+
|
| 134 |
+
if input_ids.shape[1] > max_pos - 16:
|
| 135 |
+
logger.warning("Verifier input is too long, truncating from the left.")
|
| 136 |
+
input_ids = input_ids[:, -(max_pos - 16):]
|
| 137 |
+
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
outputs = self.model(input_ids)
|
| 140 |
+
logits = outputs.logits[0, -1, :]
|
| 141 |
+
|
| 142 |
+
yes_score = max((logits[i].item() for i in self.yes_ids if i < logits.shape[-1]), default=-float('inf'))
|
| 143 |
+
no_score = max((logits[i].item() for i in self.no_ids if i < logits.shape[-1]), default=-float('inf'))
|
| 144 |
+
|
| 145 |
+
if yes_score == -float('inf') and no_score == -float('inf'): return 0.5
|
| 146 |
+
|
| 147 |
+
probs = torch.softmax(torch.tensor([yes_score, no_score]), dim=0)
|
| 148 |
+
return probs[0].item()
|
| 149 |
+
|
| 150 |
+
def get_reward(self, prompt, code_str, mode="confidence", problem_data=None, current_logits=None, task_type="code"):
|
| 151 |
+
if mode == "svf":
|
| 152 |
+
return self.svf_score(prompt, code_str, task_type=task_type)
|
| 153 |
+
else:
|
| 154 |
+
return self.compute_confidence(current_logits)
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/prompts/__init__.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from typing import Dict
|
| 5 |
+
|
| 6 |
+
from dllm_eval import utils
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
eval_logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
# Prompt library.
|
| 12 |
+
# Stores prompts in a dictionary indexed by 2 levels:
|
| 13 |
+
# prompt category name, and prompt name.
|
| 14 |
+
# This allows us to access prompts
|
| 15 |
+
PROMPT_REGISTRY: Dict[str, Dict[str, str]] = {
|
| 16 |
+
"qa-basic": {
|
| 17 |
+
"question-newline-answer": "Question: {{question}}\nAnswer:",
|
| 18 |
+
"q-newline-a": "Q: {{question}}\nA:",
|
| 19 |
+
},
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None):
|
| 24 |
+
# unpack prompt name
|
| 25 |
+
category_name, prompt_name = prompt_id.split(":")
|
| 26 |
+
if subset_name is None:
|
| 27 |
+
dataset_full_name = dataset_name
|
| 28 |
+
else:
|
| 29 |
+
dataset_full_name = f"{dataset_name}-{subset_name}"
|
| 30 |
+
eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}")
|
| 31 |
+
if category_name == "promptsource":
|
| 32 |
+
try:
|
| 33 |
+
from promptsource.templates import DatasetTemplates
|
| 34 |
+
except ModuleNotFoundError as exception:
|
| 35 |
+
raise type(exception)(
|
| 36 |
+
"Tried to load a Promptsource template, but promptsource is not installed ",
|
| 37 |
+
"please install promptsource via pip install lm-eval[promptsource] or pip install -e .[promptsource]",
|
| 38 |
+
)
|
| 39 |
+
try:
|
| 40 |
+
if subset_name is None:
|
| 41 |
+
prompts = DatasetTemplates(dataset_name=dataset_name)
|
| 42 |
+
else:
|
| 43 |
+
prompts = DatasetTemplates(
|
| 44 |
+
dataset_name=dataset_name, subset_name=subset_name
|
| 45 |
+
)
|
| 46 |
+
except Exception:
|
| 47 |
+
raise ValueError(f"{dataset_name} and {subset_name} not found")
|
| 48 |
+
if prompt_name in prompts.all_template_names:
|
| 49 |
+
return prompts[prompt_name]
|
| 50 |
+
else:
|
| 51 |
+
raise ValueError(
|
| 52 |
+
f"{prompt_name} not in prompt list {prompts.all_template_names}"
|
| 53 |
+
)
|
| 54 |
+
elif ".yaml" in category_name:
|
| 55 |
+
import yaml
|
| 56 |
+
|
| 57 |
+
with open(category_name, "rb") as file:
|
| 58 |
+
prompt_yaml_file = yaml.full_load(file)
|
| 59 |
+
|
| 60 |
+
prompt_string = prompt_yaml_file["prompts"][prompt_name]
|
| 61 |
+
return PromptString(prompt_string)
|
| 62 |
+
else:
|
| 63 |
+
try:
|
| 64 |
+
return PROMPT_REGISTRY[category_name][prompt_name]
|
| 65 |
+
except Exception:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
f"expected only a single `:` as separator between \
|
| 68 |
+
prompt category and name, but got `{prompt_id}` instead"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def load_prompt_list(
|
| 73 |
+
use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs
|
| 74 |
+
):
|
| 75 |
+
category_name, prompt_name = use_prompt.split(":")
|
| 76 |
+
|
| 77 |
+
if category_name == "promptsource":
|
| 78 |
+
from promptsource.templates import DatasetTemplates
|
| 79 |
+
|
| 80 |
+
if subset_name is None:
|
| 81 |
+
prompts = DatasetTemplates(dataset_name=dataset_name)
|
| 82 |
+
else:
|
| 83 |
+
prompts = DatasetTemplates(
|
| 84 |
+
dataset_name=dataset_name, subset_name=subset_name
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
|
| 88 |
+
|
| 89 |
+
elif ".yaml" in category_name:
|
| 90 |
+
import yaml
|
| 91 |
+
|
| 92 |
+
if yaml_path is not None:
|
| 93 |
+
category_name = os.path.realpath(os.path.join(yaml_path, category_name))
|
| 94 |
+
|
| 95 |
+
with open(category_name, "rb") as file:
|
| 96 |
+
prompt_yaml_file = yaml.full_load(file)
|
| 97 |
+
|
| 98 |
+
prompt_list = utils.pattern_match(
|
| 99 |
+
prompt_name, prompt_yaml_file["prompts"].keys()
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# category_name, *prompt_name = use_prompt.split(":")
|
| 103 |
+
# TODO allow to multiple prompt naming
|
| 104 |
+
# if len(prompt_name) > 1:
|
| 105 |
+
# prompt_list = []
|
| 106 |
+
# for prompt in prompt_name:
|
| 107 |
+
# prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
|
| 108 |
+
# else:
|
| 109 |
+
# prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
|
| 110 |
+
return [":".join([category_name, prompt]) for prompt in prompt_list]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class PromptString:
|
| 114 |
+
def __init__(self, prompt_string):
|
| 115 |
+
self.prompt_string = prompt_string
|
| 116 |
+
|
| 117 |
+
def apply(self, doc):
|
| 118 |
+
doc_to_text = self.prompt_string["doc_to_text"]
|
| 119 |
+
doc_to_target = self.prompt_string["doc_to_target"]
|
| 120 |
+
|
| 121 |
+
# TODO need a way to process doc_to_choice
|
| 122 |
+
if "doc_to_choice" in self.prompt_string:
|
| 123 |
+
raise NotImplementedError("Not yet implemented to accept doc_to_choice")
|
| 124 |
+
|
| 125 |
+
text_string = utils.apply_template(doc_to_text, doc)
|
| 126 |
+
target_string = utils.apply_template(doc_to_target, doc)
|
| 127 |
+
|
| 128 |
+
return [text_string, target_string]
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/__init__.py
ADDED
|
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import inspect
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from functools import partial
|
| 6 |
+
from typing import Dict, List, Mapping, Optional, Union
|
| 7 |
+
|
| 8 |
+
from dllm_eval import utils
|
| 9 |
+
from dllm_eval.api.group import ConfigurableGroup, GroupConfig
|
| 10 |
+
from dllm_eval.api.task import ConfigurableTask, Task
|
| 11 |
+
from dllm_eval.evaluator_utils import get_subtask_list
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
|
| 15 |
+
|
| 16 |
+
eval_logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TaskManager:
|
| 20 |
+
"""TaskManager indexes all tasks from the default `dllm_eval/tasks/`
|
| 21 |
+
and an optional directory if provided.
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
verbosity: Optional[str] = None,
|
| 28 |
+
include_path: Optional[Union[str, List]] = None,
|
| 29 |
+
include_defaults: bool = True,
|
| 30 |
+
metadata: Optional[dict] = None,
|
| 31 |
+
) -> None:
|
| 32 |
+
if verbosity is not None:
|
| 33 |
+
utils.setup_logging(verbosity)
|
| 34 |
+
self.include_path = include_path
|
| 35 |
+
self.metadata = metadata
|
| 36 |
+
self._task_index = self.initialize_tasks(
|
| 37 |
+
include_path=include_path, include_defaults=include_defaults
|
| 38 |
+
)
|
| 39 |
+
self._all_tasks = sorted(list(self._task_index.keys()))
|
| 40 |
+
|
| 41 |
+
self._all_groups = sorted(
|
| 42 |
+
[x for x in self._all_tasks if self._task_index[x]["type"] == "group"]
|
| 43 |
+
)
|
| 44 |
+
self._all_subtasks = sorted(
|
| 45 |
+
[
|
| 46 |
+
x
|
| 47 |
+
for x in self._all_tasks
|
| 48 |
+
if self._task_index[x]["type"] in ["task", "python_task"]
|
| 49 |
+
]
|
| 50 |
+
)
|
| 51 |
+
self._all_tags = sorted(
|
| 52 |
+
[x for x in self._all_tasks if self._task_index[x]["type"] == "tag"]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.task_group_map = collections.defaultdict(list)
|
| 56 |
+
|
| 57 |
+
def initialize_tasks(
|
| 58 |
+
self,
|
| 59 |
+
include_path: Optional[Union[str, List]] = None,
|
| 60 |
+
include_defaults: bool = True,
|
| 61 |
+
) -> dict[str, dict]:
|
| 62 |
+
"""Creates a dictionary of tasks indexes.
|
| 63 |
+
|
| 64 |
+
:param include_path: Union[str, List] = None
|
| 65 |
+
An additional path to be searched for tasks recursively.
|
| 66 |
+
Can provide more than one such path as a list.
|
| 67 |
+
:param include_defaults: bool = True
|
| 68 |
+
If set to false, default tasks (those in dllm_eval/tasks/) are not indexed.
|
| 69 |
+
return
|
| 70 |
+
Dictionary of task names as key and task metadata
|
| 71 |
+
"""
|
| 72 |
+
if include_defaults:
|
| 73 |
+
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
|
| 74 |
+
else:
|
| 75 |
+
all_paths = []
|
| 76 |
+
if include_path is not None:
|
| 77 |
+
if isinstance(include_path, str):
|
| 78 |
+
include_path = [include_path]
|
| 79 |
+
all_paths.extend(include_path)
|
| 80 |
+
|
| 81 |
+
task_index = {}
|
| 82 |
+
for task_dir in all_paths:
|
| 83 |
+
tasks = self._get_task_and_group(task_dir)
|
| 84 |
+
task_index = {**tasks, **task_index}
|
| 85 |
+
|
| 86 |
+
return task_index
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def all_tasks(self):
|
| 90 |
+
return self._all_tasks
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def all_groups(self):
|
| 94 |
+
return self._all_groups
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def all_subtasks(self):
|
| 98 |
+
return self._all_subtasks
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def all_tags(self):
|
| 102 |
+
return self._all_tags
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def task_index(self):
|
| 106 |
+
return self._task_index
|
| 107 |
+
|
| 108 |
+
def list_all_tasks(
|
| 109 |
+
self, list_groups=True, list_tags=True, list_subtasks=True
|
| 110 |
+
) -> str:
|
| 111 |
+
from pytablewriter import MarkdownTableWriter
|
| 112 |
+
|
| 113 |
+
def sanitize_path(path):
|
| 114 |
+
# don't print full path if we are within the dllm_eval/tasks dir !
|
| 115 |
+
# if we aren't though, provide the full path.
|
| 116 |
+
if "dllm_eval/tasks/" in path:
|
| 117 |
+
return "dllm_eval/tasks/" + path.split("dllm_eval/tasks/")[-1]
|
| 118 |
+
else:
|
| 119 |
+
return path
|
| 120 |
+
|
| 121 |
+
group_table = MarkdownTableWriter()
|
| 122 |
+
group_table.headers = ["Group", "Config Location"]
|
| 123 |
+
gt_values = []
|
| 124 |
+
for g in self.all_groups:
|
| 125 |
+
path = self.task_index[g]["yaml_path"]
|
| 126 |
+
if path == -1:
|
| 127 |
+
path = "---"
|
| 128 |
+
else:
|
| 129 |
+
path = sanitize_path(path)
|
| 130 |
+
gt_values.append([g, path])
|
| 131 |
+
group_table.value_matrix = gt_values
|
| 132 |
+
|
| 133 |
+
tag_table = MarkdownTableWriter()
|
| 134 |
+
tag_table.headers = ["Tag"]
|
| 135 |
+
tag_table.value_matrix = [[t] for t in self.all_tags]
|
| 136 |
+
|
| 137 |
+
subtask_table = MarkdownTableWriter()
|
| 138 |
+
subtask_table.headers = ["Task", "Config Location", "Output Type"]
|
| 139 |
+
st_values = []
|
| 140 |
+
for t in self.all_subtasks:
|
| 141 |
+
path = self.task_index[t]["yaml_path"]
|
| 142 |
+
|
| 143 |
+
output_type = ""
|
| 144 |
+
|
| 145 |
+
# read the yaml file to determine the output type
|
| 146 |
+
if path != -1:
|
| 147 |
+
config = utils.load_yaml_config(path, mode="simple")
|
| 148 |
+
if "output_type" in config:
|
| 149 |
+
output_type = config["output_type"]
|
| 150 |
+
elif (
|
| 151 |
+
"include" in config
|
| 152 |
+
): # if no output type, check if there is an include with an output type
|
| 153 |
+
include_path = path.split("/")[:-1] + config["include"]
|
| 154 |
+
include_config = utils.load_yaml_config(include_path, mode="simple")
|
| 155 |
+
if "output_type" in include_config:
|
| 156 |
+
output_type = include_config["output_type"]
|
| 157 |
+
|
| 158 |
+
if path == -1:
|
| 159 |
+
path = "---"
|
| 160 |
+
else:
|
| 161 |
+
path = sanitize_path(path)
|
| 162 |
+
st_values.append([t, path, output_type])
|
| 163 |
+
subtask_table.value_matrix = st_values
|
| 164 |
+
|
| 165 |
+
result = "\n"
|
| 166 |
+
if list_groups:
|
| 167 |
+
result += group_table.dumps() + "\n\n"
|
| 168 |
+
if list_tags:
|
| 169 |
+
result += tag_table.dumps() + "\n\n"
|
| 170 |
+
if list_subtasks:
|
| 171 |
+
result += subtask_table.dumps() + "\n\n"
|
| 172 |
+
return result
|
| 173 |
+
|
| 174 |
+
def match_tasks(self, task_list: list[str]) -> list[str]:
|
| 175 |
+
return utils.pattern_match(task_list, self.all_tasks)
|
| 176 |
+
|
| 177 |
+
def _name_is_registered(self, name: str) -> bool:
|
| 178 |
+
if name in self.all_tasks:
|
| 179 |
+
return True
|
| 180 |
+
return False
|
| 181 |
+
|
| 182 |
+
def _name_is_task(self, name: str) -> bool:
|
| 183 |
+
if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"):
|
| 184 |
+
return True
|
| 185 |
+
return False
|
| 186 |
+
|
| 187 |
+
def _name_is_tag(self, name: str) -> bool:
|
| 188 |
+
if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"):
|
| 189 |
+
return True
|
| 190 |
+
return False
|
| 191 |
+
|
| 192 |
+
def _name_is_group(self, name: str) -> bool:
|
| 193 |
+
if self._name_is_registered(name) and (
|
| 194 |
+
self.task_index[name]["type"] == "group"
|
| 195 |
+
):
|
| 196 |
+
return True
|
| 197 |
+
return False
|
| 198 |
+
|
| 199 |
+
def _name_is_python_task(self, name: str) -> bool:
|
| 200 |
+
if self._name_is_registered(name) and (
|
| 201 |
+
self.task_index[name]["type"] == "python_task"
|
| 202 |
+
):
|
| 203 |
+
return True
|
| 204 |
+
return False
|
| 205 |
+
|
| 206 |
+
def _config_is_task(self, config: dict) -> bool:
|
| 207 |
+
if ("task" in config) and isinstance(config["task"], str):
|
| 208 |
+
return True
|
| 209 |
+
return False
|
| 210 |
+
|
| 211 |
+
def _config_is_group(self, config: dict) -> bool:
|
| 212 |
+
if ("task" in config) and isinstance(config["task"], list):
|
| 213 |
+
return True
|
| 214 |
+
return False
|
| 215 |
+
|
| 216 |
+
def _config_is_python_task(self, config: dict) -> bool:
|
| 217 |
+
if "class" in config:
|
| 218 |
+
return True
|
| 219 |
+
return False
|
| 220 |
+
|
| 221 |
+
def _get_yaml_path(self, name: str):
|
| 222 |
+
if name not in self.task_index:
|
| 223 |
+
raise ValueError
|
| 224 |
+
return self.task_index[name]["yaml_path"]
|
| 225 |
+
|
| 226 |
+
def _get_config(self, name):
|
| 227 |
+
if name not in self.task_index:
|
| 228 |
+
raise ValueError
|
| 229 |
+
yaml_path = self._get_yaml_path(name)
|
| 230 |
+
if yaml_path == -1:
|
| 231 |
+
return {}
|
| 232 |
+
else:
|
| 233 |
+
return utils.load_yaml_config(yaml_path, mode="full")
|
| 234 |
+
|
| 235 |
+
def _get_tasklist(self, name):
|
| 236 |
+
if self._name_is_task(name):
|
| 237 |
+
raise ValueError
|
| 238 |
+
return self.task_index[name]["task"]
|
| 239 |
+
|
| 240 |
+
def _process_alias(self, config, group=None):
|
| 241 |
+
# If the group is not the same as the original
|
| 242 |
+
# group which the group alias was intended for,
|
| 243 |
+
# Set the group_alias to None instead.
|
| 244 |
+
if ("group_alias" in config) and ("group" in config) and group is not None:
|
| 245 |
+
if config["group"] != group:
|
| 246 |
+
config["group_alias"] = None
|
| 247 |
+
return config
|
| 248 |
+
|
| 249 |
+
def _class_has_config_in_constructor(self, cls):
|
| 250 |
+
constructor = getattr(cls, "__init__", None)
|
| 251 |
+
return (
|
| 252 |
+
"config" in inspect.signature(constructor).parameters
|
| 253 |
+
if constructor
|
| 254 |
+
else False
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
def _load_individual_task_or_group(
|
| 258 |
+
self,
|
| 259 |
+
name_or_config: Optional[Union[str, dict]] = None,
|
| 260 |
+
parent_name: Optional[str] = None,
|
| 261 |
+
update_config: Optional[dict] = None,
|
| 262 |
+
) -> Mapping:
|
| 263 |
+
def _load_task(config, task):
|
| 264 |
+
if "include" in config:
|
| 265 |
+
config = {
|
| 266 |
+
**utils.load_yaml_config(
|
| 267 |
+
yaml_path=None,
|
| 268 |
+
yaml_config={"include": config.pop("include")},
|
| 269 |
+
mode="full",
|
| 270 |
+
),
|
| 271 |
+
**config,
|
| 272 |
+
}
|
| 273 |
+
if self._config_is_python_task(config):
|
| 274 |
+
if self._class_has_config_in_constructor(config["class"]):
|
| 275 |
+
task_object = config["class"](config=config)
|
| 276 |
+
else:
|
| 277 |
+
task_object = config["class"]()
|
| 278 |
+
if isinstance(task_object, ConfigurableTask):
|
| 279 |
+
# very scuffed: set task name here. TODO: fixme?
|
| 280 |
+
task_object.config.task = task
|
| 281 |
+
else:
|
| 282 |
+
if self.metadata is not None:
|
| 283 |
+
config["metadata"] = config.get("metadata", {}) | self.metadata
|
| 284 |
+
else:
|
| 285 |
+
config["metadata"] = config.get("metadata", {})
|
| 286 |
+
task_object = ConfigurableTask(config=config)
|
| 287 |
+
|
| 288 |
+
return {task: task_object}
|
| 289 |
+
|
| 290 |
+
def _get_group_and_subtask_from_config(
|
| 291 |
+
config: dict,
|
| 292 |
+
) -> tuple[ConfigurableGroup, list[str]]:
|
| 293 |
+
if self.metadata is not None:
|
| 294 |
+
config["metadata"] = config.get("metadata", {}) | self.metadata
|
| 295 |
+
group_name = ConfigurableGroup(config=config)
|
| 296 |
+
subtask_list = []
|
| 297 |
+
for task in group_name.config["task"]:
|
| 298 |
+
if isinstance(task, str) and self._name_is_tag(task):
|
| 299 |
+
subtask_list.extend(self._get_tasklist(task))
|
| 300 |
+
else:
|
| 301 |
+
subtask_list.append(task)
|
| 302 |
+
return group_name, subtask_list
|
| 303 |
+
|
| 304 |
+
def _process_group_config(
|
| 305 |
+
config: dict, update_config: dict = None
|
| 306 |
+
) -> tuple[dict, dict]:
|
| 307 |
+
if update_config is not None:
|
| 308 |
+
config = {**config, **update_config}
|
| 309 |
+
_update_config = {
|
| 310 |
+
k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS
|
| 311 |
+
}
|
| 312 |
+
if not bool(_update_config):
|
| 313 |
+
_update_config = None
|
| 314 |
+
|
| 315 |
+
group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS}
|
| 316 |
+
return group_config, _update_config
|
| 317 |
+
|
| 318 |
+
if isinstance(name_or_config, str):
|
| 319 |
+
if update_config is not None:
|
| 320 |
+
# Process name_or_config as a dict instead
|
| 321 |
+
name_or_config = {"task": name_or_config, **update_config}
|
| 322 |
+
elif self._name_is_task(name_or_config) or self._name_is_python_task(
|
| 323 |
+
name_or_config
|
| 324 |
+
):
|
| 325 |
+
task_config = self._get_config(name_or_config)
|
| 326 |
+
return _load_task(task_config, task=name_or_config)
|
| 327 |
+
else:
|
| 328 |
+
subtask_list = self._get_tasklist(name_or_config)
|
| 329 |
+
if subtask_list == -1:
|
| 330 |
+
group_config = self._get_config(name_or_config)
|
| 331 |
+
group_config, update_config = _process_group_config(group_config)
|
| 332 |
+
group_name, subtask_list = _get_group_and_subtask_from_config(
|
| 333 |
+
group_config
|
| 334 |
+
)
|
| 335 |
+
else:
|
| 336 |
+
if self._name_is_tag(name_or_config):
|
| 337 |
+
fn = partial(
|
| 338 |
+
self._load_individual_task_or_group,
|
| 339 |
+
update_config=name_or_config
|
| 340 |
+
if isinstance(name_or_config, dict)
|
| 341 |
+
else None,
|
| 342 |
+
)
|
| 343 |
+
return dict(
|
| 344 |
+
collections.ChainMap(*map(fn, reversed(subtask_list)))
|
| 345 |
+
)
|
| 346 |
+
else:
|
| 347 |
+
group_name = ConfigurableGroup(
|
| 348 |
+
config={"group": name_or_config, "task": subtask_list}
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
if isinstance(name_or_config, dict):
|
| 352 |
+
if self._config_is_task(name_or_config):
|
| 353 |
+
name = name_or_config.pop("task")
|
| 354 |
+
if update_config is not None:
|
| 355 |
+
name_or_config = {**name_or_config, **update_config}
|
| 356 |
+
# If the name is registered as a group
|
| 357 |
+
if self._name_is_group(name):
|
| 358 |
+
group_config = self._get_config(name)
|
| 359 |
+
|
| 360 |
+
group_config, update_config = _process_group_config(
|
| 361 |
+
group_config, name_or_config
|
| 362 |
+
)
|
| 363 |
+
group_name, subtask_list = _get_group_and_subtask_from_config(
|
| 364 |
+
group_config
|
| 365 |
+
)
|
| 366 |
+
elif self._name_is_tag(name):
|
| 367 |
+
subtask_list = self._get_tasklist(name)
|
| 368 |
+
fn = partial(
|
| 369 |
+
self._load_individual_task_or_group,
|
| 370 |
+
update_config=name_or_config,
|
| 371 |
+
)
|
| 372 |
+
return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
|
| 373 |
+
else:
|
| 374 |
+
if self._name_is_registered(name):
|
| 375 |
+
base_task_config = self._get_config(name)
|
| 376 |
+
|
| 377 |
+
# Check if this is a duplicate.
|
| 378 |
+
if parent_name is not None:
|
| 379 |
+
num_duplicate = len(
|
| 380 |
+
list(
|
| 381 |
+
filter(
|
| 382 |
+
lambda x: x.startswith(name),
|
| 383 |
+
self.task_group_map[parent_name],
|
| 384 |
+
)
|
| 385 |
+
)
|
| 386 |
+
)
|
| 387 |
+
if num_duplicate > 0:
|
| 388 |
+
name = f"{name}-{num_duplicate}"
|
| 389 |
+
self.task_group_map[parent_name].append(name)
|
| 390 |
+
|
| 391 |
+
task_config = {
|
| 392 |
+
**base_task_config,
|
| 393 |
+
**name_or_config,
|
| 394 |
+
}
|
| 395 |
+
else:
|
| 396 |
+
task_config = name_or_config
|
| 397 |
+
return _load_task(task_config, task=name)
|
| 398 |
+
else:
|
| 399 |
+
group_config, update_config = _process_group_config(name_or_config)
|
| 400 |
+
group_name, subtask_list = _get_group_and_subtask_from_config(
|
| 401 |
+
group_config
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
fn = partial(
|
| 405 |
+
self._load_individual_task_or_group,
|
| 406 |
+
parent_name=group_name,
|
| 407 |
+
update_config=update_config,
|
| 408 |
+
)
|
| 409 |
+
return {
|
| 410 |
+
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
|
| 414 |
+
"""Loads a dictionary of task objects from a list
|
| 415 |
+
|
| 416 |
+
:param task_list: Union[str, list] = None
|
| 417 |
+
Single string or list of string of task names to be loaded
|
| 418 |
+
|
| 419 |
+
:return
|
| 420 |
+
Dictionary of task objects
|
| 421 |
+
"""
|
| 422 |
+
if isinstance(task_list, str):
|
| 423 |
+
task_list = [task_list]
|
| 424 |
+
|
| 425 |
+
all_loaded_tasks = dict(
|
| 426 |
+
collections.ChainMap(
|
| 427 |
+
*map(
|
| 428 |
+
lambda task: self._load_individual_task_or_group(task),
|
| 429 |
+
task_list,
|
| 430 |
+
)
|
| 431 |
+
)
|
| 432 |
+
)
|
| 433 |
+
return all_loaded_tasks
|
| 434 |
+
|
| 435 |
+
def load_config(self, config: Dict):
|
| 436 |
+
return self._load_individual_task_or_group(config)
|
| 437 |
+
|
| 438 |
+
def _get_task_and_group(self, task_dir: str):
|
| 439 |
+
"""Creates a dictionary of tasks index with the following metadata,
|
| 440 |
+
- `type`, that can be either `task`, `python_task`, `group` or `tags`.
|
| 441 |
+
`task` refer to regular task configs, `python_task` are special
|
| 442 |
+
yaml files that only consists of `task` and `class` parameters.
|
| 443 |
+
`group` are group configs. `tags` are labels that can be assigned
|
| 444 |
+
to tasks to assist in sorting and calling tasks of certain themes.
|
| 445 |
+
- `yaml_path`, path to the yaml file. If the entry is a `group` that
|
| 446 |
+
was configured through a task config, the yaml_path will be -1
|
| 447 |
+
and all subtasks will be listed in `task` (see below)
|
| 448 |
+
- `task`, reserved for entries with `type` as `group`. This will list
|
| 449 |
+
all subtasks. When a group config is created (as opposed to task
|
| 450 |
+
config having `group` parameter set), this will be set to -1 to
|
| 451 |
+
avoid recursive indexing. The whole list of subtasks will be loaded
|
| 452 |
+
at evaluation.
|
| 453 |
+
|
| 454 |
+
:param task_dir: str
|
| 455 |
+
A directory to check for tasks
|
| 456 |
+
|
| 457 |
+
:return
|
| 458 |
+
Dictionary of task names as key and task metadata
|
| 459 |
+
"""
|
| 460 |
+
|
| 461 |
+
def _populate_tags_and_groups(config, task, tasks_and_groups, print_info):
|
| 462 |
+
# TODO: remove group in next release
|
| 463 |
+
if "tag" in config:
|
| 464 |
+
attr_list = config["tag"]
|
| 465 |
+
if isinstance(attr_list, str):
|
| 466 |
+
attr_list = [attr_list]
|
| 467 |
+
|
| 468 |
+
for tag in attr_list:
|
| 469 |
+
if tag not in tasks_and_groups:
|
| 470 |
+
tasks_and_groups[tag] = {
|
| 471 |
+
"type": "tag",
|
| 472 |
+
"task": [task],
|
| 473 |
+
"yaml_path": -1,
|
| 474 |
+
}
|
| 475 |
+
elif tasks_and_groups[tag]["type"] != "tag":
|
| 476 |
+
eval_logger.info(
|
| 477 |
+
f"The tag '{tag}' is already registered as a group, this tag will not be registered. "
|
| 478 |
+
"This may affect tasks you want to call."
|
| 479 |
+
)
|
| 480 |
+
break
|
| 481 |
+
else:
|
| 482 |
+
tasks_and_groups[tag]["task"].append(task)
|
| 483 |
+
|
| 484 |
+
# TODO: remove group in next release
|
| 485 |
+
print_info = True
|
| 486 |
+
ignore_dirs = [
|
| 487 |
+
"__pycache__",
|
| 488 |
+
".ipynb_checkpoints",
|
| 489 |
+
]
|
| 490 |
+
tasks_and_groups = collections.defaultdict()
|
| 491 |
+
for root, dirs, file_list in os.walk(task_dir):
|
| 492 |
+
dirs[:] = [d for d in dirs if d not in ignore_dirs]
|
| 493 |
+
for f in file_list:
|
| 494 |
+
if f.endswith(".yaml"):
|
| 495 |
+
yaml_path = os.path.join(root, f)
|
| 496 |
+
print(yaml_path)
|
| 497 |
+
config = utils.load_yaml_config(yaml_path, mode="simple")
|
| 498 |
+
if self._config_is_python_task(config):
|
| 499 |
+
# This is a python class config
|
| 500 |
+
task = config["task"]
|
| 501 |
+
tasks_and_groups[task] = {
|
| 502 |
+
"type": "python_task",
|
| 503 |
+
"yaml_path": yaml_path,
|
| 504 |
+
}
|
| 505 |
+
_populate_tags_and_groups(
|
| 506 |
+
config, task, tasks_and_groups, print_info
|
| 507 |
+
)
|
| 508 |
+
elif self._config_is_group(config):
|
| 509 |
+
# This is a group config
|
| 510 |
+
tasks_and_groups[config["group"]] = {
|
| 511 |
+
"type": "group",
|
| 512 |
+
"task": -1, # This signals that
|
| 513 |
+
# we don't need to know
|
| 514 |
+
# the task list for indexing
|
| 515 |
+
# as it can be loaded
|
| 516 |
+
# when called.
|
| 517 |
+
"yaml_path": yaml_path,
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
# # Registered the level 1 tasks from a group config
|
| 521 |
+
# for config in config["task"]:
|
| 522 |
+
# if isinstance(config, dict) and self._config_is_task(config):
|
| 523 |
+
# task = config["task"]
|
| 524 |
+
# tasks_and_groups[task] = {
|
| 525 |
+
# "type": "task",
|
| 526 |
+
# "yaml_path": yaml_path,
|
| 527 |
+
# }
|
| 528 |
+
|
| 529 |
+
elif self._config_is_task(config):
|
| 530 |
+
# This is a task config
|
| 531 |
+
task = config["task"]
|
| 532 |
+
tasks_and_groups[task] = {
|
| 533 |
+
"type": "task",
|
| 534 |
+
"yaml_path": yaml_path,
|
| 535 |
+
}
|
| 536 |
+
_populate_tags_and_groups(
|
| 537 |
+
config, task, tasks_and_groups, print_info
|
| 538 |
+
)
|
| 539 |
+
else:
|
| 540 |
+
eval_logger.debug(f"File {f} in {root} could not be loaded")
|
| 541 |
+
|
| 542 |
+
return tasks_and_groups
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
|
| 546 |
+
if "task" in task_config:
|
| 547 |
+
return task_config["task"]
|
| 548 |
+
if "dataset_name" in task_config:
|
| 549 |
+
return "{dataset_path}_{dataset_name}".format(**task_config)
|
| 550 |
+
else:
|
| 551 |
+
return "{dataset_path}".format(**task_config)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def get_task_name_from_object(task_object):
|
| 555 |
+
if hasattr(task_object, "config"):
|
| 556 |
+
return task_object._config["task"]
|
| 557 |
+
|
| 558 |
+
# TODO: scrap this
|
| 559 |
+
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
|
| 560 |
+
return (
|
| 561 |
+
task_object.EVAL_HARNESS_NAME
|
| 562 |
+
if hasattr(task_object, "EVAL_HARNESS_NAME")
|
| 563 |
+
else type(task_object).__name__
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def _check_duplicates(task_dict: dict) -> None:
|
| 568 |
+
"""helper function solely used in validating get_task_dict output.
|
| 569 |
+
Takes the output of dllm_eval.evaluator_utils.get_subtask_list and
|
| 570 |
+
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
|
| 571 |
+
"oversubscribed" to several disjoint groups.
|
| 572 |
+
"""
|
| 573 |
+
subtask_names = []
|
| 574 |
+
for key, value in task_dict.items():
|
| 575 |
+
subtask_names.extend(value)
|
| 576 |
+
|
| 577 |
+
duplicate_tasks = {
|
| 578 |
+
task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
# locate the potentially problematic groups that seem to 'compete' for constituent subtasks
|
| 582 |
+
competing_groups = [
|
| 583 |
+
group
|
| 584 |
+
for group in task_dict.keys()
|
| 585 |
+
if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
|
| 586 |
+
]
|
| 587 |
+
|
| 588 |
+
if len(duplicate_tasks) > 0:
|
| 589 |
+
raise ValueError(
|
| 590 |
+
f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def get_task_dict(
|
| 595 |
+
task_name_list: Union[str, List[Union[str, Dict, Task]]],
|
| 596 |
+
task_manager: Optional[TaskManager] = None,
|
| 597 |
+
):
|
| 598 |
+
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
|
| 599 |
+
|
| 600 |
+
:param task_name_list: List[Union[str, Dict, Task]]
|
| 601 |
+
Name of model or LM object, see dllm_eval.models.get_model
|
| 602 |
+
:param task_manager: TaskManager = None
|
| 603 |
+
A TaskManager object that stores indexed tasks. If not set,
|
| 604 |
+
task_manager will load one. This should be set by the user
|
| 605 |
+
if there are additional paths that want to be included
|
| 606 |
+
via `include_path`
|
| 607 |
+
|
| 608 |
+
:return
|
| 609 |
+
Dictionary of task objects
|
| 610 |
+
"""
|
| 611 |
+
|
| 612 |
+
task_name_from_string_dict = {}
|
| 613 |
+
task_name_from_config_dict = {}
|
| 614 |
+
task_name_from_object_dict = {}
|
| 615 |
+
|
| 616 |
+
if isinstance(task_name_list, str):
|
| 617 |
+
task_name_list = [task_name_list]
|
| 618 |
+
elif isinstance(task_name_list, list):
|
| 619 |
+
if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]):
|
| 620 |
+
raise TypeError(
|
| 621 |
+
"Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
|
| 622 |
+
)
|
| 623 |
+
else:
|
| 624 |
+
raise TypeError(
|
| 625 |
+
f"Expected a 'str' or 'list' but received {type(task_name_list)}."
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
string_task_name_list = [task for task in task_name_list if isinstance(task, str)]
|
| 629 |
+
others_task_name_list = [
|
| 630 |
+
task for task in task_name_list if not isinstance(task, str)
|
| 631 |
+
]
|
| 632 |
+
if len(string_task_name_list) > 0:
|
| 633 |
+
if task_manager is None:
|
| 634 |
+
task_manager = TaskManager()
|
| 635 |
+
|
| 636 |
+
task_name_from_string_dict = task_manager.load_task_or_group(
|
| 637 |
+
string_task_name_list
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
for task_element in others_task_name_list:
|
| 641 |
+
if isinstance(task_element, dict):
|
| 642 |
+
task_name_from_config_dict = {
|
| 643 |
+
**task_name_from_config_dict,
|
| 644 |
+
**task_manager.load_config(config=task_element),
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
elif isinstance(task_element, Task):
|
| 648 |
+
task_name_from_object_dict = {
|
| 649 |
+
**task_name_from_object_dict,
|
| 650 |
+
get_task_name_from_object(task_element): task_element,
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
if not set(task_name_from_string_dict.keys()).isdisjoint(
|
| 654 |
+
set(task_name_from_object_dict.keys())
|
| 655 |
+
):
|
| 656 |
+
raise ValueError
|
| 657 |
+
|
| 658 |
+
final_task_dict = {
|
| 659 |
+
**task_name_from_string_dict,
|
| 660 |
+
**task_name_from_config_dict,
|
| 661 |
+
**task_name_from_object_dict,
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
# behavior can get odd if one tries to invoke several groups that "compete" for the same task.
|
| 665 |
+
# (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask
|
| 666 |
+
# and we'd be unsure which to use and report.)
|
| 667 |
+
# we explicitly check and error in this case.
|
| 668 |
+
_check_duplicates(get_subtask_list(final_task_dict))
|
| 669 |
+
|
| 670 |
+
return final_task_dict
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/gsm8k.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
task: gsm8k
|
| 2 |
+
dataset_path: openai/gsm8k
|
| 3 |
+
dataset_name: main
|
| 4 |
+
output_type: generate_until
|
| 5 |
+
training_split: train
|
| 6 |
+
fewshot_split: train
|
| 7 |
+
test_split: test
|
| 8 |
+
doc_to_text: !function utils.gsm_prompt
|
| 9 |
+
doc_to_target: "{{answer.split('####')[-1].strip()}}"
|
| 10 |
+
generation_kwargs:
|
| 11 |
+
until:
|
| 12 |
+
- "[NO_UNTIL_PLACEHOLDER]"
|
| 13 |
+
do_sample: false
|
| 14 |
+
repeats: 1
|
| 15 |
+
num_fewshot: 0
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/utils.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def gsm_prompt(doc):
|
| 2 |
+
system_prompt = (
|
| 3 |
+
"You are a math expert. You will be given a question to solve. Solve it step by step. Wrap the final answer in a \\boxed{}. \n"
|
| 4 |
+
"Respond in the following format:\n"
|
| 5 |
+
"<reasoning>\n"
|
| 6 |
+
"Your reasoning here\n"
|
| 7 |
+
"</reasoning>\n"
|
| 8 |
+
"<answer>\n"
|
| 9 |
+
"\\boxed{...}\n"
|
| 10 |
+
"</answer>"
|
| 11 |
+
)
|
| 12 |
+
prompt = f"{system_prompt}\n\n{doc['question']}\n\n"
|
| 13 |
+
return prompt
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/humaneval.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
task: humaneval
|
| 2 |
+
dataset_path: openai/openai_humaneval
|
| 3 |
+
unsafe_code: true
|
| 4 |
+
output_type: generate_until
|
| 5 |
+
test_split: test
|
| 6 |
+
doc_to_text: "Write a solution to the following problem and make sure that it passes the tests:\n{{prompt}}\n\nFirst, reason about the solution step-by-step. Then, write the code.\nRespond in the following format:\n<reasoning>\nYour reasoning here\n</reasoning>\n<answer>\n```python\nThe complete implementation of the {{entry_point}} function\n```\n</answer>"
|
| 7 |
+
doc_to_target: "{{test}}\ncheck({{entry_point}})"
|
| 8 |
+
generation_kwargs:
|
| 9 |
+
until:
|
| 10 |
+
- "[NO_UNTIL_PLACEHOLDER]"
|
| 11 |
+
do_sample: false
|
| 12 |
+
repeats: 1
|
| 13 |
+
num_fewshot: 0
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/utils.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import evaluate as hf_evaluate
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
compute_ = hf_evaluate.load("code_eval")
|
| 6 |
+
test_cases = ["assert add(2, 3)==5"]
|
| 7 |
+
candidates = [["def add(a,b): return a*b"]]
|
| 8 |
+
results = compute_.compute(references=test_cases, predictions=candidates, k=[1])
|
| 9 |
+
except Exception as e:
|
| 10 |
+
raise e
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def pass_at_k(references: list[str], predictions: list[list[str]], k: list[int] = None):
|
| 14 |
+
global compute_
|
| 15 |
+
assert k is not None
|
| 16 |
+
if isinstance(k, int):
|
| 17 |
+
k = [k]
|
| 18 |
+
res = compute_.compute(
|
| 19 |
+
references=references,
|
| 20 |
+
predictions=predictions,
|
| 21 |
+
k=k
|
| 22 |
+
)
|
| 23 |
+
return res[0]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def clean_response_string(r: str) -> str:
|
| 27 |
+
cleaned_text = r if r.rfind("```python") == -1 else r[r.rfind("```python"):]
|
| 28 |
+
cleaned_text = cleaned_text if cleaned_text.rfind("```") == -1 else cleaned_text[: cleaned_text.rfind("```")]
|
| 29 |
+
cleaned_text = cleaned_text if cleaned_text.rfind("if __name__ == \"__main__\":") == -1 else cleaned_text[: cleaned_text.rfind("if __name__ == \"__main__\":")]
|
| 30 |
+
return cleaned_text
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 34 |
+
return [[doc["prompt"] + r for r in resp] for resp, doc in zip(resps, docs)]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def build_predictions(
|
| 38 |
+
resps: list[list[str]], docs: list[dict]
|
| 39 |
+
) -> list[list[str]]:
|
| 40 |
+
return [
|
| 41 |
+
[clean_response_string(r) for r in resp]
|
| 42 |
+
for resp, doc in zip(resps, docs)
|
| 43 |
+
]
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/mbpp.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
task: mbpp
|
| 2 |
+
dataset_path: google-research-datasets/mbpp
|
| 3 |
+
dataset_name: full
|
| 4 |
+
unsafe_code: true
|
| 5 |
+
output_type: generate_until
|
| 6 |
+
test_split: test
|
| 7 |
+
doc_to_text: "\n{{text}} Your code should pass these tests:\n\n{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}} \n\nFirst, reason about the solution step-by-step. Then, write the code.\nRespond in the following format:\n<reasoning>\nYour reasoning here\n</reasoning>\n<answer>\n```python\nThe complete implementation of the function\n```\n</answer>"
|
| 8 |
+
doc_to_target: "{% if is_fewshot is defined %}{{code}}\n[DONE]{% else %}{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}{% endif %}"
|
| 9 |
+
target_delimiter: ""
|
| 10 |
+
generation_kwargs:
|
| 11 |
+
until:
|
| 12 |
+
- "[NO_UNTIL_PLACEHOLDER]"
|
| 13 |
+
do_sample: false
|
| 14 |
+
num_fewshot: 0
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/utils.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
import evaluate as hf_evaluate
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
pass_at_k = hf_evaluate.load("code_eval")
|
| 9 |
+
|
| 10 |
+
# run simple test to check code execution is enabled before model generation
|
| 11 |
+
test_cases = ["assert add(2, 3)==5"]
|
| 12 |
+
candidates = [["def add(a,b): return a*b"]]
|
| 13 |
+
results = pass_at_k.compute(references=test_cases, predictions=candidates, k=[1])
|
| 14 |
+
except Exception as e:
|
| 15 |
+
raise e
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def pass_at_1(
|
| 19 |
+
references: Union[str, list[str]], predictions: Union[str, list[list[str]]]
|
| 20 |
+
) -> float:
|
| 21 |
+
if isinstance(references, str):
|
| 22 |
+
references = [references]
|
| 23 |
+
if isinstance(predictions[0], str):
|
| 24 |
+
predictions = [[p] for p in predictions]
|
| 25 |
+
return pass_at_k.compute(
|
| 26 |
+
references=references,
|
| 27 |
+
predictions=predictions,
|
| 28 |
+
k=[1],
|
| 29 |
+
num_workers=48
|
| 30 |
+
)[0]["pass@1"]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def extract_code_blocks(text: str) -> str:
|
| 34 |
+
text = re.sub(r"\[DONE\]", "", text)
|
| 35 |
+
text = re.sub(r"<\|eot_id\|>", "", text)
|
| 36 |
+
text = re.sub(r"<\|endoftext\|>", "", text)
|
| 37 |
+
return text
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 41 |
+
return [[extract_code_blocks(r) for r in resp] for resp in resps]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def list_fewshot_samples():
|
| 45 |
+
return [
|
| 46 |
+
{
|
| 47 |
+
"task_id": 2,
|
| 48 |
+
"text": "Write a function to find the similar elements from the given two tuple lists.",
|
| 49 |
+
"code": "def similar_elements(test_tup1, test_tup2):\r\n res = tuple(set(test_tup1) & set(test_tup2))\r\n return (res) ",
|
| 50 |
+
"test_list": [
|
| 51 |
+
"assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)",
|
| 52 |
+
"assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)",
|
| 53 |
+
"assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)",
|
| 54 |
+
],
|
| 55 |
+
"is_fewshot": True,
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"task_id": 3,
|
| 59 |
+
"text": "Write a python function to identify non-prime numbers.",
|
| 60 |
+
"code": "import math\r\ndef is_not_prime(n):\r\n result = False\r\n for i in range(2,int(math.sqrt(n)) + 1):\r\n if n % i == 0:\r\n result = True\r\n return result",
|
| 61 |
+
"test_list": [
|
| 62 |
+
"assert is_not_prime(2) == False",
|
| 63 |
+
"assert is_not_prime(10) == True",
|
| 64 |
+
"assert is_not_prime(35) == True",
|
| 65 |
+
],
|
| 66 |
+
"is_fewshot": True,
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"task_id": 4,
|
| 70 |
+
"text": "Write a function to find the largest integers from a given list of numbers using heap queue algorithm.",
|
| 71 |
+
"code": "import heapq as hq\r\ndef heap_queue_largest(nums,n):\r\n largest_nums = hq.nlargest(n, nums)\r\n return largest_nums",
|
| 72 |
+
"test_list": [
|
| 73 |
+
"assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] ",
|
| 74 |
+
"assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] ",
|
| 75 |
+
"assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]",
|
| 76 |
+
],
|
| 77 |
+
"is_fewshot": True,
|
| 78 |
+
},
|
| 79 |
+
]
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/certifi/__main__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from certifi import contents, where
|
| 4 |
+
|
| 5 |
+
parser = argparse.ArgumentParser()
|
| 6 |
+
parser.add_argument("-c", "--contents", action="store_true")
|
| 7 |
+
args = parser.parse_args()
|
| 8 |
+
|
| 9 |
+
if args.contents:
|
| 10 |
+
print(contents())
|
| 11 |
+
else:
|
| 12 |
+
print(where())
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/INSTALLER
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2013-2019 Nikolay Kim and Andrew Svetlov
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/METADATA
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: frozenlist
|
| 3 |
+
Version: 1.5.0
|
| 4 |
+
Summary: A list-like structure which implements collections.abc.MutableSequence
|
| 5 |
+
Home-page: https://github.com/aio-libs/frozenlist
|
| 6 |
+
Maintainer: aiohttp team <team@aiohttp.org>
|
| 7 |
+
Maintainer-email: team@aiohttp.org
|
| 8 |
+
License: Apache 2
|
| 9 |
+
Project-URL: Chat: Matrix, https://matrix.to/#/#aio-libs:matrix.org
|
| 10 |
+
Project-URL: Chat: Matrix Space, https://matrix.to/#/#aio-libs-space:matrix.org
|
| 11 |
+
Project-URL: CI: Github Actions, https://github.com/aio-libs/frozenlist/actions
|
| 12 |
+
Project-URL: Code of Conduct, https://github.com/aio-libs/.github/blob/master/CODE_OF_CONDUCT.md
|
| 13 |
+
Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/frozenlist
|
| 14 |
+
Project-URL: Docs: Changelog, https://github.com/aio-libs/frozenlist/blob/master/CHANGES.rst#changelog
|
| 15 |
+
Project-URL: Docs: RTD, https://frozenlist.aio-libs.org
|
| 16 |
+
Project-URL: GitHub: issues, https://github.com/aio-libs/frozenlist/issues
|
| 17 |
+
Project-URL: GitHub: repo, https://github.com/aio-libs/frozenlist
|
| 18 |
+
Classifier: Development Status :: 5 - Production/Stable
|
| 19 |
+
Classifier: Intended Audience :: Developers
|
| 20 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
| 21 |
+
Classifier: Operating System :: POSIX
|
| 22 |
+
Classifier: Operating System :: MacOS :: MacOS X
|
| 23 |
+
Classifier: Operating System :: Microsoft :: Windows
|
| 24 |
+
Classifier: Programming Language :: Cython
|
| 25 |
+
Classifier: Programming Language :: Python
|
| 26 |
+
Classifier: Programming Language :: Python :: 3
|
| 27 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 28 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 29 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 30 |
+
Classifier: Programming Language :: Python :: 3.11
|
| 31 |
+
Classifier: Programming Language :: Python :: 3.12
|
| 32 |
+
Classifier: Programming Language :: Python :: 3.13
|
| 33 |
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
| 34 |
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
| 35 |
+
Requires-Python: >=3.8
|
| 36 |
+
Description-Content-Type: text/x-rst
|
| 37 |
+
License-File: LICENSE
|
| 38 |
+
|
| 39 |
+
frozenlist
|
| 40 |
+
==========
|
| 41 |
+
|
| 42 |
+
.. image:: https://github.com/aio-libs/frozenlist/workflows/CI/badge.svg
|
| 43 |
+
:target: https://github.com/aio-libs/frozenlist/actions
|
| 44 |
+
:alt: GitHub status for master branch
|
| 45 |
+
|
| 46 |
+
.. image:: https://codecov.io/gh/aio-libs/frozenlist/branch/master/graph/badge.svg
|
| 47 |
+
:target: https://codecov.io/gh/aio-libs/frozenlist
|
| 48 |
+
:alt: codecov.io status for master branch
|
| 49 |
+
|
| 50 |
+
.. image:: https://img.shields.io/pypi/v/frozenlist.svg?logo=Python&logoColor=white
|
| 51 |
+
:target: https://pypi.org/project/frozenlist
|
| 52 |
+
:alt: frozenlist @ PyPI
|
| 53 |
+
|
| 54 |
+
.. image:: https://readthedocs.org/projects/frozenlist/badge/?version=latest
|
| 55 |
+
:target: https://frozenlist.aio-libs.org
|
| 56 |
+
:alt: Read The Docs build status badge
|
| 57 |
+
|
| 58 |
+
.. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat
|
| 59 |
+
:target: https://matrix.to/#/%23aio-libs:matrix.org
|
| 60 |
+
:alt: Matrix Room — #aio-libs:matrix.org
|
| 61 |
+
|
| 62 |
+
.. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat
|
| 63 |
+
:target: https://matrix.to/#/%23aio-libs-space:matrix.org
|
| 64 |
+
:alt: Matrix Space — #aio-libs-space:matrix.org
|
| 65 |
+
|
| 66 |
+
Introduction
|
| 67 |
+
------------
|
| 68 |
+
|
| 69 |
+
``frozenlist.FrozenList`` is a list-like structure which implements
|
| 70 |
+
``collections.abc.MutableSequence``. The list is *mutable* until ``FrozenList.freeze``
|
| 71 |
+
is called, after which list modifications raise ``RuntimeError``:
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
>>> from frozenlist import FrozenList
|
| 75 |
+
>>> fl = FrozenList([17, 42])
|
| 76 |
+
>>> fl.append('spam')
|
| 77 |
+
>>> fl.append('Vikings')
|
| 78 |
+
>>> fl
|
| 79 |
+
<FrozenList(frozen=False, [17, 42, 'spam', 'Vikings'])>
|
| 80 |
+
>>> fl.freeze()
|
| 81 |
+
>>> fl
|
| 82 |
+
<FrozenList(frozen=True, [17, 42, 'spam', 'Vikings'])>
|
| 83 |
+
>>> fl.frozen
|
| 84 |
+
True
|
| 85 |
+
>>> fl.append("Monty")
|
| 86 |
+
Traceback (most recent call last):
|
| 87 |
+
File "<stdin>", line 1, in <module>
|
| 88 |
+
File "frozenlist/_frozenlist.pyx", line 97, in frozenlist._frozenlist.FrozenList.append
|
| 89 |
+
self._check_frozen()
|
| 90 |
+
File "frozenlist/_frozenlist.pyx", line 19, in frozenlist._frozenlist.FrozenList._check_frozen
|
| 91 |
+
raise RuntimeError("Cannot modify frozen list.")
|
| 92 |
+
RuntimeError: Cannot modify frozen list.
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
FrozenList is also hashable, but only when frozen. Otherwise it also throws a RuntimeError:
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
>>> fl = FrozenList([17, 42, 'spam'])
|
| 99 |
+
>>> hash(fl)
|
| 100 |
+
Traceback (most recent call last):
|
| 101 |
+
File "<stdin>", line 1, in <module>
|
| 102 |
+
File "frozenlist/_frozenlist.pyx", line 111, in frozenlist._frozenlist.FrozenList.__hash__
|
| 103 |
+
raise RuntimeError("Cannot hash unfrozen list.")
|
| 104 |
+
RuntimeError: Cannot hash unfrozen list.
|
| 105 |
+
>>> fl.freeze()
|
| 106 |
+
>>> hash(fl)
|
| 107 |
+
3713081631934410656
|
| 108 |
+
>>> dictionary = {fl: 'Vikings'} # frozen fl can be a dict key
|
| 109 |
+
>>> dictionary
|
| 110 |
+
{<FrozenList(frozen=True, [1, 2])>: 'Vikings'}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
Installation
|
| 114 |
+
------------
|
| 115 |
+
|
| 116 |
+
::
|
| 117 |
+
|
| 118 |
+
$ pip install frozenlist
|
| 119 |
+
|
| 120 |
+
The library requires Python 3.8 or newer.
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
Documentation
|
| 124 |
+
-------------
|
| 125 |
+
|
| 126 |
+
https://frozenlist.aio-libs.org
|
| 127 |
+
|
| 128 |
+
Communication channels
|
| 129 |
+
----------------------
|
| 130 |
+
|
| 131 |
+
We have a *Matrix Space* `#aio-libs-space:matrix.org
|
| 132 |
+
<https://matrix.to/#/%23aio-libs-space:matrix.org>`_ which is
|
| 133 |
+
also accessible via Gitter.
|
| 134 |
+
|
| 135 |
+
Requirements
|
| 136 |
+
------------
|
| 137 |
+
|
| 138 |
+
- Python >= 3.8
|
| 139 |
+
|
| 140 |
+
License
|
| 141 |
+
-------
|
| 142 |
+
|
| 143 |
+
``frozenlist`` is offered under the Apache 2 license.
|
| 144 |
+
|
| 145 |
+
Source code
|
| 146 |
+
-----------
|
| 147 |
+
|
| 148 |
+
The project is hosted on GitHub_
|
| 149 |
+
|
| 150 |
+
Please file an issue in the `bug tracker
|
| 151 |
+
<https://github.com/aio-libs/frozenlist/issues>`_ if you have found a bug
|
| 152 |
+
or have some suggestions to improve the library.
|
| 153 |
+
|
| 154 |
+
.. _GitHub: https://github.com/aio-libs/frozenlist
|
| 155 |
+
|
| 156 |
+
=========
|
| 157 |
+
Changelog
|
| 158 |
+
=========
|
| 159 |
+
|
| 160 |
+
..
|
| 161 |
+
You should *NOT* be adding new change log entries to this file, this
|
| 162 |
+
file is managed by towncrier. You *may* edit previous change logs to
|
| 163 |
+
fix problems like typo corrections or such.
|
| 164 |
+
To add a new change log entry, please see
|
| 165 |
+
https://pip.pypa.io/en/latest/development/contributing/#news-entries
|
| 166 |
+
we named the news folder "changes".
|
| 167 |
+
|
| 168 |
+
WARNING: Don't drop the next directive!
|
| 169 |
+
|
| 170 |
+
.. towncrier release notes start
|
| 171 |
+
|
| 172 |
+
1.5.0 (2024-10-22)
|
| 173 |
+
==================
|
| 174 |
+
|
| 175 |
+
Bug fixes
|
| 176 |
+
---------
|
| 177 |
+
|
| 178 |
+
- An incorrect signature of the ``__class_getitem__`` class method
|
| 179 |
+
has been fixed, adding a missing ``class_item`` argument under
|
| 180 |
+
Python 3.8 and older.
|
| 181 |
+
|
| 182 |
+
This change also improves the code coverage of this method that
|
| 183 |
+
was previously missing -- by `@webknjaz <https://github.com/sponsors/webknjaz>`__.
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
*Related issues and pull requests on GitHub:*
|
| 187 |
+
`#567 <https://github.com/aio-libs/frozenlist/issues/567>`__, `#571 <https://github.com/aio-libs/frozenlist/issues/571>`__.
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
Improved documentation
|
| 191 |
+
----------------------
|
| 192 |
+
|
| 193 |
+
- Rendered issue, PR, and commit links now lead to
|
| 194 |
+
``frozenlist``'s repo instead of ``yarl``'s repo.
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
*Related issues and pull requests on GitHub:*
|
| 198 |
+
`#573 <https://github.com/aio-libs/frozenlist/issues/573>`__.
|
| 199 |
+
|
| 200 |
+
- On the ``Contributing docs`` page,
|
| 201 |
+
a link to the ``Towncrier philosophy`` has been fixed.
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
*Related issues and pull requests on GitHub:*
|
| 205 |
+
`#574 <https://github.com/aio-libs/frozenlist/issues/574>`__.
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
Packaging updates and notes for downstreams
|
| 209 |
+
-------------------------------------------
|
| 210 |
+
|
| 211 |
+
- A name of a temporary building directory now reflects
|
| 212 |
+
that it's related to ``frozenlist``, not ``yarl``.
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
*Related issues and pull requests on GitHub:*
|
| 216 |
+
`#573 <https://github.com/aio-libs/frozenlist/issues/573>`__.
|
| 217 |
+
|
| 218 |
+
- Declared Python 3.13 supported officially in the distribution package metadata.
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
*Related issues and pull requests on GitHub:*
|
| 222 |
+
`#595 <https://github.com/aio-libs/frozenlist/issues/595>`__.
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
----
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
1.4.1 (2023-12-15)
|
| 229 |
+
==================
|
| 230 |
+
|
| 231 |
+
Packaging updates and notes for downstreams
|
| 232 |
+
-------------------------------------------
|
| 233 |
+
|
| 234 |
+
- Declared Python 3.12 and PyPy 3.8-3.10 supported officially
|
| 235 |
+
in the distribution package metadata.
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
*Related issues and pull requests on GitHub:*
|
| 239 |
+
`#553 <https://github.com/aio-libs/frozenlist/issues/553>`__.
|
| 240 |
+
|
| 241 |
+
- Replaced the packaging is replaced from an old-fashioned ``setup.py`` to an
|
| 242 |
+
in-tree `PEP 517 <https://peps.python.org/pep-517>`__ build backend -- by `@webknjaz <https://github.com/sponsors/webknjaz>`__.
|
| 243 |
+
|
| 244 |
+
Whenever the end-users or downstream packagers need to build ``frozenlist``
|
| 245 |
+
from source (a Git checkout or an sdist), they may pass a ``config_settings``
|
| 246 |
+
flag ``pure-python``. If this flag is not set, a C-extension will be built
|
| 247 |
+
and included into the distribution.
|
| 248 |
+
|
| 249 |
+
Here is how this can be done with ``pip``:
|
| 250 |
+
|
| 251 |
+
.. code-block:: console
|
| 252 |
+
|
| 253 |
+
$ python3 -m pip install . --config-settings=pure-python=
|
| 254 |
+
|
| 255 |
+
This will also work with ``-e | --editable``.
|
| 256 |
+
|
| 257 |
+
The same can be achieved via ``pypa/build``:
|
| 258 |
+
|
| 259 |
+
.. code-block:: console
|
| 260 |
+
|
| 261 |
+
$ python3 -m build --config-setting=pure-python=
|
| 262 |
+
|
| 263 |
+
Adding ``-w | --wheel`` can force ``pypa/build`` produce a wheel from source
|
| 264 |
+
directly, as opposed to building an ``sdist`` and then building from it.
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
*Related issues and pull requests on GitHub:*
|
| 268 |
+
`#560 <https://github.com/aio-libs/frozenlist/issues/560>`__.
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
Contributor-facing changes
|
| 272 |
+
--------------------------
|
| 273 |
+
|
| 274 |
+
- It is now possible to request line tracing in Cython builds using the
|
| 275 |
+
``with-cython-tracing`` `PEP 517 <https://peps.python.org/pep-517>`__ config setting
|
| 276 |
+
-- `@webknjaz <https://github.com/sponsors/webknjaz>`__.
|
| 277 |
+
|
| 278 |
+
This can be used in CI and development environment to measure coverage
|
| 279 |
+
on Cython modules, but is not normally useful to the end-users or
|
| 280 |
+
downstream packagers.
|
| 281 |
+
|
| 282 |
+
Here's a usage example:
|
| 283 |
+
|
| 284 |
+
.. code-block:: console
|
| 285 |
+
|
| 286 |
+
$ python3 -Im pip install . --config-settings=with-cython-tracing=true
|
| 287 |
+
|
| 288 |
+
For editable installs, this setting is on by default. Otherwise, it's
|
| 289 |
+
off unless requested explicitly.
|
| 290 |
+
|
| 291 |
+
The following produces C-files required for the Cython coverage
|
| 292 |
+
plugin to map the measurements back to the PYX-files:
|
| 293 |
+
|
| 294 |
+
.. code-block:: console
|
| 295 |
+
|
| 296 |
+
$ python -Im pip install -e .
|
| 297 |
+
|
| 298 |
+
Alternatively, the ``FROZENLIST_CYTHON_TRACING=1`` environment variable
|
| 299 |
+
can be set to do the same as the `PEP 517 <https://peps.python.org/pep-517>`__ config setting.
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
*Related issues and pull requests on GitHub:*
|
| 303 |
+
`#560 <https://github.com/aio-libs/frozenlist/issues/560>`__.
|
| 304 |
+
|
| 305 |
+
- Coverage collection has been implemented for the Cython modules
|
| 306 |
+
-- by `@webknjaz <https://github.com/sponsors/webknjaz>`__.
|
| 307 |
+
|
| 308 |
+
It will also be reported to Codecov from any non-release CI jobs.
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
*Related issues and pull requests on GitHub:*
|
| 312 |
+
`#561 <https://github.com/aio-libs/frozenlist/issues/561>`__.
|
| 313 |
+
|
| 314 |
+
- A step-by-step ``Release Guide`` guide has
|
| 315 |
+
been added, describing how to release *frozenlist* -- by `@webknjaz <https://github.com/sponsors/webknjaz>`__.
|
| 316 |
+
|
| 317 |
+
This is primarily targeting the maintainers.
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
*Related issues and pull requests on GitHub:*
|
| 321 |
+
`#563 <https://github.com/aio-libs/frozenlist/issues/563>`__.
|
| 322 |
+
|
| 323 |
+
- Detailed ``Contributing Guidelines`` on
|
| 324 |
+
authoring the changelog fragments have been published in the
|
| 325 |
+
documentation -- by `@webknjaz <https://github.com/sponsors/webknjaz>`__.
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
*Related issues and pull requests on GitHub:*
|
| 329 |
+
`#564 <https://github.com/aio-libs/frozenlist/issues/564>`__.
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
----
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
1.4.0 (2023-07-12)
|
| 336 |
+
==================
|
| 337 |
+
|
| 338 |
+
The published source distribution package became buildable
|
| 339 |
+
under Python 3.12.
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
----
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
Bugfixes
|
| 346 |
+
--------
|
| 347 |
+
|
| 348 |
+
- Removed an unused ``typing.Tuple`` import
|
| 349 |
+
`#411 <https://github.com/aio-libs/frozenlist/issues/411>`_
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
Deprecations and Removals
|
| 353 |
+
-------------------------
|
| 354 |
+
|
| 355 |
+
- Dropped Python 3.7 support.
|
| 356 |
+
`#413 <https://github.com/aio-libs/frozenlist/issues/413>`_
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
Misc
|
| 360 |
+
----
|
| 361 |
+
|
| 362 |
+
- `#410 <https://github.com/aio-libs/frozenlist/issues/410>`_, `#433 <https://github.com/aio-libs/frozenlist/issues/433>`_
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
----
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
1.3.3 (2022-11-08)
|
| 369 |
+
==================
|
| 370 |
+
|
| 371 |
+
- Fixed CI runs when creating a new release, where new towncrier versions
|
| 372 |
+
fail when the current version section is already present.
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
----
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
1.3.2 (2022-11-08)
|
| 379 |
+
==================
|
| 380 |
+
|
| 381 |
+
Misc
|
| 382 |
+
----
|
| 383 |
+
|
| 384 |
+
- Updated the CI runs to better check for test results and to avoid deprecated syntax. `#327 <https://github.com/aio-libs/frozenlist/issues/327>`_
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
----
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
1.3.1 (2022-08-02)
|
| 391 |
+
==================
|
| 392 |
+
|
| 393 |
+
The published source distribution package became buildable
|
| 394 |
+
under Python 3.11.
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
----
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
1.3.0 (2022-01-18)
|
| 401 |
+
==================
|
| 402 |
+
|
| 403 |
+
Bugfixes
|
| 404 |
+
--------
|
| 405 |
+
|
| 406 |
+
- Do not install C sources with binary distributions.
|
| 407 |
+
`#250 <https://github.com/aio-libs/frozenlist/issues/250>`_
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
Deprecations and Removals
|
| 411 |
+
-------------------------
|
| 412 |
+
|
| 413 |
+
- Dropped Python 3.6 support
|
| 414 |
+
`#274 <https://github.com/aio-libs/frozenlist/issues/274>`_
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
----
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
1.2.0 (2021-10-16)
|
| 421 |
+
==================
|
| 422 |
+
|
| 423 |
+
Features
|
| 424 |
+
--------
|
| 425 |
+
|
| 426 |
+
- ``FrozenList`` now supports being used as a generic type as per PEP 585, e.g. ``frozen_int_list: FrozenList[int]`` (requires Python 3.9 or newer).
|
| 427 |
+
`#172 <https://github.com/aio-libs/frozenlist/issues/172>`_
|
| 428 |
+
- Added support for Python 3.10.
|
| 429 |
+
`#227 <https://github.com/aio-libs/frozenlist/issues/227>`_
|
| 430 |
+
- Started shipping platform-specific wheels with the ``musl`` tag targeting typical Alpine Linux runtimes.
|
| 431 |
+
`#227 <https://github.com/aio-libs/frozenlist/issues/227>`_
|
| 432 |
+
- Started shipping platform-specific arm64 wheels for Apple Silicon.
|
| 433 |
+
`#227 <https://github.com/aio-libs/frozenlist/issues/227>`_
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
----
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
1.1.1 (2020-11-14)
|
| 440 |
+
==================
|
| 441 |
+
|
| 442 |
+
Bugfixes
|
| 443 |
+
--------
|
| 444 |
+
|
| 445 |
+
- Provide x86 Windows wheels.
|
| 446 |
+
`#169 <https://github.com/aio-libs/frozenlist/issues/169>`_
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
----
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
1.1.0 (2020-10-13)
|
| 453 |
+
==================
|
| 454 |
+
|
| 455 |
+
Features
|
| 456 |
+
--------
|
| 457 |
+
|
| 458 |
+
- Add support for hashing of a frozen list.
|
| 459 |
+
`#136 <https://github.com/aio-libs/frozenlist/issues/136>`_
|
| 460 |
+
|
| 461 |
+
- Support Python 3.8 and 3.9.
|
| 462 |
+
|
| 463 |
+
- Provide wheels for ``aarch64``, ``i686``, ``ppc64le``, ``s390x`` architectures on
|
| 464 |
+
Linux as well as ``x86_64``.
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
----
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
1.0.0 (2019-11-09)
|
| 471 |
+
==================
|
| 472 |
+
|
| 473 |
+
Deprecations and Removals
|
| 474 |
+
-------------------------
|
| 475 |
+
|
| 476 |
+
- Dropped support for Python 3.5; only 3.6, 3.7 and 3.8 are supported going forward.
|
| 477 |
+
`#24 <https://github.com/aio-libs/frozenlist/issues/24>`_
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/RECORD
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
frozenlist-1.5.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
| 2 |
+
frozenlist-1.5.0.dist-info/LICENSE,sha256=b9UkPpLdf5jsacesN3co50kFcJ_1J6W_mNbQJjwE9bY,11332
|
| 3 |
+
frozenlist-1.5.0.dist-info/METADATA,sha256=BpQvB7z2NbU3f4XTQDvhAZ9L08WR4XiYajilj9IY6Yk,13762
|
| 4 |
+
frozenlist-1.5.0.dist-info/RECORD,,
|
| 5 |
+
frozenlist-1.5.0.dist-info/WHEEL,sha256=64hRuO2b8JU2aeheZgbK9oQwal3JVqwtqRhpQNr8ZdQ,224
|
| 6 |
+
frozenlist-1.5.0.dist-info/top_level.txt,sha256=jivtxsPXA3nK3WBWW2LW5Mtu_GHt8UZA13NeCs2cKuA,11
|
| 7 |
+
frozenlist/__init__.py,sha256=ymVtnW3MinO-Ux3cBj_PLEpXnmLawk45el8vcX6IkWY,2371
|
| 8 |
+
frozenlist/__init__.pyi,sha256=vMEoES1xGegPtVXoCi9XydEeHsyuIq-KdeXwP5PdsaA,1470
|
| 9 |
+
frozenlist/__pycache__/__init__.cpython-312.pyc,,
|
| 10 |
+
frozenlist/_frozenlist.cpython-312-x86_64-linux-gnu.so,sha256=n65G8t1lqSUcWICd9rjOJujV1lxtniI2JJQQXtc7BjQ,961592
|
| 11 |
+
frozenlist/_frozenlist.pyx,sha256=4YturclNF7wioO7YX3Vzl7Ldb2-iswe6UrjJOMKSswU,2993
|
| 12 |
+
frozenlist/py.typed,sha256=sow9soTwP9T_gEAQSVh7Gb8855h04Nwmhs2We-JRgZM,7
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/WHEEL
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Wheel-Version: 1.0
|
| 2 |
+
Generator: setuptools (75.2.0)
|
| 3 |
+
Root-Is-Purelib: false
|
| 4 |
+
Tag: cp312-cp312-manylinux_2_5_x86_64
|
| 5 |
+
Tag: cp312-cp312-manylinux1_x86_64
|
| 6 |
+
Tag: cp312-cp312-manylinux_2_17_x86_64
|
| 7 |
+
Tag: cp312-cp312-manylinux2014_x86_64
|
| 8 |
+
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
frozenlist
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/INSTALLER
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/LICENSE.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
The MIT License (MIT)
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2016 Nathaniel J. Smith <njs@pobox.com> and other contributors
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining
|
| 6 |
+
a copy of this software and associated documentation files (the
|
| 7 |
+
"Software"), to deal in the Software without restriction, including
|
| 8 |
+
without limitation the rights to use, copy, modify, merge, publish,
|
| 9 |
+
distribute, sublicense, and/or sell copies of the Software, and to
|
| 10 |
+
permit persons to whom the Software is furnished to do so, subject to
|
| 11 |
+
the following conditions:
|
| 12 |
+
|
| 13 |
+
The above copyright notice and this permission notice shall be
|
| 14 |
+
included in all copies or substantial portions of the Software.
|
| 15 |
+
|
| 16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
| 17 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 18 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
| 19 |
+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
| 20 |
+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
| 21 |
+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
| 22 |
+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/METADATA
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: h11
|
| 3 |
+
Version: 0.14.0
|
| 4 |
+
Summary: A pure-Python, bring-your-own-I/O implementation of HTTP/1.1
|
| 5 |
+
Home-page: https://github.com/python-hyper/h11
|
| 6 |
+
Author: Nathaniel J. Smith
|
| 7 |
+
Author-email: njs@pobox.com
|
| 8 |
+
License: MIT
|
| 9 |
+
Classifier: Development Status :: 3 - Alpha
|
| 10 |
+
Classifier: Intended Audience :: Developers
|
| 11 |
+
Classifier: License :: OSI Approved :: MIT License
|
| 12 |
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
| 13 |
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
| 14 |
+
Classifier: Programming Language :: Python :: 3
|
| 15 |
+
Classifier: Programming Language :: Python :: 3 :: Only
|
| 16 |
+
Classifier: Programming Language :: Python :: 3.7
|
| 17 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 18 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 19 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 20 |
+
Classifier: Topic :: Internet :: WWW/HTTP
|
| 21 |
+
Classifier: Topic :: System :: Networking
|
| 22 |
+
Requires-Python: >=3.7
|
| 23 |
+
License-File: LICENSE.txt
|
| 24 |
+
Requires-Dist: typing-extensions ; python_version < "3.8"
|
| 25 |
+
|
| 26 |
+
h11
|
| 27 |
+
===
|
| 28 |
+
|
| 29 |
+
.. image:: https://travis-ci.org/python-hyper/h11.svg?branch=master
|
| 30 |
+
:target: https://travis-ci.org/python-hyper/h11
|
| 31 |
+
:alt: Automated test status
|
| 32 |
+
|
| 33 |
+
.. image:: https://codecov.io/gh/python-hyper/h11/branch/master/graph/badge.svg
|
| 34 |
+
:target: https://codecov.io/gh/python-hyper/h11
|
| 35 |
+
:alt: Test coverage
|
| 36 |
+
|
| 37 |
+
.. image:: https://readthedocs.org/projects/h11/badge/?version=latest
|
| 38 |
+
:target: http://h11.readthedocs.io/en/latest/?badge=latest
|
| 39 |
+
:alt: Documentation Status
|
| 40 |
+
|
| 41 |
+
This is a little HTTP/1.1 library written from scratch in Python,
|
| 42 |
+
heavily inspired by `hyper-h2 <https://hyper-h2.readthedocs.io/>`_.
|
| 43 |
+
|
| 44 |
+
It's a "bring-your-own-I/O" library; h11 contains no IO code
|
| 45 |
+
whatsoever. This means you can hook h11 up to your favorite network
|
| 46 |
+
API, and that could be anything you want: synchronous, threaded,
|
| 47 |
+
asynchronous, or your own implementation of `RFC 6214
|
| 48 |
+
<https://tools.ietf.org/html/rfc6214>`_ -- h11 won't judge you.
|
| 49 |
+
(Compare this to the current state of the art, where every time a `new
|
| 50 |
+
network API <https://trio.readthedocs.io/>`_ comes along then someone
|
| 51 |
+
gets to start over reimplementing the entire HTTP protocol from
|
| 52 |
+
scratch.) Cory Benfield made an `excellent blog post describing the
|
| 53 |
+
benefits of this approach
|
| 54 |
+
<https://lukasa.co.uk/2015/10/The_New_Hyper/>`_, or if you like video
|
| 55 |
+
then here's his `PyCon 2016 talk on the same theme
|
| 56 |
+
<https://www.youtube.com/watch?v=7cC3_jGwl_U>`_.
|
| 57 |
+
|
| 58 |
+
This also means that h11 is not immediately useful out of the box:
|
| 59 |
+
it's a toolkit for building programs that speak HTTP, not something
|
| 60 |
+
that could directly replace ``requests`` or ``twisted.web`` or
|
| 61 |
+
whatever. But h11 makes it much easier to implement something like
|
| 62 |
+
``requests`` or ``twisted.web``.
|
| 63 |
+
|
| 64 |
+
At a high level, working with h11 goes like this:
|
| 65 |
+
|
| 66 |
+
1) First, create an ``h11.Connection`` object to track the state of a
|
| 67 |
+
single HTTP/1.1 connection.
|
| 68 |
+
|
| 69 |
+
2) When you read data off the network, pass it to
|
| 70 |
+
``conn.receive_data(...)``; you'll get back a list of objects
|
| 71 |
+
representing high-level HTTP "events".
|
| 72 |
+
|
| 73 |
+
3) When you want to send a high-level HTTP event, create the
|
| 74 |
+
corresponding "event" object and pass it to ``conn.send(...)``;
|
| 75 |
+
this will give you back some bytes that you can then push out
|
| 76 |
+
through the network.
|
| 77 |
+
|
| 78 |
+
For example, a client might instantiate and then send a
|
| 79 |
+
``h11.Request`` object, then zero or more ``h11.Data`` objects for the
|
| 80 |
+
request body (e.g., if this is a POST), and then a
|
| 81 |
+
``h11.EndOfMessage`` to indicate the end of the message. Then the
|
| 82 |
+
server would then send back a ``h11.Response``, some ``h11.Data``, and
|
| 83 |
+
its own ``h11.EndOfMessage``. If either side violates the protocol,
|
| 84 |
+
you'll get a ``h11.ProtocolError`` exception.
|
| 85 |
+
|
| 86 |
+
h11 is suitable for implementing both servers and clients, and has a
|
| 87 |
+
pleasantly symmetric API: the events you send as a client are exactly
|
| 88 |
+
the ones that you receive as a server and vice-versa.
|
| 89 |
+
|
| 90 |
+
`Here's an example of a tiny HTTP client
|
| 91 |
+
<https://github.com/python-hyper/h11/blob/master/examples/basic-client.py>`_
|
| 92 |
+
|
| 93 |
+
It also has `a fine manual <https://h11.readthedocs.io/>`_.
|
| 94 |
+
|
| 95 |
+
FAQ
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
*Whyyyyy?*
|
| 99 |
+
|
| 100 |
+
I wanted to play with HTTP in `Curio
|
| 101 |
+
<https://curio.readthedocs.io/en/latest/tutorial.html>`__ and `Trio
|
| 102 |
+
<https://trio.readthedocs.io>`__, which at the time didn't have any
|
| 103 |
+
HTTP libraries. So I thought, no big deal, Python has, like, a dozen
|
| 104 |
+
different implementations of HTTP, surely I can find one that's
|
| 105 |
+
reusable. I didn't find one, but I did find Cory's call-to-arms
|
| 106 |
+
blog-post. So I figured, well, fine, if I have to implement HTTP from
|
| 107 |
+
scratch, at least I can make sure no-one *else* has to ever again.
|
| 108 |
+
|
| 109 |
+
*Should I use it?*
|
| 110 |
+
|
| 111 |
+
Maybe. You should be aware that it's a very young project. But, it's
|
| 112 |
+
feature complete and has an exhaustive test-suite and complete docs,
|
| 113 |
+
so the next step is for people to try using it and see how it goes
|
| 114 |
+
:-). If you do then please let us know -- if nothing else we'll want
|
| 115 |
+
to talk to you before making any incompatible changes!
|
| 116 |
+
|
| 117 |
+
*What are the features/limitations?*
|
| 118 |
+
|
| 119 |
+
Roughly speaking, it's trying to be a robust, complete, and non-hacky
|
| 120 |
+
implementation of the first "chapter" of the HTTP/1.1 spec: `RFC 7230:
|
| 121 |
+
HTTP/1.1 Message Syntax and Routing
|
| 122 |
+
<https://tools.ietf.org/html/rfc7230>`_. That is, it mostly focuses on
|
| 123 |
+
implementing HTTP at the level of taking bytes on and off the wire,
|
| 124 |
+
and the headers related to that, and tries to be anal about spec
|
| 125 |
+
conformance. It doesn't know about higher-level concerns like URL
|
| 126 |
+
routing, conditional GETs, cross-origin cookie policies, or content
|
| 127 |
+
negotiation. But it does know how to take care of framing,
|
| 128 |
+
cross-version differences in keep-alive handling, and the "obsolete
|
| 129 |
+
line folding" rule, so you can focus your energies on the hard /
|
| 130 |
+
interesting parts for your application, and it tries to support the
|
| 131 |
+
full specification in the sense that any useful HTTP/1.1 conformant
|
| 132 |
+
application should be able to use h11.
|
| 133 |
+
|
| 134 |
+
It's pure Python, and has no dependencies outside of the standard
|
| 135 |
+
library.
|
| 136 |
+
|
| 137 |
+
It has a test suite with 100.0% coverage for both statements and
|
| 138 |
+
branches.
|
| 139 |
+
|
| 140 |
+
Currently it supports Python 3 (testing on 3.7-3.10) and PyPy 3.
|
| 141 |
+
The last Python 2-compatible version was h11 0.11.x.
|
| 142 |
+
(Originally it had a Cython wrapper for `http-parser
|
| 143 |
+
<https://github.com/nodejs/http-parser>`_ and a beautiful nested state
|
| 144 |
+
machine implemented with ``yield from`` to postprocess the output. But
|
| 145 |
+
I had to take these out -- the new *parser* needs fewer lines-of-code
|
| 146 |
+
than the old *parser wrapper*, is written in pure Python, uses no
|
| 147 |
+
exotic language syntax, and has more features. It's sad, really; that
|
| 148 |
+
old state machine was really slick. I just need a few sentences here
|
| 149 |
+
to mourn that.)
|
| 150 |
+
|
| 151 |
+
I don't know how fast it is. I haven't benchmarked or profiled it yet,
|
| 152 |
+
so it's probably got a few pointless hot spots, and I've been trying
|
| 153 |
+
to err on the side of simplicity and robustness instead of
|
| 154 |
+
micro-optimization. But at the architectural level I tried hard to
|
| 155 |
+
avoid fundamentally bad decisions, e.g., I believe that all the
|
| 156 |
+
parsing algorithms remain linear-time even in the face of pathological
|
| 157 |
+
input like slowloris, and there are no byte-by-byte loops. (I also
|
| 158 |
+
believe that it maintains bounded memory usage in the face of
|
| 159 |
+
arbitrary/pathological input.)
|
| 160 |
+
|
| 161 |
+
The whole library is ~800 lines-of-code. You can read and understand
|
| 162 |
+
the whole thing in less than an hour. Most of the energy invested in
|
| 163 |
+
this so far has been spent on trying to keep things simple by
|
| 164 |
+
minimizing special-cases and ad hoc state manipulation; even though it
|
| 165 |
+
is now quite small and simple, I'm still annoyed that I haven't
|
| 166 |
+
figured out how to make it even smaller and simpler. (Unfortunately,
|
| 167 |
+
HTTP does not lend itself to simplicity.)
|
| 168 |
+
|
| 169 |
+
The API is ~feature complete and I don't expect the general outlines
|
| 170 |
+
to change much, but you can't judge an API's ergonomics until you
|
| 171 |
+
actually document and use it, so I'd expect some changes in the
|
| 172 |
+
details.
|
| 173 |
+
|
| 174 |
+
*How do I try it?*
|
| 175 |
+
|
| 176 |
+
.. code-block:: sh
|
| 177 |
+
|
| 178 |
+
$ pip install h11
|
| 179 |
+
$ git clone git@github.com:python-hyper/h11
|
| 180 |
+
$ cd h11/examples
|
| 181 |
+
$ python basic-client.py
|
| 182 |
+
|
| 183 |
+
and go from there.
|
| 184 |
+
|
| 185 |
+
*License?*
|
| 186 |
+
|
| 187 |
+
MIT
|
| 188 |
+
|
| 189 |
+
*Code of conduct?*
|
| 190 |
+
|
| 191 |
+
Contributors are requested to follow our `code of conduct
|
| 192 |
+
<https://github.com/python-hyper/h11/blob/master/CODE_OF_CONDUCT.md>`_ in
|
| 193 |
+
all project spaces.
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/RECORD
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
h11-0.14.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
| 2 |
+
h11-0.14.0.dist-info/LICENSE.txt,sha256=N9tbuFkm2yikJ6JYZ_ELEjIAOuob5pzLhRE4rbjm82E,1124
|
| 3 |
+
h11-0.14.0.dist-info/METADATA,sha256=B7pZ0m7WBXNs17vl6hUH9bJTL9s37DaGvY31w7jNxSg,8175
|
| 4 |
+
h11-0.14.0.dist-info/RECORD,,
|
| 5 |
+
h11-0.14.0.dist-info/WHEEL,sha256=ewwEueio1C2XeHTvT17n8dZUJgOvyCWCt0WVNLClP9o,92
|
| 6 |
+
h11-0.14.0.dist-info/top_level.txt,sha256=F7dC4jl3zeh8TGHEPaWJrMbeuoWbS379Gwdi-Yvdcis,4
|
| 7 |
+
h11/__init__.py,sha256=iO1KzkSO42yZ6ffg-VMgbx_ZVTWGUY00nRYEWn-s3kY,1507
|
| 8 |
+
h11/__pycache__/__init__.cpython-312.pyc,,
|
| 9 |
+
h11/__pycache__/_abnf.cpython-312.pyc,,
|
| 10 |
+
h11/__pycache__/_connection.cpython-312.pyc,,
|
| 11 |
+
h11/__pycache__/_events.cpython-312.pyc,,
|
| 12 |
+
h11/__pycache__/_headers.cpython-312.pyc,,
|
| 13 |
+
h11/__pycache__/_readers.cpython-312.pyc,,
|
| 14 |
+
h11/__pycache__/_receivebuffer.cpython-312.pyc,,
|
| 15 |
+
h11/__pycache__/_state.cpython-312.pyc,,
|
| 16 |
+
h11/__pycache__/_util.cpython-312.pyc,,
|
| 17 |
+
h11/__pycache__/_version.cpython-312.pyc,,
|
| 18 |
+
h11/__pycache__/_writers.cpython-312.pyc,,
|
| 19 |
+
h11/_abnf.py,sha256=ybixr0xsupnkA6GFAyMubuXF6Tc1lb_hF890NgCsfNc,4815
|
| 20 |
+
h11/_connection.py,sha256=eS2sorMD0zKLCFiB9lW9W9F_Nzny2tjHa4e6s1ujr1c,26539
|
| 21 |
+
h11/_events.py,sha256=LEfuvg1AbhHaVRwxCd0I-pFn9-ezUOaoL8o2Kvy1PBA,11816
|
| 22 |
+
h11/_headers.py,sha256=RqB8cd8CN0blYPzcLe5qeCh-phv6D1U_CHj4hs67lgQ,10230
|
| 23 |
+
h11/_readers.py,sha256=EbSed0jzwVUiD1nOPAeUcVE4Flf3wXkxfb8c06-OTBM,8383
|
| 24 |
+
h11/_receivebuffer.py,sha256=xrspsdsNgWFxRfQcTXxR8RrdjRXXTK0Io5cQYWpJ1Ws,5252
|
| 25 |
+
h11/_state.py,sha256=k1VL6SDbaPkSrZ-49ewCXDpuiUS69_46YhbWjuV1qEY,13300
|
| 26 |
+
h11/_util.py,sha256=LWkkjXyJaFlAy6Lt39w73UStklFT5ovcvo0TkY7RYuk,4888
|
| 27 |
+
h11/_version.py,sha256=LVyTdiZRzIIEv79UyOgbM5iUrJUllEzlCWaJEYBY1zc,686
|
| 28 |
+
h11/_writers.py,sha256=oFKm6PtjeHfbj4RLX7VB7KDc1gIY53gXG3_HR9ltmTA,5081
|
| 29 |
+
h11/py.typed,sha256=sow9soTwP9T_gEAQSVh7Gb8855h04Nwmhs2We-JRgZM,7
|
| 30 |
+
h11/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
| 31 |
+
h11/tests/__pycache__/__init__.cpython-312.pyc,,
|
| 32 |
+
h11/tests/__pycache__/helpers.cpython-312.pyc,,
|
| 33 |
+
h11/tests/__pycache__/test_against_stdlib_http.cpython-312.pyc,,
|
| 34 |
+
h11/tests/__pycache__/test_connection.cpython-312.pyc,,
|
| 35 |
+
h11/tests/__pycache__/test_events.cpython-312.pyc,,
|
| 36 |
+
h11/tests/__pycache__/test_headers.cpython-312.pyc,,
|
| 37 |
+
h11/tests/__pycache__/test_helpers.cpython-312.pyc,,
|
| 38 |
+
h11/tests/__pycache__/test_io.cpython-312.pyc,,
|
| 39 |
+
h11/tests/__pycache__/test_receivebuffer.cpython-312.pyc,,
|
| 40 |
+
h11/tests/__pycache__/test_state.cpython-312.pyc,,
|
| 41 |
+
h11/tests/__pycache__/test_util.cpython-312.pyc,,
|
| 42 |
+
h11/tests/data/test-file,sha256=ZJ03Rqs98oJw29OHzJg7LlMzyGQaRAY0r3AqBeM2wVU,65
|
| 43 |
+
h11/tests/helpers.py,sha256=a1EVG_p7xU4wRsa3tMPTRxuaKCmretok9sxXWvqfmQA,3355
|
| 44 |
+
h11/tests/test_against_stdlib_http.py,sha256=cojCHgHXFQ8gWhNlEEwl3trmOpN-5uDukRoHnElqo3A,3995
|
| 45 |
+
h11/tests/test_connection.py,sha256=ZbPLDPclKvjgjAhgk-WlCPBaf17c4XUIV2tpaW08jOI,38720
|
| 46 |
+
h11/tests/test_events.py,sha256=LPVLbcV-NvPNK9fW3rraR6Bdpz1hAlsWubMtNaJ5gHg,4657
|
| 47 |
+
h11/tests/test_headers.py,sha256=qd8T1Zenuz5GbD6wklSJ5G8VS7trrYgMV0jT-SMvqg8,5612
|
| 48 |
+
h11/tests/test_helpers.py,sha256=kAo0CEM4LGqmyyP2ZFmhsyq3UFJqoFfAbzu3hbWreRM,794
|
| 49 |
+
h11/tests/test_io.py,sha256=uCZVnjarkRBkudfC1ij-KSCQ71XWJhnkgkgWWkKgYPQ,16386
|
| 50 |
+
h11/tests/test_receivebuffer.py,sha256=3jGbeJM36Akqg_pAhPb7XzIn2NS6RhPg-Ryg8Eu6ytk,3454
|
| 51 |
+
h11/tests/test_state.py,sha256=rqll9WqFsJPE0zSrtCn9LH659mPKsDeXZ-DwXwleuBQ,8928
|
| 52 |
+
h11/tests/test_util.py,sha256=VO5L4nSFe4pgtSwKuv6u_6l0H7UeizF5WKuHTWreg70,2970
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/WHEEL
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Wheel-Version: 1.0
|
| 2 |
+
Generator: bdist_wheel (0.37.0)
|
| 3 |
+
Root-Is-Purelib: true
|
| 4 |
+
Tag: py3-none-any
|
| 5 |
+
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
h11
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/INSTALLER
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/LICENSE
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright 2016 Andrew Svetlov and aio-libs contributors
|
| 2 |
+
|
| 3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
you may not use this file except in compliance with the License.
|
| 5 |
+
You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
See the License for the specific language governing permissions and
|
| 13 |
+
limitations under the License.
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/METADATA
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: multidict
|
| 3 |
+
Version: 6.1.0
|
| 4 |
+
Summary: multidict implementation
|
| 5 |
+
Home-page: https://github.com/aio-libs/multidict
|
| 6 |
+
Author: Andrew Svetlov
|
| 7 |
+
Author-email: andrew.svetlov@gmail.com
|
| 8 |
+
License: Apache 2
|
| 9 |
+
Project-URL: Chat: Matrix, https://matrix.to/#/#aio-libs:matrix.org
|
| 10 |
+
Project-URL: Chat: Matrix Space, https://matrix.to/#/#aio-libs-space:matrix.org
|
| 11 |
+
Project-URL: CI: GitHub, https://github.com/aio-libs/multidict/actions
|
| 12 |
+
Project-URL: Code of Conduct, https://github.com/aio-libs/.github/blob/master/CODE_OF_CONDUCT.md
|
| 13 |
+
Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/multidict
|
| 14 |
+
Project-URL: Docs: Changelog, https://multidict.aio-libs.org/en/latest/changes/
|
| 15 |
+
Project-URL: Docs: RTD, https://multidict.aio-libs.org
|
| 16 |
+
Project-URL: GitHub: issues, https://github.com/aio-libs/multidict/issues
|
| 17 |
+
Project-URL: GitHub: repo, https://github.com/aio-libs/multidict
|
| 18 |
+
Classifier: Development Status :: 5 - Production/Stable
|
| 19 |
+
Classifier: Intended Audience :: Developers
|
| 20 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
| 21 |
+
Classifier: Programming Language :: Python
|
| 22 |
+
Classifier: Programming Language :: Python :: 3
|
| 23 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 24 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 25 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 26 |
+
Classifier: Programming Language :: Python :: 3.11
|
| 27 |
+
Classifier: Programming Language :: Python :: 3.12
|
| 28 |
+
Classifier: Programming Language :: Python :: 3.13
|
| 29 |
+
Requires-Python: >=3.8
|
| 30 |
+
Description-Content-Type: text/x-rst
|
| 31 |
+
License-File: LICENSE
|
| 32 |
+
Requires-Dist: typing-extensions >=4.1.0 ; python_version < "3.11"
|
| 33 |
+
|
| 34 |
+
=========
|
| 35 |
+
multidict
|
| 36 |
+
=========
|
| 37 |
+
|
| 38 |
+
.. image:: https://github.com/aio-libs/multidict/actions/workflows/ci-cd.yml/badge.svg
|
| 39 |
+
:target: https://github.com/aio-libs/multidict/actions
|
| 40 |
+
:alt: GitHub status for master branch
|
| 41 |
+
|
| 42 |
+
.. image:: https://codecov.io/gh/aio-libs/multidict/branch/master/graph/badge.svg
|
| 43 |
+
:target: https://codecov.io/gh/aio-libs/multidict
|
| 44 |
+
:alt: Coverage metrics
|
| 45 |
+
|
| 46 |
+
.. image:: https://img.shields.io/pypi/v/multidict.svg
|
| 47 |
+
:target: https://pypi.org/project/multidict
|
| 48 |
+
:alt: PyPI
|
| 49 |
+
|
| 50 |
+
.. image:: https://readthedocs.org/projects/multidict/badge/?version=latest
|
| 51 |
+
:target: https://multidict.aio-libs.org
|
| 52 |
+
:alt: Read The Docs build status badge
|
| 53 |
+
|
| 54 |
+
.. image:: https://img.shields.io/pypi/pyversions/multidict.svg
|
| 55 |
+
:target: https://pypi.org/project/multidict
|
| 56 |
+
:alt: Python versions
|
| 57 |
+
|
| 58 |
+
.. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat
|
| 59 |
+
:target: https://matrix.to/#/%23aio-libs:matrix.org
|
| 60 |
+
:alt: Matrix Room — #aio-libs:matrix.org
|
| 61 |
+
|
| 62 |
+
.. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat
|
| 63 |
+
:target: https://matrix.to/#/%23aio-libs-space:matrix.org
|
| 64 |
+
:alt: Matrix Space — #aio-libs-space:matrix.org
|
| 65 |
+
|
| 66 |
+
Multidict is dict-like collection of *key-value pairs* where key
|
| 67 |
+
might occur more than once in the container.
|
| 68 |
+
|
| 69 |
+
Introduction
|
| 70 |
+
------------
|
| 71 |
+
|
| 72 |
+
*HTTP Headers* and *URL query string* require specific data structure:
|
| 73 |
+
*multidict*. It behaves mostly like a regular ``dict`` but it may have
|
| 74 |
+
several *values* for the same *key* and *preserves insertion ordering*.
|
| 75 |
+
|
| 76 |
+
The *key* is ``str`` (or ``istr`` for case-insensitive dictionaries).
|
| 77 |
+
|
| 78 |
+
``multidict`` has four multidict classes:
|
| 79 |
+
``MultiDict``, ``MultiDictProxy``, ``CIMultiDict``
|
| 80 |
+
and ``CIMultiDictProxy``.
|
| 81 |
+
|
| 82 |
+
Immutable proxies (``MultiDictProxy`` and
|
| 83 |
+
``CIMultiDictProxy``) provide a dynamic view for the
|
| 84 |
+
proxied multidict, the view reflects underlying collection changes. They
|
| 85 |
+
implement the ``collections.abc.Mapping`` interface.
|
| 86 |
+
|
| 87 |
+
Regular mutable (``MultiDict`` and ``CIMultiDict``) classes
|
| 88 |
+
implement ``collections.abc.MutableMapping`` and allows them to change
|
| 89 |
+
their own content.
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
*Case insensitive* (``CIMultiDict`` and
|
| 93 |
+
``CIMultiDictProxy``) assume the *keys* are case
|
| 94 |
+
insensitive, e.g.::
|
| 95 |
+
|
| 96 |
+
>>> dct = CIMultiDict(key='val')
|
| 97 |
+
>>> 'Key' in dct
|
| 98 |
+
True
|
| 99 |
+
>>> dct['Key']
|
| 100 |
+
'val'
|
| 101 |
+
|
| 102 |
+
*Keys* should be ``str`` or ``istr`` instances.
|
| 103 |
+
|
| 104 |
+
The library has optional C Extensions for speed.
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
License
|
| 108 |
+
-------
|
| 109 |
+
|
| 110 |
+
Apache 2
|
| 111 |
+
|
| 112 |
+
Library Installation
|
| 113 |
+
--------------------
|
| 114 |
+
|
| 115 |
+
.. code-block:: bash
|
| 116 |
+
|
| 117 |
+
$ pip install multidict
|
| 118 |
+
|
| 119 |
+
The library is Python 3 only!
|
| 120 |
+
|
| 121 |
+
PyPI contains binary wheels for Linux, Windows and MacOS. If you want to install
|
| 122 |
+
``multidict`` on another operating system (or *Alpine Linux* inside a Docker) the
|
| 123 |
+
tarball will be used to compile the library from source. It requires a C compiler and
|
| 124 |
+
Python headers to be installed.
|
| 125 |
+
|
| 126 |
+
To skip the compilation, please use the `MULTIDICT_NO_EXTENSIONS` environment variable,
|
| 127 |
+
e.g.:
|
| 128 |
+
|
| 129 |
+
.. code-block:: bash
|
| 130 |
+
|
| 131 |
+
$ MULTIDICT_NO_EXTENSIONS=1 pip install multidict
|
| 132 |
+
|
| 133 |
+
Please note, the pure Python (uncompiled) version is about 20-50 times slower depending on
|
| 134 |
+
the usage scenario!!!
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
Changelog
|
| 139 |
+
---------
|
| 140 |
+
See `RTD page <http://multidict.aio-libs.org/en/latest/changes>`_.
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/RECORD
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
multidict-6.1.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
| 2 |
+
multidict-6.1.0.dist-info/LICENSE,sha256=k9Ealo4vDzY3PECBH_bSDhc_WMPKtYhM1mF7v9eVSSo,611
|
| 3 |
+
multidict-6.1.0.dist-info/METADATA,sha256=OnCx5DR4XPf64GIDK4XmcA2e7HLQ_784vMfEQy287kM,4979
|
| 4 |
+
multidict-6.1.0.dist-info/RECORD,,
|
| 5 |
+
multidict-6.1.0.dist-info/WHEEL,sha256=3FRagTIevYnyede1Gym_XNKguJrd07UOyEdLNhxNq20,151
|
| 6 |
+
multidict-6.1.0.dist-info/top_level.txt,sha256=-euDElkk5_qkmfIJ7WiqCab02ZlSFZWynejKg59qZQQ,10
|
| 7 |
+
multidict/__init__.py,sha256=p60Ag5UVACSli1txazSi85foCmHN-cg3qZDCuWdOKng,928
|
| 8 |
+
multidict/__init__.pyi,sha256=SbgC2ew1NvNXWlRKs9o0KhW4moozgMqgQ0OA4Re5JQQ,4840
|
| 9 |
+
multidict/__pycache__/__init__.cpython-312.pyc,,
|
| 10 |
+
multidict/__pycache__/_abc.cpython-312.pyc,,
|
| 11 |
+
multidict/__pycache__/_compat.cpython-312.pyc,,
|
| 12 |
+
multidict/__pycache__/_multidict_base.cpython-312.pyc,,
|
| 13 |
+
multidict/__pycache__/_multidict_py.cpython-312.pyc,,
|
| 14 |
+
multidict/_abc.py,sha256=Zvnrn4SBkrv4QTD7-ZzqNcoxw0f8KStLMPzGvBuGT2w,1190
|
| 15 |
+
multidict/_compat.py,sha256=uCNUpVHJSFOiKUJmRcz3SDqMpkb37C_csc29ijr8Evo,352
|
| 16 |
+
multidict/_multidict.cpython-312-x86_64-linux-gnu.so,sha256=6BwP62oLns2chEgPfwAa8DseIoF0wOWBe81pHjnlqhs,418968
|
| 17 |
+
multidict/_multidict_base.py,sha256=ZndtnZ5oc1sODKmXsv6F9kWvVNCda9xAEEFXkaPoFoA,3979
|
| 18 |
+
multidict/_multidict_py.py,sha256=57h4sYrRIu7EjMX4YpHVIZVrV9-q1KCW3F6rao10D3U,15050
|
| 19 |
+
multidict/py.typed,sha256=e9bmbH3UFxsabQrnNFPG9qxIXztwbcM6IKDYnvZwprY,15
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/WHEEL
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Wheel-Version: 1.0
|
| 2 |
+
Generator: setuptools (74.1.2)
|
| 3 |
+
Root-Is-Purelib: false
|
| 4 |
+
Tag: cp312-cp312-manylinux_2_17_x86_64
|
| 5 |
+
Tag: cp312-cp312-manylinux2014_x86_64
|
| 6 |
+
|