BryanW commited on
Commit
f12e61a
·
verified ·
1 Parent(s): a60ff7d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/__init__.py +0 -0
  2. Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/filter.py +56 -0
  3. Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/group.py +115 -0
  4. Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/instance.py +38 -0
  5. Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/metrics.py +578 -0
  6. Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/model.py +493 -0
  7. Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/registry.py +196 -0
  8. Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/samplers.py +232 -0
  9. Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/task.py +1881 -0
  10. Prism/LLaDA/LLaDA_Baseline/dllm_eval/caching/__init__.py +0 -0
  11. Prism/LLaDA/LLaDA_Baseline/dllm_eval/caching/cache.py +59 -0
  12. Prism/LLaDA/LLaDA_Baseline/dllm_eval/decontamination/__init__.py +0 -0
  13. Prism/LLaDA/LLaDA_Baseline/dllm_eval/decontamination/janitor.py +328 -0
  14. Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/__init__.py +2 -0
  15. Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/evaluation_tracker.py +530 -0
  16. Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/utils.py +149 -0
  17. Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/wandb_logger.py +358 -0
  18. Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/LLaDA.py +786 -0
  19. Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/__init__.py +19 -0
  20. Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/dummy.py +41 -0
  21. Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/hts_sampler.py +315 -0
  22. Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/huggingface.py +1489 -0
  23. Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/utils.py +854 -0
  24. Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/verifier.py +154 -0
  25. Prism/LLaDA/LLaDA_Baseline/dllm_eval/prompts/__init__.py +128 -0
  26. Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/__init__.py +670 -0
  27. Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/gsm8k.yaml +15 -0
  28. Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/utils.py +13 -0
  29. Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/humaneval.yaml +13 -0
  30. Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/utils.py +43 -0
  31. Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/mbpp.yaml +14 -0
  32. Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/utils.py +79 -0
  33. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/certifi/__main__.py +12 -0
  34. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/INSTALLER +1 -0
  35. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/LICENSE +201 -0
  36. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/METADATA +477 -0
  37. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/RECORD +12 -0
  38. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/WHEEL +8 -0
  39. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/top_level.txt +1 -0
  40. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/INSTALLER +1 -0
  41. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/LICENSE.txt +22 -0
  42. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/METADATA +193 -0
  43. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/RECORD +52 -0
  44. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/WHEEL +5 -0
  45. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/top_level.txt +1 -0
  46. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/INSTALLER +1 -0
  47. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/LICENSE +13 -0
  48. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/METADATA +140 -0
  49. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/RECORD +19 -0
  50. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/WHEEL +6 -0
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/__init__.py ADDED
File without changes
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/filter.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Iterable, List, Union
4
+
5
+ from dllm_eval.api.instance import Instance
6
+
7
+
8
+ class Filter(ABC):
9
+ """
10
+ Filter classes operate on a per-task level.
11
+ They take all model outputs (`instance.resps` for all `task.instances`)
12
+ across all instances of a task, and perform operations.
13
+ In a single run, one can configure any number of separate filters or lists of filters.
14
+
15
+ """
16
+
17
+ def __init__(self, **kwargs) -> None:
18
+ """
19
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
20
+ """
21
+
22
+ @abstractmethod
23
+ def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
24
+ """
25
+ Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
26
+ Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
27
+ if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
28
+ [<filtered resps for instance 0>, <filtered resps for instance 1>]
29
+ """
30
+ return resps
31
+
32
+
33
+ @dataclass
34
+ class FilterEnsemble:
35
+ """
36
+ FilterEnsemble creates a pipeline applying multiple filters.
37
+ Its intended usage is to stack multiple post-processing steps in order.
38
+ `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
39
+ pipeline separately.
40
+ """
41
+
42
+ name: str
43
+ filters: List[Callable[[], Filter]]
44
+
45
+ def apply(self, instances: List[Instance]) -> None:
46
+ resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
47
+ resps, docs = list(resps), list(docs)
48
+
49
+ for f in self.filters:
50
+ # apply filters in sequence
51
+ resps = f().apply(resps, docs)
52
+
53
+ # add the end results after filtering to filtered_requests of their respective source instances.
54
+ # has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
55
+ for inst, resp in zip(instances, resps):
56
+ inst.filtered_resps[self.name] = resp
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/group.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from dataclasses import asdict, dataclass
3
+ from inspect import getsource
4
+ from typing import Any, Callable, List, Optional, Union
5
+
6
+
7
+ @dataclass
8
+ class AggMetricConfig(dict):
9
+ metric: Optional[str] = None
10
+ aggregation: Optional[str] = "mean"
11
+ weight_by_size: Optional[str] = False
12
+ # list of filter names which should be incorporated into the aggregated metric.
13
+ filter_list: Optional[Union[str, list]] = "none"
14
+
15
+ def __post_init__(self):
16
+ if self.aggregation != "mean" and not callable(self.aggregation):
17
+ raise ValueError(
18
+ f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'."
19
+ )
20
+
21
+ if isinstance(self.filter_list, str):
22
+ self.filter_list = [self.filter_list]
23
+
24
+
25
+ @dataclass
26
+ class GroupConfig(dict):
27
+ group: Optional[str] = None
28
+ group_alias: Optional[str] = None
29
+ task: Optional[Union[str, list]] = None
30
+ aggregate_metric_list: Optional[
31
+ Union[List[AggMetricConfig], AggMetricConfig, dict]
32
+ ] = None
33
+ metadata: Optional[dict] = (
34
+ None # by default, not used in the code. allows for users to pass arbitrary info to tasks
35
+ )
36
+
37
+ def __getitem__(self, item):
38
+ return getattr(self, item)
39
+
40
+ def __setitem__(self, item, value):
41
+ return setattr(self, item, value)
42
+
43
+ def __post_init__(self):
44
+ if self.aggregate_metric_list is not None:
45
+ if isinstance(self.aggregate_metric_list, dict):
46
+ self.aggregate_metric_list = [self.aggregate_metric_list]
47
+
48
+ self.aggregate_metric_list = [
49
+ AggMetricConfig(**item) if isinstance(item, dict) else item
50
+ for item in self.aggregate_metric_list
51
+ ]
52
+
53
+ def to_dict(self, keep_callable: bool = False) -> dict:
54
+ """dumps the current config as a dictionary object, as a printable format.
55
+ null fields will not be printed.
56
+ Used for dumping results alongside full task configuration
57
+
58
+ :return: dict
59
+ A printable dictionary version of the TaskConfig object.
60
+
61
+ # TODO: should any default value in the TaskConfig not be printed?
62
+ """
63
+ cfg_dict = asdict(self)
64
+ # remove values that are `None`
65
+ for k, v in list(cfg_dict.items()):
66
+ if callable(v):
67
+ cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
68
+ return cfg_dict
69
+
70
+ def serialize_function(
71
+ self, value: Union[Callable, str], keep_callable=False
72
+ ) -> Union[Callable, str]:
73
+ """Serializes a given function or string.
74
+
75
+ If 'keep_callable' is True, the original callable is returned.
76
+ Otherwise, attempts to return the source code of the callable using 'getsource'.
77
+ """
78
+ if keep_callable:
79
+ return value
80
+ else:
81
+ try:
82
+ return getsource(value)
83
+ except (TypeError, OSError):
84
+ return str(value)
85
+
86
+
87
+ class ConfigurableGroup(abc.ABC):
88
+ def __init__(
89
+ self,
90
+ config: Optional[dict] = None,
91
+ ) -> None:
92
+ self._config = GroupConfig(**config)
93
+
94
+ @property
95
+ def group(self):
96
+ return self._config.group
97
+
98
+ @property
99
+ def group_alias(self):
100
+ return self._config.group_alias
101
+
102
+ @property
103
+ def version(self):
104
+ return self._config.version
105
+
106
+ @property
107
+ def config(self):
108
+ return self._config.to_dict()
109
+
110
+ @property
111
+ def group_name(self) -> Any:
112
+ return self._config.group
113
+
114
+ def __repr__(self):
115
+ return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})"
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/instance.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Literal, Optional, Tuple
3
+
4
+
5
+ OutputType = Literal[
6
+ "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
7
+ ]
8
+
9
+
10
+ @dataclass
11
+ class Instance:
12
+ request_type: OutputType
13
+ doc: dict
14
+ arguments: tuple
15
+ idx: int
16
+ metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
17
+ default_factory=lambda: (None, None, None)
18
+ )
19
+ resps: list = field(default_factory=list)
20
+ filtered_resps: dict = field(default_factory=dict)
21
+
22
+ # initialized after init
23
+ task_name: Optional[str] = None
24
+ doc_id: Optional[int] = None
25
+ repeats: Optional[int] = None
26
+
27
+ def __post_init__(self) -> None:
28
+ # unpack metadata field
29
+ self.task_name, self.doc_id, self.repeats = self.metadata
30
+
31
+ @property
32
+ def args(self):
33
+ """
34
+ Returns (string,) where `string` is the string to calculate loglikelihood over
35
+ """
36
+ return (
37
+ self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
38
+ )
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/metrics.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import random
4
+ import re
5
+ import string
6
+ from collections.abc import Iterable
7
+ from typing import List
8
+
9
+ import numpy as np
10
+ import sacrebleu
11
+
12
+ from dllm_eval.api.registry import register_aggregation, register_metric
13
+
14
+
15
+ eval_logger = logging.getLogger(__name__)
16
+
17
+
18
+ # Register Aggregations First
19
+ @register_aggregation("bypass")
20
+ def bypass_agg(arr):
21
+ return 999
22
+
23
+
24
+ @register_aggregation("nanmean")
25
+ def nanmean(arr):
26
+ if len(arr) == 0 or all(np.isnan(arr)):
27
+ return np.nan
28
+ return np.nanmean(arr)
29
+
30
+
31
+ @register_aggregation("mean")
32
+ def mean(arr):
33
+ return sum(arr) / len(arr)
34
+
35
+
36
+ @register_aggregation("median")
37
+ def median(arr):
38
+ return arr[len(arr) // 2]
39
+
40
+
41
+ # Certain metrics must be calculated across all documents in a benchmark.
42
+ # We use them as aggregation metrics, paired with no-op passthrough metric fns.
43
+ @register_aggregation("perplexity")
44
+ def perplexity(items):
45
+ return math.exp(-mean(items))
46
+
47
+
48
+ @register_aggregation("weighted_perplexity")
49
+ def weighted_perplexity(items):
50
+ return math.exp(-weighted_mean(items))
51
+
52
+
53
+ @register_aggregation("bits_per_byte")
54
+ def bits_per_byte(items):
55
+ return -weighted_mean(items) / math.log(2)
56
+
57
+
58
+ @register_aggregation("f1")
59
+ def f1_score(items):
60
+ from sklearn.metrics import f1_score
61
+
62
+ unzipped_list = list(zip(*items))
63
+ golds = unzipped_list[0]
64
+ preds = unzipped_list[1]
65
+ fscore = f1_score(golds, preds)
66
+
67
+ return np.max(fscore)
68
+
69
+
70
+ @register_aggregation("matthews_corrcoef")
71
+ def matthews_corrcoef(items):
72
+ from sklearn.metrics import matthews_corrcoef
73
+
74
+ unzipped_list = list(zip(*items))
75
+ golds = unzipped_list[0]
76
+ preds = unzipped_list[1]
77
+ return matthews_corrcoef(golds, preds)
78
+
79
+
80
+ @register_aggregation("bleu")
81
+ def bleu(items):
82
+ """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
83
+ for evaluating a generated sentence to a reference sentence. It counts matching
84
+ n-grams in the candidate translation to n-grams in the reference text, where
85
+ 1-gram or unigram would be each token and a bigram comparison would be each
86
+ word pair. The comparison is made regardless of word order
87
+ Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
88
+ Paper: https://www.aclweb.org/anthology/P02-1040/
89
+
90
+ Higher is better
91
+ """
92
+ refs = list(zip(*items))[0]
93
+ preds = list(zip(*items))[1]
94
+ refs, preds = _sacreformat(refs, preds)
95
+ return sacrebleu.corpus_bleu(preds, refs).score
96
+
97
+
98
+ @register_aggregation("chrf")
99
+ def chrf(items):
100
+ """chrF++ is a tool for automatic evaluation of machine translation output
101
+ based on character n-gram precision and recall enhanced with word n-grams.
102
+ Source: https://github.com/m-popovic/chrF
103
+ Paper: https://www.aclweb.org/anthology/W15-3049.pdf
104
+
105
+ Higher is better # TODO I think
106
+ """
107
+ refs = list(zip(*items))[0]
108
+ preds = list(zip(*items))[1]
109
+ refs, preds = _sacreformat(refs, preds)
110
+ return sacrebleu.corpus_chrf(preds, refs).score
111
+
112
+
113
+ @register_aggregation("ter")
114
+ def ter(items):
115
+ """Translation Error Rate is an error metric for machine translation that
116
+ measures the number of edits required to change a system output into one
117
+ of the references
118
+ Source: http://www.cs.umd.edu/~snover/tercom/
119
+ Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
120
+
121
+ Lower is better
122
+ """
123
+ refs = list(zip(*items))[0]
124
+ preds = list(zip(*items))[1]
125
+ refs, preds = _sacreformat(refs, preds)
126
+ return sacrebleu.corpus_ter(preds, refs).score
127
+
128
+
129
+ @register_aggregation("brier_score")
130
+ def brier_score(items): # This is a passthrough function
131
+ gold, predictions = list(zip(*items))
132
+ bs, num_class = np.array(predictions).shape
133
+
134
+ gold = list(gold)
135
+ gold_one_hot = np.eye(num_class)[gold]
136
+ return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
137
+
138
+
139
+ @register_metric(
140
+ metric="brier_score",
141
+ higher_is_better=False,
142
+ output_type=["multiple_choice"],
143
+ aggregation="brier_score",
144
+ )
145
+ def brier_score_fn(items): # This is a passthrough function
146
+ return items
147
+
148
+
149
+ @register_metric(
150
+ metric="acc",
151
+ higher_is_better=True,
152
+ output_type=["loglikelihood", "multiple_choice"],
153
+ aggregation="mean",
154
+ )
155
+ def acc_fn(items): # This is a passthrough function
156
+ return items
157
+
158
+
159
+ @register_metric(
160
+ metric="acc_norm",
161
+ higher_is_better=True,
162
+ output_type=["loglikelihood", "multiple_choice"],
163
+ aggregation="mean",
164
+ )
165
+ def acc_norm_fn(items): # This is a passthrough function
166
+ return items
167
+
168
+
169
+ @register_metric(
170
+ metric="acc_mutual_info",
171
+ higher_is_better=True,
172
+ output_type="multiple_choice",
173
+ aggregation="mean",
174
+ )
175
+ def acc_mutual_info_fn(items): # This is a passthrough function
176
+ return items
177
+
178
+
179
+ ### the code used in the `exact_match_hf_evaluate` function is ported from
180
+ ### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py
181
+ ### which is under the apache license.
182
+
183
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
184
+
185
+ # Licensed under the Apache License, Version 2.0 (the "License");
186
+ # you may not use this file except in compliance with the License.
187
+ # You may obtain a copy of the License at
188
+
189
+ # http://www.apache.org/licenses/LICENSE-2.0
190
+
191
+
192
+ # Unless required by applicable law or agreed to in writing, software
193
+ # distributed under the License is distributed on an "AS IS" BASIS,
194
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
195
+ # See the License for the specific language governing permissions and
196
+ # limitations under the License.
197
+ def exact_match_hf_evaluate(
198
+ predictions,
199
+ references,
200
+ regexes_to_ignore=None,
201
+ ignore_case=False,
202
+ ignore_punctuation=False,
203
+ ignore_numbers=False,
204
+ ):
205
+ if regexes_to_ignore is not None:
206
+ for s in regexes_to_ignore:
207
+ predictions = np.array([re.sub(s, "", x) for x in predictions])
208
+ references = np.array([re.sub(s, "", x) for x in references])
209
+ else:
210
+ predictions = np.asarray(predictions)
211
+ references = np.asarray(references)
212
+
213
+ if ignore_case:
214
+ predictions = np.char.lower(predictions)
215
+ references = np.char.lower(references)
216
+
217
+ if ignore_punctuation:
218
+ repl_table = string.punctuation.maketrans("", "", string.punctuation)
219
+ predictions = np.char.translate(predictions, table=repl_table)
220
+ references = np.char.translate(references, table=repl_table)
221
+
222
+ if ignore_numbers:
223
+ repl_table = string.digits.maketrans("", "", string.digits)
224
+ predictions = np.char.translate(predictions, table=repl_table)
225
+ references = np.char.translate(references, table=repl_table)
226
+
227
+ score_list = predictions == references
228
+
229
+ return {"exact_match": np.mean(score_list)}
230
+
231
+
232
+ ###
233
+
234
+
235
+ @register_metric(
236
+ metric="exact_match",
237
+ higher_is_better=True,
238
+ output_type="generate_until",
239
+ aggregation="mean",
240
+ )
241
+ def exact_match_fn(**kwargs):
242
+ return exact_match_hf_evaluate(**kwargs)
243
+
244
+
245
+ @register_metric(
246
+ metric="perplexity",
247
+ higher_is_better=False,
248
+ output_type="loglikelihood",
249
+ aggregation="perplexity",
250
+ )
251
+ def perplexity_fn(items): # This is a passthrough function
252
+ return items
253
+
254
+
255
+ @register_metric(
256
+ metric="word_perplexity",
257
+ higher_is_better=False,
258
+ output_type="loglikelihood_rolling",
259
+ aggregation="weighted_perplexity",
260
+ )
261
+ def word_perplexity_fn(items): # This is a passthrough function
262
+ return items
263
+
264
+
265
+ @register_metric(
266
+ metric="byte_perplexity",
267
+ higher_is_better=False,
268
+ output_type="loglikelihood_rolling",
269
+ aggregation="weighted_perplexity",
270
+ )
271
+ def byte_perplexity_fn(items): # This is a passthrough function
272
+ return items
273
+
274
+
275
+ @register_metric(
276
+ metric="bits_per_byte",
277
+ higher_is_better=False,
278
+ output_type="loglikelihood_rolling",
279
+ aggregation="bits_per_byte",
280
+ )
281
+ def bits_per_byte_fn(items): # This is a passthrough function
282
+ return items
283
+
284
+
285
+ def pop_stddev(arr):
286
+ mu = mean(arr)
287
+ return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
288
+
289
+
290
+ def sample_stddev(arr):
291
+ mu = mean(arr)
292
+ return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
293
+
294
+
295
+ def mean_stderr(arr):
296
+ return sample_stddev(arr) / math.sqrt(len(arr))
297
+
298
+
299
+ @register_metric(
300
+ metric="bypass",
301
+ higher_is_better=True,
302
+ output_type=["loglikelihood", "multiple_choice", "generate_until"],
303
+ aggregation="bypass",
304
+ )
305
+ def bypass(items):
306
+ return None
307
+
308
+
309
+ @register_metric(
310
+ metric="mcc",
311
+ higher_is_better=True,
312
+ output_type="multiple_choice",
313
+ aggregation="matthews_corrcoef",
314
+ )
315
+ def mcc_fn(items): # This is a passthrough function
316
+ return items
317
+
318
+
319
+ @register_metric(
320
+ metric="f1",
321
+ higher_is_better=True,
322
+ output_type="multiple_choice",
323
+ aggregation="f1",
324
+ )
325
+ def f1_fn(items): # This is a passthrough function
326
+ return items
327
+
328
+
329
+ @register_metric(
330
+ metric="bleu",
331
+ higher_is_better=True,
332
+ output_type="generate_until",
333
+ aggregation="bleu",
334
+ )
335
+ def bleu_fn(items): # This is a passthrough function
336
+ return items
337
+
338
+
339
+ @register_metric(
340
+ metric="chrf",
341
+ higher_is_better=True,
342
+ output_type="generate_until",
343
+ aggregation="chrf",
344
+ )
345
+ def chrf_fn(items): # This is a passthrough function
346
+ return items
347
+
348
+
349
+ @register_metric(
350
+ metric="ter",
351
+ higher_is_better=True,
352
+ output_type="generate_until",
353
+ aggregation="ter",
354
+ )
355
+ def ter_fn(items): # This is a passthrough function
356
+ return items
357
+
358
+
359
+ @register_metric(
360
+ metric="acc_all",
361
+ higher_is_better=True,
362
+ output_type="loglikelihood",
363
+ aggregation="mean",
364
+ )
365
+ def acc_all(items):
366
+ # Only count as correct if all answers are labeled correctly for each question
367
+ question_scoring_dict = {}
368
+ preds = list(zip(*items))[0]
369
+ docs = list(zip(*items))[1]
370
+
371
+ for doc, pred in zip(docs, preds):
372
+ paragraph_id = doc["idx"]["paragraph"]
373
+ question_id = doc["idx"]["question"]
374
+ if (paragraph_id, question_id) not in question_scoring_dict:
375
+ question_scoring_dict[(paragraph_id, question_id)] = []
376
+
377
+ gold_label = doc["label"] == 1
378
+
379
+ question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
380
+ acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
381
+ return acc
382
+
383
+
384
+ def acc_all_stderr(items):
385
+ # Only count as correct if all answers are labeled correctly for each question
386
+ question_scoring_dict = {}
387
+ preds = list(zip(*items))[0]
388
+ docs = list(zip(*items))[1]
389
+
390
+ for doc, pred in zip(docs, preds):
391
+ question_id = doc["idx"]["question"]
392
+ if question_id not in question_scoring_dict:
393
+ question_scoring_dict[question_id] = []
394
+
395
+ gold_label = doc["label"] == 1
396
+ question_scoring_dict[question_id].append(gold_label == pred)
397
+
398
+ acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
399
+ return acc
400
+
401
+
402
+ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
403
+ """Compute max metric between prediction and each ground truth."""
404
+ scores_for_ground_truths = []
405
+ for ground_truth in ground_truths:
406
+ score = metric_fn(prediction, ground_truth)
407
+ scores_for_ground_truths.append(score)
408
+ return max(scores_for_ground_truths)
409
+
410
+
411
+ def weighted_mean(items):
412
+ a, b = zip(*items)
413
+ return sum(a) / sum(b)
414
+
415
+
416
+ def is_non_str_iterable(obj):
417
+ return isinstance(obj, Iterable) and not isinstance(obj, str)
418
+
419
+
420
+ def _sacreformat(refs, preds):
421
+ """Format refs and preds for sacrebleu corpus calculation. It is very particular"""
422
+ # Sacrebleu expects (List[str], List[List[str])
423
+ # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
424
+
425
+ # Note [ref1_stream] is the first reference for each pred.
426
+ # So lists are size N and (M, N) for N preds and M possible refs for each pred
427
+ # This is a different order of dimensions that I would expect
428
+
429
+ # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
430
+ # Must become List[List[str]] with the inner list corresponding to preds
431
+ if not is_non_str_iterable(refs):
432
+ refs = list(refs)
433
+ if not is_non_str_iterable(refs[0]):
434
+ refs = [[ref] for ref in refs]
435
+ refs = list(zip(*refs))
436
+ # Note the number of refs in each ref list much match the number of preds
437
+
438
+ # We expect preds to be List[str] or List[List[str]]. Must become List[str]
439
+ if not is_non_str_iterable(preds):
440
+ preds = list(preds)
441
+ if is_non_str_iterable(preds[0]):
442
+ assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
443
+ preds = [pred[0] for pred in preds]
444
+
445
+ return refs, preds
446
+
447
+
448
+ # stderr stuff
449
+
450
+
451
+ class _bootstrap_internal:
452
+ def __init__(self, f, n) -> None:
453
+ self.f = f
454
+ self.n = n
455
+
456
+ def __call__(self, v):
457
+ i, xs = v
458
+ rnd = random.Random()
459
+ rnd.seed(i)
460
+ res = []
461
+ for _ in range(self.n):
462
+ res.append(self.f(rnd.choices(xs, k=len(xs))))
463
+ return res
464
+
465
+
466
+ def bootstrap_stderr(f, xs, iters):
467
+ import multiprocessing as mp
468
+
469
+ pool = mp.Pool(mp.cpu_count())
470
+ # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
471
+ # equivalent to stderr calculated without Bessel's correction in the stddev.
472
+ # Unfortunately, I haven't been able to figure out what the right correction is
473
+ # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
474
+ # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
475
+ # Thankfully, shouldn't matter because our samples are pretty big usually anyways
476
+ res = []
477
+ chunk_size = min(1000, iters)
478
+ from tqdm import tqdm
479
+
480
+ print("bootstrapping for stddev:", f.__name__)
481
+ for bootstrap in tqdm(
482
+ pool.imap(
483
+ _bootstrap_internal(f, chunk_size),
484
+ [(i, xs) for i in range(iters // chunk_size)],
485
+ ),
486
+ total=iters // chunk_size,
487
+ ):
488
+ # sample w replacement
489
+ res.extend(bootstrap)
490
+
491
+ pool.close()
492
+ return sample_stddev(res)
493
+
494
+
495
+ def stderr_for_metric(metric, bootstrap_iters: int):
496
+ if bootstrap_iters <= 0:
497
+ # return no function (don't compute stderr) if bootstrap iters = 0
498
+ return None
499
+
500
+ bootstrappable = [
501
+ median,
502
+ matthews_corrcoef,
503
+ f1_score,
504
+ perplexity,
505
+ bleu,
506
+ chrf,
507
+ ter,
508
+ nanmean,
509
+ ]
510
+
511
+ if metric in bootstrappable:
512
+ return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
513
+
514
+ stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
515
+
516
+ return stderr.get(metric, None)
517
+
518
+
519
+ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
520
+ # Used to aggregate bootstrapped stderrs across subtasks in a group,
521
+ # when we are weighting by the size of each subtask.
522
+ #
523
+
524
+ assert len(stderrs) == len(sizes)
525
+
526
+ # formula source: https://en.wikipedia.org/wiki/Pooled_variance
527
+ # and: https://stats.stackexchange.com/a/4841331
528
+ # this empirically seems to match running `stderr_for_metric` on all instances
529
+ # from the subtasks concatenated with each other.
530
+ pooled_sample_var = (
531
+ sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)])
532
+ ) / (sum(sizes) - len(sizes))
533
+
534
+ return np.sqrt(pooled_sample_var / sum(sizes))
535
+
536
+
537
+ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
538
+ assert metrics is not None, (
539
+ "Need to pass a list of each subtask's metric for this stderr aggregation"
540
+ )
541
+ assert len(stderrs) == len(sizes) and len(sizes) == len(metrics)
542
+
543
+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation.
544
+ # This formula depends on sample means.
545
+ # removed because it seems to give erroneously huge stderrs for groupings of tasks
546
+ # and does not seem to match up with bootstrap-calculated stderrs for groups.
547
+
548
+ ### don't use this unless a statistician has told you it's the right thing to do ###
549
+
550
+ # accumulators: we'll aggregate pairwise N - 1 times
551
+ variance = stderrs[0] ** 2
552
+ curr_size = sizes[0]
553
+ curr_score = metrics[0]
554
+
555
+ for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]):
556
+ curr_score = ((curr_score * curr_size) + (score * size)) / (
557
+ curr_size + size
558
+ ) # NOTE: this assumes our aggregation fn is "mean"
559
+
560
+ variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / (
561
+ curr_size + size - 1
562
+ ) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * (
563
+ curr_score - score
564
+ ) ** 2
565
+
566
+ return np.sqrt(variance)
567
+
568
+
569
+ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
570
+ # A helper function that is used to aggregate
571
+ # subtask scores cross-task.
572
+ # TODO: does not hold for non-mean aggregations
573
+ if not weight_by_size:
574
+ sizes = [1] * len(sizes)
575
+
576
+ assert len(metrics) == len(sizes)
577
+
578
+ return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes)
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/model.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import hashlib
3
+ import json
4
+ import logging
5
+ import os
6
+ from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union
7
+
8
+ import transformers
9
+ from sqlitedict import SqliteDict
10
+ from tqdm import tqdm
11
+
12
+ from dllm_eval import utils
13
+
14
+
15
+ eval_logger = logging.getLogger(__name__)
16
+
17
+ T = TypeVar("T", bound="LM")
18
+
19
+
20
+ class LM(abc.ABC):
21
+ def __init__(self) -> None:
22
+ """Defines the interface that should be implemented by all LM subclasses.
23
+ LMs are assumed to take text (strings) as input and yield strings as output
24
+ (inputs/outputs should be tokenization-agnostic.)
25
+
26
+ """
27
+ # set rank and world size to a single process, by default.
28
+ self._rank = 0
29
+ self._world_size = 1
30
+ self.cache_hook = CacheHook(None)
31
+
32
+ @abc.abstractmethod
33
+ def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
34
+ """Compute log-likelihood of generating a continuation from a context.
35
+ Downstream tasks should attempt to use loglikelihood instead of other
36
+ LM calls whenever possible.
37
+
38
+ :param requests: list[Instance]
39
+ A list of Instance objects, with property `args` which returns a tuple (context, continuation).
40
+ `context: str`
41
+ Context string. Implementations of LM must be able to handle an
42
+ empty context string.
43
+ `continuation: str`
44
+ The continuation over which log likelihood will be calculated. If
45
+ there is a word boundary, the space should be in the continuation.
46
+ For example, context="hello" continuation=" world" is correct.
47
+
48
+ :return: list[tuple[float, bool]]
49
+ A list of pairs (logprob, isgreedy)
50
+ `logprob: float`
51
+ The log probability of `continuation`.
52
+ `isgreedy`:
53
+ Whether `continuation` would be generated by greedy sampling from `context`.
54
+ """
55
+ pass
56
+
57
+ @abc.abstractmethod
58
+ def loglikelihood_rolling(self, requests) -> List[float]:
59
+ """Compute full log-likelihood of a string, with no truncation, for perplexity computation
60
+ - We will use the full max context length of the model.
61
+ - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
62
+ the max context length.
63
+ - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
64
+ which may simply concatenate multiple documents together.
65
+ - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
66
+ multiple chunks, the last input will still a full-sized context.
67
+ Example:
68
+ Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
69
+ Prefix: BOS/EOS
70
+ Max context length: 4
71
+ Resulting input/prediction pairs:
72
+
73
+ INPUT: BOS 0 1 2
74
+ PRED: 0 1 2 3
75
+
76
+ INPUT: 3 4 5 6
77
+ PRED: 4 5 6 7
78
+
79
+ INPUT: 5 6 7 8
80
+ PRED: 8 9
81
+
82
+ Observe that:
83
+ 1. Each token is predicted exactly once
84
+ 2. For the last pair, we provide the full context, but only score the last two tokens
85
+
86
+ :param requests: list[Instance]
87
+ A list of Instance objects with property `args` which returns a tuple (context,).
88
+ string: str
89
+ String for which we are computing overall loglikelihood
90
+ :return: list[tuple[float]]
91
+ A list of tuples (logprob,)
92
+ logprob: float
93
+ The log probability of `context` conditioned on the BOS/EOS token.
94
+ Can also be overridden for custom cases by `prefix_token_id`.
95
+ """
96
+ pass
97
+
98
+ # TODO: Add an optional max length
99
+ @abc.abstractmethod
100
+ def generate_until(self, requests) -> List[str]:
101
+ """Generate greedily until a stopping sequence
102
+
103
+ :param requests: list[Instance]
104
+ A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
105
+ context: str
106
+ Context string
107
+ gen_kwargs: dict
108
+ A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
109
+ :return: list[str]
110
+ A list of model generated continuations.
111
+ continuation: str
112
+ The generated continuation.
113
+ """
114
+ pass
115
+
116
+ def apply_chat_template(
117
+ self, chat_history: List[Dict[str, str]], add_generation_prompt=True
118
+ ) -> str:
119
+ """
120
+ Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
121
+
122
+ :param chat_history: list[dict[str, str]]
123
+ A list of dictionaries with keys 'role' and 'content'.
124
+ Values are strings representing the role name and the content of the message, respectively.
125
+ :param add_generation_prompt: bool
126
+ Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message.
127
+ :return: str
128
+ A string representing the chat history in a format that can be used as input to the LM.
129
+ """
130
+ raise NotImplementedError(
131
+ "To use this model with chat templates, please implement the 'apply_chat_template' method for your model type."
132
+ )
133
+
134
+ @classmethod
135
+ def create_from_arg_string(
136
+ cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
137
+ ) -> T:
138
+ """
139
+ Creates an instance of the LM class using the given argument string and additional config.
140
+
141
+ Parameters:
142
+ - arg_string: A string containing arguments in the format key1=value1,key2=value2.
143
+ - additional_config: Optional dictionary containing additional configuration parameters.
144
+
145
+ Returns:
146
+ - Instance of the LM class.
147
+ """
148
+ additional_config = {} if additional_config is None else additional_config
149
+ args = utils.simple_parse_args_string(arg_string)
150
+ args2 = {k: v for k, v in additional_config.items() if v is not None}
151
+ return cls(**args, **args2)
152
+
153
+ @classmethod
154
+ def create_from_arg_obj(
155
+ cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
156
+ ) -> T:
157
+ """
158
+ Creates an instance of the LM class using the given arg_obj
159
+
160
+ Parameters:
161
+ - arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
162
+ - additional_config: Optional dictionary containing additional configuration parameters.
163
+
164
+ Returns:
165
+ - Instance of the LM class.
166
+ """
167
+
168
+ additional_config = {} if additional_config is None else additional_config
169
+ additional_config = {
170
+ k: v for k, v in additional_config.items() if v is not None
171
+ }
172
+
173
+ return cls(**arg_dict, **additional_config)
174
+
175
+ @property
176
+ def rank(self):
177
+ # used in the case of parallelism. Hardcoded to
178
+ # ensure no errors arise using API models which do
179
+ # not support multi-device parallelism nor expect it.
180
+ return self._rank
181
+
182
+ @property
183
+ def world_size(self):
184
+ # used in the case of parallelism. Hardcoded to
185
+ # ensure no errors arise using API models which do
186
+ # not support multi-device parallelism nor expect it.
187
+ return self._world_size
188
+
189
+ @property
190
+ def tokenizer_name(self) -> str:
191
+ """Must be defined for LM subclasses which implement Chat Templating.
192
+ Should return the name of the tokenizer or chat template used.
193
+ Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
194
+ """
195
+ raise NotImplementedError(
196
+ "To use this model with chat templates, please implement the 'tokenizer_name' property."
197
+ )
198
+
199
+ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
200
+ """Returns the chat template structure for user/assistant messages if a template is provided.
201
+ This method is intended to be overridden in a subclass to define a specific chat template format.
202
+ For models that do not support chat templates, this method returns None by default.
203
+ """
204
+
205
+ return ""
206
+
207
+ def set_cache_hook(self, cache_hook) -> None:
208
+ self.cache_hook = cache_hook
209
+
210
+
211
+ ### SQLite-based caching of LM responses
212
+ def hash_args(attr, args):
213
+ dat = json.dumps([attr] + list(args))
214
+ return hashlib.sha256(dat.encode("utf-8")).hexdigest()
215
+
216
+
217
+ class CacheHook:
218
+ def __init__(self, cachinglm) -> None:
219
+ if cachinglm is None:
220
+ self.dbdict = None
221
+ return
222
+
223
+ self.dbdict = cachinglm.dbdict
224
+
225
+ def add_partial(self, attr, req, res) -> None:
226
+ if self.dbdict is None:
227
+ return
228
+ hsh = hash_args(attr, req)
229
+ self.dbdict[hsh] = res
230
+
231
+
232
+ class CachingLM:
233
+ def __init__(self, lm, cache_db) -> None:
234
+ """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
235
+
236
+ :param lm: LM
237
+ Underlying LM
238
+ :param cache_db: str
239
+ Path to cache db
240
+ """
241
+ self.lm = lm
242
+ self.cache_db = cache_db
243
+ if os.path.dirname(cache_db):
244
+ os.makedirs(os.path.dirname(cache_db), exist_ok=True)
245
+ self.dbdict = SqliteDict(cache_db, autocommit=True)
246
+
247
+ # add hook to lm
248
+ lm.set_cache_hook(self.get_cache_hook())
249
+
250
+ def __getattr__(self, attr: str):
251
+ lm_attr = getattr(self.lm, attr)
252
+ if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
253
+ eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
254
+ return lm_attr
255
+
256
+ def fn(requests):
257
+ res = []
258
+ remaining_reqs = []
259
+ warned = False
260
+ # figure out which ones are cached and which ones are new
261
+ eval_logger.info(
262
+ f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
263
+ )
264
+ for req in tqdm(requests, desc="Checking cached requests"):
265
+ hsh = hash_args(attr, req.args)
266
+ if attr == "generate_until" and req.args[1].get("do_sample", False):
267
+ # when we are doing non-greedy generation, don't use the cache
268
+ # (else every "randomly sampled" generation would be identical for repeats > 1).
269
+ if not warned:
270
+ eval_logger.warning(
271
+ f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
272
+ )
273
+ warned = True
274
+ res.append(None)
275
+ remaining_reqs.append(req)
276
+ elif hsh in self.dbdict:
277
+ ob = self.dbdict[hsh]
278
+
279
+ assert ob is not None
280
+
281
+ res.append(ob)
282
+ else:
283
+ res.append(None)
284
+ remaining_reqs.append(req)
285
+ eval_logger.info(
286
+ f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
287
+ )
288
+ if remaining_reqs:
289
+ # actually run the LM on the requests that do not have cached results
290
+ rem_res = getattr(self.lm, attr)(remaining_reqs)
291
+ else:
292
+ rem_res = []
293
+
294
+ # stick the new ones back into the list and also cache any of the new ones
295
+ resptr = 0
296
+ for req, r in zip(remaining_reqs, rem_res):
297
+ while res[resptr] is not None:
298
+ resptr += 1
299
+
300
+ res[resptr] = r
301
+
302
+ # caching
303
+ hsh = hash_args(attr, req.args)
304
+ self.dbdict[hsh] = r
305
+ self.dbdict.commit()
306
+
307
+ return res
308
+
309
+ return fn
310
+
311
+ def get_cache_hook(self):
312
+ return CacheHook(self)
313
+
314
+
315
+ class TemplateLM(LM):
316
+ """
317
+ A class acting as intermediary between the LM base class
318
+ and boilerplate often included in other LM subclasses.
319
+ """
320
+
321
+ tokenizer = None
322
+
323
+ @property
324
+ @abc.abstractmethod
325
+ def eot_token_id(self):
326
+ pass
327
+
328
+ @property
329
+ def prefix_token_id(self):
330
+ # it is used as prefix for loglikelihood
331
+ return self.eot_token_id
332
+
333
+ @abc.abstractmethod
334
+ def tok_encode(self, string: str, **kwargs) -> List[int]:
335
+ """
336
+ Tokenize a string using the model's tokenizer and return a list of token IDs.
337
+ """
338
+ pass
339
+
340
+ @abc.abstractmethod
341
+ def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
342
+ pass
343
+
344
+ def _encode_pair(
345
+ self, context: str, continuation: str
346
+ ) -> Tuple[List[int], List[int]]:
347
+ n_spaces = len(context) - len(context.rstrip())
348
+ if n_spaces > 0:
349
+ continuation = context[-n_spaces:] + continuation
350
+ context = context[:-n_spaces]
351
+
352
+ model_class = getattr(self, "AUTO_MODEL_CLASS", None)
353
+
354
+ if model_class == transformers.AutoModelForSeq2SeqLM:
355
+ context_enc = self.tok_encode(context)
356
+ continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
357
+ else:
358
+ whole_enc = self.tok_encode(context + continuation)
359
+ context_enc = self.tok_encode(context)
360
+
361
+ context_enc_len = len(context_enc)
362
+ continuation_enc = whole_enc[context_enc_len:]
363
+
364
+ return context_enc, continuation_enc
365
+
366
+ def loglikelihood(
367
+ self, requests, disable_tqdm: bool = False
368
+ ) -> List[Tuple[float, bool]]:
369
+ new_reqs = []
370
+ for context, continuation in [req.args for req in requests]:
371
+ if context == "":
372
+ # BOS or EOS as context
373
+ context_enc, continuation_enc = (
374
+ [self.prefix_token_id],
375
+ self.tok_encode(continuation),
376
+ )
377
+ else:
378
+ context_enc, continuation_enc = self._encode_pair(context, continuation)
379
+
380
+ new_reqs.append(((context, continuation), context_enc, continuation_enc))
381
+
382
+ return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
383
+
384
+ @abc.abstractmethod
385
+ def loglikelihood_rolling(
386
+ self, requests, disable_tqdm: bool = False
387
+ ) -> List[float]:
388
+ pass
389
+
390
+ @abc.abstractmethod
391
+ def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
392
+ pass
393
+
394
+ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
395
+ """
396
+ Set and get the appropriate chat template for the model.
397
+ This method sets the tokenizer's chat_template and returns the template string for reproducibility.
398
+
399
+ The template selection logic is adapted from the Transformers library's `apply_chat_template`
400
+ method in the Tokenizer class. The original implementation can be found at:
401
+ https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687
402
+
403
+ This method ensures that the right template is chosen based on the following:
404
+ 0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string.
405
+ 1. If the model's tokenizer has multiple templates:
406
+ a. Use the specified template if it exists in the dictionary.
407
+ b. Use the default template from the list if no specific template is provided.
408
+ c. Raise an error if no default template exists and no specific template is provided.
409
+ 2. If the model's tokenizer has a single template or no template:
410
+ a. Use the tokenizer's chat template if available.
411
+ b. Fall back to the default chat template if no tokenizer chat template exists.
412
+
413
+ Args:
414
+ chat_template (Union[bool, str]): Specifies the chat template to use.
415
+ - If False or None, no template is applied.
416
+ - If True, the default or only available template is used.
417
+ - If a string, the template with the matching name is used.
418
+
419
+ Returns:
420
+ Optional[str]: The selected chat template, or None if no template is applied.
421
+ """
422
+ if self.tokenizer is None:
423
+ return ""
424
+
425
+ if chat_template is False or chat_template is None:
426
+ eval_logger.warning(
427
+ "model.chat_template was called with the chat_template set to False or None. "
428
+ "Therefore no chat template will be applied. Make sure this is an intended behavior."
429
+ )
430
+ return None
431
+
432
+ # Convert boolean chat_template to None to ensure compatibility with the adapted logic
433
+ if isinstance(chat_template, bool):
434
+ chat_template = None
435
+ using_default_template = False
436
+
437
+ # First, handle the cases when the model has a dict of multiple templates
438
+ try:
439
+ template = (
440
+ self.tokenizer.chat_template or self.tokenizer.default_chat_template
441
+ )
442
+ except AttributeError:
443
+ return None
444
+
445
+ if isinstance(template, dict):
446
+ using_default_dict = self.tokenizer.chat_template is None
447
+
448
+ if chat_template is not None:
449
+ if chat_template in template:
450
+ selected_template = template[chat_template]
451
+ if using_default_dict:
452
+ using_default_template = True
453
+ else:
454
+ raise ValueError(
455
+ f"The specified chat template '{chat_template}' is not available. "
456
+ f"Available template names are {sorted(template.keys())}."
457
+ )
458
+ else:
459
+ # If user didn't pass a chat template, use the default template from the dict
460
+ if "default" in template:
461
+ selected_template = template["default"]
462
+ using_default_template = True
463
+ else:
464
+ raise ValueError(
465
+ "This model has multiple chat templates with no default specified! Please either pass a chat "
466
+ "template or the name of the template you wish to use to the `chat_template` argument. Available "
467
+ f"template names are {sorted(template.keys())}."
468
+ )
469
+
470
+ # Cases when the model has a single template or no template
471
+ else:
472
+ # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
473
+ if isinstance(chat_template, str):
474
+ eval_logger.warning(
475
+ "Chat template name provided, but the tokenizer's chat template is not a dictionary. "
476
+ "Using the tokenizer's chat template or the default template instead."
477
+ )
478
+ if self.tokenizer.chat_template is not None:
479
+ selected_template = self.tokenizer.chat_template
480
+ else:
481
+ selected_template = self.tokenizer.default_chat_template
482
+ using_default_template = True
483
+
484
+ if using_default_template:
485
+ eval_logger.warning(
486
+ "No chat template is set for this tokenizer, falling back to a default class-level template. This is "
487
+ "very error-prone, because models are often trained with templates different from the class default! "
488
+ "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
489
+ "point any code depending on them will stop working. We recommend setting a valid chat template before "
490
+ "then to ensure that this model continues working without issues."
491
+ )
492
+
493
+ return selected_template
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/registry.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Callable, Dict, Union
3
+
4
+ import evaluate as hf_evaluate
5
+
6
+ from dllm_eval.api.model import LM
7
+
8
+
9
+ eval_logger = logging.getLogger(__name__)
10
+
11
+ MODEL_REGISTRY = {}
12
+
13
+
14
+ def register_model(*names):
15
+ # either pass a list or a single alias.
16
+ # function receives them as a tuple of strings
17
+
18
+ def decorate(cls):
19
+ for name in names:
20
+ assert issubclass(cls, LM), (
21
+ f"Model '{name}' ({cls.__name__}) must extend LM class"
22
+ )
23
+
24
+ assert name not in MODEL_REGISTRY, (
25
+ f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
26
+ )
27
+
28
+ MODEL_REGISTRY[name] = cls
29
+ return cls
30
+
31
+ return decorate
32
+
33
+
34
+ def get_model(model_name):
35
+ try:
36
+ return MODEL_REGISTRY[model_name]
37
+ except KeyError:
38
+ raise ValueError(
39
+ f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
40
+ )
41
+
42
+
43
+ TASK_REGISTRY = {}
44
+ GROUP_REGISTRY = {}
45
+ ALL_TASKS = set()
46
+ func2task_index = {}
47
+
48
+
49
+ def register_task(name):
50
+ def decorate(fn):
51
+ assert name not in TASK_REGISTRY, (
52
+ f"task named '{name}' conflicts with existing registered task!"
53
+ )
54
+
55
+ TASK_REGISTRY[name] = fn
56
+ ALL_TASKS.add(name)
57
+ func2task_index[fn.__name__] = name
58
+ return fn
59
+
60
+ return decorate
61
+
62
+
63
+ def register_group(name):
64
+ def decorate(fn):
65
+ func_name = func2task_index[fn.__name__]
66
+ if name in GROUP_REGISTRY:
67
+ GROUP_REGISTRY[name].append(func_name)
68
+ else:
69
+ GROUP_REGISTRY[name] = [func_name]
70
+ ALL_TASKS.add(name)
71
+ return fn
72
+
73
+ return decorate
74
+
75
+
76
+ OUTPUT_TYPE_REGISTRY = {}
77
+ METRIC_REGISTRY = {}
78
+ METRIC_AGGREGATION_REGISTRY = {}
79
+ AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
80
+ HIGHER_IS_BETTER_REGISTRY = {}
81
+ FILTER_REGISTRY = {}
82
+
83
+ DEFAULT_METRIC_REGISTRY = {
84
+ "loglikelihood": [
85
+ "perplexity",
86
+ "acc",
87
+ ],
88
+ "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
89
+ "multiple_choice": ["acc", "acc_norm"],
90
+ "generate_until": ["exact_match"],
91
+ }
92
+
93
+
94
+ def register_metric(**args):
95
+ # TODO: do we want to enforce a certain interface to registered metrics?
96
+ def decorate(fn):
97
+ assert "metric" in args
98
+ name = args["metric"]
99
+
100
+ for key, registry in [
101
+ ("metric", METRIC_REGISTRY),
102
+ ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
103
+ ("aggregation", METRIC_AGGREGATION_REGISTRY),
104
+ ]:
105
+ if key in args:
106
+ value = args[key]
107
+ assert value not in registry, (
108
+ f"{key} named '{value}' conflicts with existing registered {key}!"
109
+ )
110
+
111
+ if key == "metric":
112
+ registry[name] = fn
113
+ elif key == "aggregation":
114
+ registry[name] = AGGREGATION_REGISTRY[value]
115
+ else:
116
+ registry[name] = value
117
+
118
+ return fn
119
+
120
+ return decorate
121
+
122
+
123
+ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
124
+ if not hf_evaluate_metric:
125
+ if name in METRIC_REGISTRY:
126
+ return METRIC_REGISTRY[name]
127
+ else:
128
+ eval_logger.warning(
129
+ f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
130
+ )
131
+
132
+ try:
133
+ metric_object = hf_evaluate.load(name)
134
+ return metric_object.compute
135
+ except Exception:
136
+ eval_logger.error(
137
+ f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
138
+ )
139
+
140
+
141
+ def register_aggregation(name: str):
142
+ def decorate(fn):
143
+ assert name not in AGGREGATION_REGISTRY, (
144
+ f"aggregation named '{name}' conflicts with existing registered aggregation!"
145
+ )
146
+
147
+ AGGREGATION_REGISTRY[name] = fn
148
+ return fn
149
+
150
+ return decorate
151
+
152
+
153
+ def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
154
+ try:
155
+ return AGGREGATION_REGISTRY[name]
156
+ except KeyError:
157
+ eval_logger.warning(f"{name} not a registered aggregation metric!")
158
+
159
+
160
+ def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
161
+ try:
162
+ return METRIC_AGGREGATION_REGISTRY[name]
163
+ except KeyError:
164
+ eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
165
+
166
+
167
+ def is_higher_better(metric_name) -> bool:
168
+ try:
169
+ return HIGHER_IS_BETTER_REGISTRY[metric_name]
170
+ except KeyError:
171
+ eval_logger.warning(
172
+ f"higher_is_better not specified for metric '{metric_name}'!"
173
+ )
174
+
175
+
176
+ def register_filter(name):
177
+ def decorate(cls):
178
+ if name in FILTER_REGISTRY:
179
+ eval_logger.info(
180
+ f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}"
181
+ )
182
+ FILTER_REGISTRY[name] = cls
183
+ return cls
184
+
185
+ return decorate
186
+
187
+
188
+ def get_filter(filter_name: Union[str, Callable]) -> Callable:
189
+ try:
190
+ return FILTER_REGISTRY[filter_name]
191
+ except KeyError as e:
192
+ if callable(filter_name):
193
+ return filter_name
194
+ else:
195
+ eval_logger.warning(f"filter `{filter_name}` is not registered!")
196
+ raise e
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/samplers.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+ from functools import partial
4
+ from typing import TYPE_CHECKING, Iterable, Optional, Union
5
+
6
+ import datasets
7
+
8
+
9
+ if TYPE_CHECKING:
10
+ from random import Random
11
+
12
+ from dllm_eval.api.task import ConfigurableTask, Task
13
+
14
+ eval_logger = logging.getLogger("lm-eval")
15
+
16
+
17
+ class ContextSampler:
18
+ def __init__(
19
+ self,
20
+ docs: list[dict],
21
+ task: Union["Task", "ConfigurableTask"],
22
+ fewshot_indices: Optional[Iterable] = None,
23
+ rnd: Optional["Random"] = None,
24
+ ) -> None:
25
+ self.rnd = rnd
26
+ if not self.rnd:
27
+ raise ValueError(
28
+ "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
29
+ )
30
+
31
+ self.task = task
32
+ self.config = task._config
33
+
34
+ self.target_delimiter = self.config.target_delimiter
35
+ self.fewshot_delimiter = self.config.fewshot_delimiter
36
+
37
+ if (
38
+ self.config.fewshot_config is not None
39
+ and self.config.fewshot_config.get("doc_to_text", None) is not None
40
+ ):
41
+ self.doc_to_text = partial(
42
+ self.task.doc_to_text,
43
+ doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
44
+ )
45
+ else:
46
+ self.doc_to_text = self.task.doc_to_text
47
+
48
+ if (
49
+ self.config.fewshot_config is not None
50
+ and self.config.fewshot_config.get("doc_to_target", None) is not None
51
+ ):
52
+ self.doc_to_target = partial(
53
+ self.task.doc_to_target,
54
+ doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
55
+ )
56
+ else:
57
+ self.doc_to_target = self.task.doc_to_target
58
+
59
+ if (
60
+ self.config.fewshot_config is not None
61
+ and self.config.fewshot_config.get("doc_to_choice", None) is not None
62
+ ):
63
+ self.doc_to_choice = partial(
64
+ self.task.doc_to_choice,
65
+ doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
66
+ )
67
+ else:
68
+ self.doc_to_choice = self.task.doc_to_choice
69
+
70
+ self.docs = docs # HF dataset split, provided by task._fewshot_docs()
71
+ if fewshot_indices: # subset few-shot docs from
72
+ if not isinstance(self.docs, datasets.Dataset):
73
+ raise ValueError(
74
+ "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
75
+ )
76
+ self.docs = self.docs.select(fewshot_indices)
77
+
78
+ def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
79
+ # draw an extra fewshot sample if using same split as evaluating on
80
+ prefix = gen_prefix + " " if gen_prefix else ""
81
+ n_samples = (
82
+ num_fewshot + 1
83
+ if self.config.fewshot_split == self.config.test_split
84
+ else num_fewshot
85
+ )
86
+
87
+ # draw `n_samples` docs from fewshot_docs
88
+ fewshotex = self.sample(n_samples)
89
+
90
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
91
+ # TODO: should we just stop people from using fewshot from same split as evaluating?
92
+ selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
93
+
94
+ labeled_examples = ""
95
+ for doc in selected_docs:
96
+ doc_content = self.doc_to_text(doc)
97
+ doc_target = self.doc_to_target(doc)
98
+ if self.config.doc_to_choice is None or isinstance(doc_content, str):
99
+ labeled_examples += doc_content
100
+ else:
101
+ labeled_examples += self.doc_to_choice(doc)[doc_content]
102
+
103
+ if doc_target != "":
104
+ if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
105
+ # TODO: add logger warn once here.
106
+ warnings.warn(
107
+ "Both target_delimiter and target start with a space. This may cause issues.",
108
+ Warning,
109
+ stacklevel=2,
110
+ )
111
+ labeled_examples += self.target_delimiter
112
+ labeled_examples += prefix
113
+ labeled_examples += (
114
+ str(doc_target[0])
115
+ if isinstance(doc_target, list)
116
+ else doc_target
117
+ if self.config.doc_to_choice is None or isinstance(doc_target, str)
118
+ else str(self.doc_to_choice(doc)[doc_target])
119
+ )
120
+ labeled_examples += self.fewshot_delimiter
121
+
122
+ return labeled_examples
123
+
124
+ def get_chat_context(
125
+ self,
126
+ doc: dict,
127
+ num_fewshot: int,
128
+ fewshot_as_multiturn: bool = False,
129
+ gen_prefix: Optional[str] = None,
130
+ ):
131
+ # TODO: Do we need any other delimiter
132
+ prefix = gen_prefix + " " if gen_prefix else ""
133
+ chat_history = []
134
+ # draw an extra fewshot sample if using same split as evaluating on
135
+ n_samples = (
136
+ num_fewshot + 1
137
+ if self.config.fewshot_split == self.config.test_split
138
+ else num_fewshot
139
+ )
140
+ # draw `n_samples` docs from fewshot_docs
141
+ fewshotex = self.sample(n_samples)
142
+
143
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
144
+ # TODO: should we just stop people from using fewshot from same split as evaluating?
145
+ selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
146
+
147
+ if fewshot_as_multiturn:
148
+ for doc in selected_docs:
149
+ doc_content = self.doc_to_text(doc)
150
+ doc_target = self.doc_to_target(doc)
151
+ chat_history.append(
152
+ {
153
+ "role": "user",
154
+ "content": doc_content
155
+ if self.config.doc_to_choice is None
156
+ or isinstance(doc_content, str)
157
+ else self.doc_to_choice(doc)[doc_content],
158
+ }
159
+ )
160
+ chat_history.append(
161
+ {
162
+ "role": "assistant",
163
+ "content": prefix + str(doc_target[0])
164
+ if isinstance(doc_target, list)
165
+ else prefix + doc_target
166
+ if self.config.doc_to_choice is None
167
+ or isinstance(doc_target, str)
168
+ else prefix + str(self.doc_to_choice(doc)[doc_target]),
169
+ }
170
+ )
171
+ else:
172
+ # get fewshot context as one user turn
173
+ chat_history.append(
174
+ {
175
+ "role": "user",
176
+ "content": self.get_context(
177
+ doc, num_fewshot, gen_prefix=gen_prefix
178
+ ),
179
+ }
180
+ )
181
+
182
+ return chat_history
183
+
184
+ def sample(self, n: int):
185
+ """
186
+ Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
187
+ """
188
+
189
+ return self.rnd.sample(self.docs, n)
190
+
191
+
192
+ class FirstNSampler(ContextSampler):
193
+ def sample(self, n: int) -> None:
194
+ """
195
+ Draw the first `n` samples in order from the specified split.
196
+ Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
197
+ """
198
+ assert n <= len(self.docs), (
199
+ f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
200
+ )
201
+ return self.docs[:n]
202
+
203
+
204
+ class BalancedSampler(ContextSampler):
205
+ def sample(self, n: int) -> None:
206
+ """
207
+ TODO: this should return approximately class-balanced samples from our fewshot examples.
208
+ TODO: what order should they be in? maybe random?
209
+ """
210
+
211
+ pass
212
+
213
+
214
+ class ManualSampler(ContextSampler):
215
+ def sample(self, n: int) -> None:
216
+ """ """
217
+ pass
218
+
219
+
220
+ SAMPLER_REGISTRY = {
221
+ "default": ContextSampler,
222
+ "first_n": FirstNSampler,
223
+ }
224
+
225
+
226
+ def get_sampler(name: str):
227
+ try:
228
+ return SAMPLER_REGISTRY[name]
229
+ except KeyError:
230
+ raise ValueError(
231
+ f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
232
+ )
Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/task.py ADDED
@@ -0,0 +1,1881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import ast
3
+ import logging
4
+ import random
5
+ import re
6
+ from collections.abc import Callable
7
+ from copy import deepcopy
8
+ from dataclasses import asdict, dataclass
9
+ from inspect import getsource
10
+ from typing import (
11
+ Any,
12
+ Dict,
13
+ Iterable,
14
+ Iterator,
15
+ List,
16
+ Literal,
17
+ Mapping,
18
+ Optional,
19
+ Tuple,
20
+ Union,
21
+ )
22
+
23
+ import datasets
24
+ import numpy as np
25
+ from tqdm import tqdm
26
+
27
+ from dllm_eval import utils
28
+ from dllm_eval.api import samplers
29
+ from dllm_eval.api.instance import Instance, OutputType
30
+ from dllm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
31
+ from dllm_eval.api.registry import (
32
+ AGGREGATION_REGISTRY,
33
+ DEFAULT_METRIC_REGISTRY,
34
+ get_aggregation,
35
+ get_metric,
36
+ get_metric_aggregation,
37
+ is_higher_better,
38
+ )
39
+ from dllm_eval.caching.cache import load_from_cache, save_to_cache
40
+ from dllm_eval.filters import build_filter_ensemble
41
+ from dllm_eval.prompts import get_prompt
42
+
43
+
44
+ ALL_OUTPUT_TYPES = [
45
+ "loglikelihood",
46
+ "multiple_choice",
47
+ "loglikelihood_rolling",
48
+ "generate_until",
49
+ ]
50
+
51
+ eval_logger = logging.getLogger(__name__)
52
+
53
+
54
+ @dataclass
55
+ class TaskConfig(dict):
56
+ # task naming/registry
57
+ task: Optional[str] = None
58
+ task_alias: Optional[str] = None
59
+ tag: Optional[Union[str, list]] = None
60
+ # HF dataset options.
61
+ # which dataset to use,
62
+ # and what splits for what purpose
63
+ custom_dataset: Optional[Callable] = None
64
+ dataset_path: Optional[str] = None
65
+ dataset_name: Optional[str] = None
66
+ dataset_kwargs: Optional[dict] = None
67
+ training_split: Optional[str] = None
68
+ validation_split: Optional[str] = None
69
+ test_split: Optional[str] = None
70
+ fewshot_split: Optional[str] = (
71
+ None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
72
+ )
73
+ # formatting / prompting options.
74
+ # see docs/advanced_task_guide.md for more info
75
+ process_docs: Optional[Callable] = None
76
+ doc_to_text: Optional[Union[Callable, str]] = None
77
+ doc_to_target: Optional[Union[Callable, str]] = None
78
+ doc_to_image: Union[Callable, str] = None
79
+ doc_to_audio: Union[Callable, str] = None
80
+ unsafe_code: bool = False
81
+ doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
82
+ process_results: Optional[Union[Callable, str]] = None
83
+ use_prompt: Optional[str] = None
84
+ description: str = ""
85
+ target_delimiter: str = " "
86
+ fewshot_delimiter: str = "\n\n"
87
+ fewshot_config: Optional[dict] = None
88
+ # runtime configuration options
89
+ num_fewshot: Optional[int] = None
90
+ # scoring options
91
+ metric_list: Optional[list] = None
92
+ output_type: OutputType = "generate_until"
93
+ generation_kwargs: Optional[dict] = None
94
+ repeats: int = 1
95
+ filter_list: Optional[Union[str, list]] = None
96
+ should_decontaminate: bool = False
97
+ doc_to_decontamination_query: Optional[str] = None
98
+ gen_prefix: Optional[str] = None
99
+ metadata: Optional[dict] = (
100
+ None # by default, not used in the code. allows for users to pass arbitrary info to tasks
101
+ )
102
+
103
+ def __post_init__(self) -> None:
104
+ if self.generation_kwargs is not None:
105
+ if self.output_type != "generate_until":
106
+ eval_logger.warning(
107
+ f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
108
+ )
109
+
110
+ if "temperature" in self.generation_kwargs:
111
+ self.generation_kwargs["temperature"] = float(
112
+ self.generation_kwargs["temperature"]
113
+ )
114
+
115
+ if "until" not in self.generation_kwargs:
116
+ eval_logger.warning(
117
+ f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
118
+ )
119
+ self.generation_kwargs["until"] = [self.fewshot_delimiter]
120
+ else:
121
+ if self.output_type == "generate_until":
122
+ # ensure that we greedily generate in absence of explicit arguments otherwise
123
+ self.generation_kwargs = {
124
+ "until": (
125
+ None
126
+ if self.fewshot_delimiter is None
127
+ else [self.fewshot_delimiter]
128
+ ),
129
+ "do_sample": False,
130
+ "temperature": 0,
131
+ }
132
+ eval_logger.warning(
133
+ f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
134
+ )
135
+
136
+ def __getitem__(self, item):
137
+ return getattr(self, item)
138
+
139
+ def __setitem__(self, item, value):
140
+ return setattr(self, item, value)
141
+
142
+ def to_dict(self, keep_callable: bool = False) -> dict:
143
+ """dumps the current config as a dictionary object, as a printable format.
144
+ null fields will not be printed.
145
+ Used for dumping results alongside full task configuration
146
+
147
+ :return: dict
148
+ A printable dictionary version of the TaskConfig object.
149
+
150
+ # TODO: should any default value in the TaskConfig not be printed?
151
+ """
152
+ cfg_dict = asdict(self)
153
+ # remove values that are `None`
154
+ for k, v in list(cfg_dict.items()):
155
+ if v is None:
156
+ cfg_dict.pop(k)
157
+ elif k == "metric_list":
158
+ for metric_dict in v:
159
+ for metric_key, metric_value in metric_dict.items():
160
+ if callable(metric_value):
161
+ metric_dict[metric_key] = self.serialize_function(
162
+ metric_value, keep_callable=keep_callable
163
+ )
164
+ cfg_dict[k] = v
165
+ elif callable(v):
166
+ cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
167
+ return cfg_dict
168
+
169
+ def serialize_function(
170
+ self, value: Union[Callable, str], keep_callable=False
171
+ ) -> Union[Callable, str]:
172
+ """Serializes a given function or string.
173
+
174
+ If 'keep_callable' is True, the original callable is returned.
175
+ Otherwise, attempts to return the source code of the callable using 'getsource'.
176
+ """
177
+ if keep_callable:
178
+ return value
179
+ else:
180
+ try:
181
+ return getsource(value)
182
+ except (TypeError, OSError):
183
+ return str(value)
184
+
185
+
186
+ class Task(abc.ABC):
187
+ """A task represents an entire benchmark including its dataset, problems,
188
+ answers, and evaluation methods. See BoolQ for a simple example implementation
189
+
190
+ A `doc` can be any python object which represents one instance of evaluation.
191
+ This is usually a dictionary e.g.
192
+ {"question": ..., "answer": ...} or
193
+ {"question": ..., question, answer)
194
+ """
195
+
196
+ VERSION: Optional[Union[int, str]] = None
197
+
198
+ # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
199
+ # or a path to a custom `datasets` loading script.
200
+ DATASET_PATH: Optional[str] = None
201
+
202
+ # The name of a subset within `DATASET_PATH`.
203
+ DATASET_NAME: Optional[str] = None
204
+
205
+ OUTPUT_TYPE: Optional[OutputType] = None
206
+
207
+ def __init__(
208
+ self,
209
+ data_dir: Optional[str] = None,
210
+ cache_dir: Optional[str] = None,
211
+ download_mode: Optional[datasets.DownloadMode] = None,
212
+ config: Optional[Mapping] = None, # Union[dict, TaskConfig]
213
+ ) -> None:
214
+ """
215
+ :param data_dir: str
216
+ Stores the path to a local folder containing the `Task`'s data files.
217
+ Use this to specify the path to manually downloaded data (usually when
218
+ the dataset is not publicly accessible).
219
+ :param cache_dir: str
220
+ The directory to read/write the `Task` dataset. This follows the
221
+ HuggingFace `datasets` API with the default cache directory located at:
222
+ `~/.cache/huggingface/datasets`
223
+ NOTE: You can change the cache location globally for a given process
224
+ to another directory:
225
+ `export HF_DATASETS_CACHE="/path/to/another/directory"`
226
+ :param download_mode: datasets.DownloadMode
227
+ How to treat pre-existing `Task` downloads and data.
228
+ - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
229
+ Reuse download and reuse dataset.
230
+ - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
231
+ Reuse download with fresh dataset.
232
+ - `datasets.DownloadMode.FORCE_REDOWNLOAD`
233
+ Fresh download and fresh dataset.
234
+ """
235
+ self.download(data_dir, cache_dir, download_mode)
236
+ self._training_docs: Optional[list] = None
237
+ self._fewshot_docs: Optional[list] = None
238
+ self._instances: Optional[List[Instance]] = None
239
+
240
+ self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
241
+
242
+ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
243
+ self.fewshot_rnd: Optional[random.Random] = (
244
+ None # purposely induce errors in case of improper usage
245
+ )
246
+
247
+ def download(
248
+ self,
249
+ data_dir: Optional[str] = None,
250
+ cache_dir: Optional[str] = None,
251
+ download_mode=None,
252
+ ) -> None:
253
+ """Downloads and returns the task dataset.
254
+ Override this method to download the dataset from a custom API.
255
+
256
+ :param data_dir: str
257
+ Stores the path to a local folder containing the `Task`'s data files.
258
+ Use this to specify the path to manually downloaded data (usually when
259
+ the dataset is not publicly accessible).
260
+ :param cache_dir: str
261
+ The directory to read/write the `Task` dataset. This follows the
262
+ HuggingFace `datasets` API with the default cache directory located at:
263
+ `~/.cache/huggingface/datasets`
264
+ NOTE: You can change the cache location globally for a given process
265
+ by setting the shell environment variable, `HF_DATASETS_CACHE`,
266
+ to another directory:
267
+ `export HF_DATASETS_CACHE="/path/to/another/directory"`
268
+ :param download_mode: datasets.DownloadMode
269
+ How to treat pre-existing `Task` downloads and data.
270
+ - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
271
+ Reuse download and reuse dataset.
272
+ - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
273
+ Reuse download with fresh dataset.
274
+ - `datasets.DownloadMode.FORCE_REDOWNLOAD`
275
+ Fresh download and fresh dataset.
276
+ """
277
+ self.dataset = datasets.load_dataset(
278
+ path=self.DATASET_PATH,
279
+ name=self.DATASET_NAME,
280
+ data_dir=data_dir,
281
+ cache_dir=cache_dir,
282
+ download_mode=download_mode,
283
+ )
284
+
285
+ @property
286
+ def config(self) -> TaskConfig:
287
+ """Returns the TaskConfig associated with this class."""
288
+ return self._config
289
+
290
+ @abc.abstractmethod
291
+ def has_training_docs(self):
292
+ """Whether the task has a training set"""
293
+ pass
294
+
295
+ @abc.abstractmethod
296
+ def has_validation_docs(self):
297
+ """Whether the task has a validation set"""
298
+ pass
299
+
300
+ @abc.abstractmethod
301
+ def has_test_docs(self):
302
+ """Whether the task has a test set"""
303
+ pass
304
+
305
+ def training_docs(self) -> Iterable:
306
+ """
307
+ :return: Iterable[obj]
308
+ A iterable of any object, that doc_to_text can handle
309
+ """
310
+ return []
311
+
312
+ def validation_docs(self) -> Iterable:
313
+ """
314
+ :return: Iterable[obj]
315
+ A iterable of any object, that doc_to_text can handle
316
+ """
317
+ return []
318
+
319
+ def test_docs(self) -> Iterable:
320
+ """
321
+ :return: Iterable[obj]
322
+ A iterable of any object, that doc_to_text can handle
323
+ """
324
+ return []
325
+
326
+ def fewshot_docs(self) -> Iterable:
327
+ """
328
+ :return: Iterable[obj]
329
+ A iterable of any object, that doc_to_text can handle
330
+ """
331
+ if self.has_training_docs():
332
+ return self.training_docs()
333
+ elif self.has_validation_docs():
334
+ return self.validation_docs()
335
+ else:
336
+ if self.config.get("num_fewshot", 0) > 0:
337
+ eval_logger.warning(
338
+ f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
339
+ ", using test_docs as fewshot_docs but this is not recommended."
340
+ )
341
+ return self.test_docs()
342
+
343
+ def _process_doc(self, doc: dict) -> dict:
344
+ """
345
+ Override this to process (detokenize, strip, replace, etc.) individual
346
+ documents. This can be used in a map over documents of a data split.
347
+ E.g. `map(self._process_doc, self.dataset["validation"])`
348
+
349
+ :return: dict
350
+ The processed version of the specified `doc`.
351
+ """
352
+ return doc
353
+
354
+ @property
355
+ def instances(self) -> List[Instance]:
356
+ """After calling `task.build_all_requests()`, tasks
357
+ maintain a list of the dataset instances which will be evaluated.
358
+ """
359
+ return self._instances
360
+
361
+ def fewshot_examples(self, k, rnd):
362
+ if self._training_docs is None:
363
+ self._training_docs = list(self.training_docs())
364
+
365
+ return rnd.sample(self._training_docs, k)
366
+
367
+ def doc_to_decontamination_query(self, doc):
368
+ raise NotImplementedError(
369
+ "Override doc_to_decontamination_query with document specific decontamination query."
370
+ )
371
+
372
+ @abc.abstractmethod
373
+ def doc_to_text(self, doc):
374
+ pass
375
+
376
+ @abc.abstractmethod
377
+ def doc_to_target(self, doc):
378
+ pass
379
+
380
+ # not an abstractmethod because not every language-only task has to implement this
381
+ def doc_to_image(self, doc):
382
+ raise NotImplementedError
383
+
384
+ def doc_to_audio(self, doc):
385
+ raise NotImplementedError
386
+
387
+ def doc_to_prefix(self, doc):
388
+ return ""
389
+
390
+ def build_all_requests(
391
+ self,
392
+ *,
393
+ limit: Union[int, None] = None,
394
+ samples: Optional[List[int]] = None,
395
+ rank: int = 0,
396
+ world_size: int = 1,
397
+ cache_requests: bool = False,
398
+ rewrite_requests_cache: bool = False,
399
+ system_instruction: Optional[str] = None,
400
+ apply_chat_template: bool = False,
401
+ fewshot_as_multiturn: bool = False,
402
+ chat_template: Optional[Callable] = None,
403
+ tokenizer_name: str = "",
404
+ ) -> None:
405
+ """Build a set of Instances for a task, and store them in task.instances"""
406
+
407
+ # used with caching
408
+ og_limit = limit
409
+
410
+ cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}"
411
+ cache_key += "-chat_template" if apply_chat_template else ""
412
+ cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else ""
413
+ cache_key += (
414
+ f"-system_prompt_hash{utils.hash_string(system_instruction)}"
415
+ if system_instruction is not None
416
+ else ""
417
+ )
418
+ cache_key += f"-tokenizer{tokenizer_name}"
419
+
420
+ cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
421
+
422
+ if cache_requests and cached_instances and not rewrite_requests_cache:
423
+ cached_instances = cached_instances[:limit]
424
+
425
+ flattened_instances = [
426
+ instance
427
+ for instance_group in cached_instances
428
+ for instance in instance_group
429
+ ]
430
+
431
+ self._instances = flattened_instances
432
+ return
433
+
434
+ eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
435
+
436
+ instances = []
437
+
438
+ # process all documents when caching is specified for simplicity
439
+ if (
440
+ cache_requests
441
+ and (not cached_instances or rewrite_requests_cache)
442
+ and limit is not None
443
+ ):
444
+ limit = None
445
+
446
+ doc_id_docs = list(
447
+ self.doc_iterator(
448
+ rank=rank, limit=limit, samples=samples, world_size=world_size
449
+ )
450
+ )
451
+
452
+ num_docs = len(doc_id_docs)
453
+
454
+ for doc_id, doc in tqdm(
455
+ doc_id_docs,
456
+ total=num_docs,
457
+ ):
458
+ # sample fewshot context #TODO: need to offset doc_id by rank now!
459
+ fewshot_ctx = self.fewshot_context(
460
+ doc,
461
+ num_fewshot=0
462
+ if self.config.num_fewshot is None
463
+ else self.config.num_fewshot,
464
+ system_instruction=system_instruction,
465
+ apply_chat_template=apply_chat_template,
466
+ fewshot_as_multiturn=fewshot_as_multiturn,
467
+ chat_template=chat_template,
468
+ gen_prefix=self.doc_to_prefix(doc),
469
+ )
470
+
471
+ # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
472
+ inst = self.construct_requests(
473
+ doc=doc,
474
+ ctx=fewshot_ctx,
475
+ metadata=(self.config["task"], doc_id, self.config.repeats),
476
+ apply_chat_template=apply_chat_template,
477
+ chat_template=chat_template,
478
+ )
479
+
480
+ if not isinstance(inst, list):
481
+ inst = [inst]
482
+
483
+ instances.append(inst)
484
+
485
+ # now flatten, this is to allow slicing to work with pickles
486
+
487
+ sliced_instances = instances[:og_limit]
488
+
489
+ flattened_instances = [
490
+ instance
491
+ for instance_group in sliced_instances
492
+ for instance in instance_group
493
+ ]
494
+
495
+ self._instances = flattened_instances
496
+
497
+ if len(self._instances) == 0:
498
+ raise ValueError("task.build_requests() did not find any docs!")
499
+
500
+ if cache_requests and (not cached_instances or rewrite_requests_cache):
501
+ save_to_cache(file_name=cache_key, obj=instances)
502
+
503
+ @abc.abstractmethod
504
+ def construct_requests(self, doc, ctx, **kwargs):
505
+ """Uses RequestFactory to construct Requests and returns an iterable of
506
+ Requests which will be sent to the LM.
507
+
508
+ :param doc:
509
+ The document as returned from training_docs, validation_docs, or test_docs.
510
+ :param ctx: str
511
+ The context string, generated by fewshot_context. This includes the natural
512
+ language description, as well as the few shot examples, and the question
513
+ part of the document for `doc`.
514
+ :param doc_idx: int
515
+ The index of a document within `self.test_docs()` or `self.validation_docs()`,
516
+ whichever is the main split used.
517
+ :param repeats: int
518
+ TODO: update this docstring
519
+ The number of times each instance in a dataset is inferred on. Defaults to 1,
520
+ can be increased for techniques like majority voting.
521
+ """
522
+ pass
523
+
524
+ @abc.abstractmethod
525
+ def process_results(self, doc, results):
526
+ """Take a single document and the LM results and evaluates, returning a
527
+ dict where keys are the names of submetrics and values are the values of
528
+ the metric for that one document
529
+
530
+ :param doc:
531
+ The document as returned from training_docs, validation_docs, or test_docs.
532
+ :param results:
533
+ The results of the requests created in construct_requests.
534
+ """
535
+ pass
536
+
537
+ @abc.abstractmethod
538
+ def aggregation(self):
539
+ """
540
+ :returns: {str: [metric_score] -> float}
541
+ A dictionary where keys are the names of submetrics and values are
542
+ functions that aggregate a list of metric scores
543
+ """
544
+ pass
545
+
546
+ @abc.abstractmethod
547
+ def higher_is_better(self):
548
+ """
549
+ :returns: {str: bool}
550
+ A dictionary where keys are the names of submetrics and values are
551
+ whether a higher value of the submetric is better
552
+ """
553
+ pass
554
+
555
+ def get_config(self, key: str) -> Any:
556
+ return getattr(self._config, key, None)
557
+
558
+ @classmethod
559
+ def count_bytes(cls, doc):
560
+ """Used for byte-level perplexity metrics in rolling loglikelihood"""
561
+ return len(doc.encode("utf-8"))
562
+
563
+ @classmethod
564
+ def count_words(cls, doc):
565
+ """Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
566
+ return len(re.split(r"\s+", doc))
567
+
568
+ @utils.positional_deprecated
569
+ def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs):
570
+ """Returns a fewshot context string that is made up of a prepended description
571
+ (if provided), the `num_fewshot` number of examples, and an appended prompt example.
572
+
573
+ :param doc: str
574
+ The document as returned from training_docs, validation_docs, or test_docs.
575
+ :param num_fewshot: int
576
+ The number of fewshot examples to provide in the returned context string.
577
+ :param rnd: random.Random
578
+ The pseudo-random number generator used to randomly sample examples.
579
+ WARNING: This is currently a required arg although it's optionalized with a default `None`.
580
+ :param description: str
581
+ The task's description that will be prepended to the fewshot examples.
582
+ :returns: str
583
+ The fewshot context.
584
+ """
585
+ if rnd is None:
586
+ if self.fewshot_rnd is not None:
587
+ rnd = self.fewshot_rnd
588
+ else:
589
+ raise ValueError(
590
+ "A `random.Random` generator argument must be provided to `rnd`"
591
+ )
592
+
593
+ description = description if description else ""
594
+
595
+ if num_fewshot == 0:
596
+ labeled_examples = ""
597
+ else:
598
+ # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
599
+ if self.has_training_docs():
600
+ fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
601
+ else:
602
+ if self._fewshot_docs is None:
603
+ self._fewshot_docs = list(
604
+ self.validation_docs()
605
+ if self.has_validation_docs()
606
+ else self.test_docs()
607
+ )
608
+
609
+ fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
610
+
611
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
612
+ fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
613
+
614
+ labeled_examples = (
615
+ "\n\n".join(
616
+ [
617
+ self.doc_to_text(doc) + self.doc_to_target(doc)
618
+ for doc in fewshotex
619
+ ]
620
+ )
621
+ + "\n\n"
622
+ )
623
+
624
+ example = self.doc_to_text(doc)
625
+ return description + labeled_examples + example
626
+
627
+ def apply_filters(self) -> Optional[List[Instance]]:
628
+ """Iterates over FilterEnsembles and applies them to instances"""
629
+ if hasattr(self, "_filters"):
630
+ for f in self._filters:
631
+ f.apply(self._instances)
632
+ else:
633
+ eval_logger.warning("No filter defined, passing through instances")
634
+ return self._instances
635
+
636
+ def dump_config(self) -> dict:
637
+ """Returns the config as a dictionary."""
638
+ # TODO: this should only return the overrides applied to a non-YAML task's configuration.
639
+ # (num_fewshot)
640
+ return self.config.to_dict()
641
+
642
+ def set_config(self, key: str, value: Any, update: bool = False) -> None:
643
+ """Set or update the configuration for a given key."""
644
+ if key is None:
645
+ raise ValueError("Key must be provided.")
646
+
647
+ if update:
648
+ current_value = getattr(self._config, key, {})
649
+ if not isinstance(current_value, dict):
650
+ raise TypeError(
651
+ f"Expected a dict for key '{key}', got {type(current_value).__name__} instead."
652
+ )
653
+ current_value.update(value)
654
+ else:
655
+ setattr(self._config, key, value)
656
+
657
+ def override_metric(self, metric_name: str) -> None:
658
+ """
659
+ Override the default metrics used for evaluation with custom metrics.
660
+
661
+ Parameters:
662
+ - metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
663
+ """
664
+ (
665
+ self._metric_fn_list,
666
+ self._aggregation_list,
667
+ self._metric_fn_kwargs,
668
+ self._higher_is_better,
669
+ ) = ({}, {}, {}, {})
670
+ self._metric_fn_list[metric_name] = get_metric(metric_name)
671
+ self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
672
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
673
+ self._metric_fn_kwargs[metric_name] = {}
674
+ if not isinstance(self, ConfigurableTask):
675
+ self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
676
+ self.aggregation = lambda: {
677
+ metric_name: get_metric_aggregation(metric_name)
678
+ }
679
+ setattr(self._config, "metric_list", [{"metric": metric_name}])
680
+ setattr(self._config, "process_results", None)
681
+
682
+ def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
683
+ self.fewshot_rnd = random.Random(seed)
684
+ if hasattr(self, "sampler"):
685
+ self.sampler.rnd = self.fewshot_rnd
686
+
687
+ @property
688
+ def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
689
+ if self.has_test_docs():
690
+ return self.test_docs()
691
+ elif self.has_validation_docs():
692
+ return self.validation_docs()
693
+ else:
694
+ raise ValueError(
695
+ f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
696
+ )
697
+
698
+ def doc_iterator(
699
+ self,
700
+ *,
701
+ rank: int = 0,
702
+ limit: Union[int, None] = None,
703
+ world_size: int = 1,
704
+ samples: Optional[List[int]] = None,
705
+ ) -> Iterator[Tuple[int, Any]]:
706
+ if samples:
707
+ n = len(self.eval_docs)
708
+ assert all([e < n for e in samples]), (
709
+ f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
710
+ )
711
+ eval_logger.info(
712
+ f"{self.config.task}: Evaluating on {len(samples)} examples"
713
+ )
714
+ doc_iterator = utils.create_iterator(
715
+ enumerate(x for i, x in enumerate(self.eval_docs) if i in samples),
716
+ rank=int(rank),
717
+ limit=None, # limit does not matter here since we are selecting samples directly
718
+ world_size=int(world_size),
719
+ )
720
+ else:
721
+ limit = int(limit) if limit else None
722
+ doc_iterator = utils.create_iterator(
723
+ enumerate(self.eval_docs),
724
+ rank=int(rank),
725
+ limit=limit,
726
+ world_size=int(world_size),
727
+ )
728
+ return doc_iterator
729
+
730
+
731
+ class ConfigurableTask(Task):
732
+ VERSION = "Yaml"
733
+ OUTPUT_TYPE = None
734
+ CONFIG = None
735
+
736
+ def __init__(
737
+ self,
738
+ data_dir=None,
739
+ cache_dir=None,
740
+ download_mode=None,
741
+ config: Optional[dict] = None,
742
+ ) -> None: # TODO no super() call here
743
+ # Get pre-configured attributes
744
+ self._config = self.CONFIG
745
+
746
+ # Use new configurations if there was no preconfiguration
747
+ if self.config is None:
748
+ self._config = TaskConfig(**config)
749
+ # Overwrite configs
750
+ else:
751
+ if config is not None:
752
+ self._config.__dict__.update(config)
753
+
754
+ if self.config is None:
755
+ raise ValueError(
756
+ "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
757
+ )
758
+
759
+ if isinstance(self.config.metadata, dict):
760
+ if "version" in self.config.metadata:
761
+ self.VERSION = self.config.metadata["version"]
762
+
763
+ if self.config.output_type is not None:
764
+ if self.config.output_type not in ALL_OUTPUT_TYPES:
765
+ raise ValueError(
766
+ f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
767
+ )
768
+ self.OUTPUT_TYPE = self.config.output_type
769
+
770
+ if self.config.doc_to_image is not None:
771
+ # mark the task as requiring multimodality.
772
+ self.MULTIMODAL = True
773
+
774
+ if self.config.doc_to_audio:
775
+ # mark the task as requiring multimodality.
776
+ self.MULTIMODAL = True
777
+
778
+ if self.config.unsafe_code is not False:
779
+ self.UNSAFE_CODE = True
780
+
781
+ if self.config.dataset_path is not None:
782
+ self.DATASET_PATH = self.config.dataset_path
783
+
784
+ if self.config.dataset_name is not None:
785
+ self.DATASET_NAME = self.config.dataset_name
786
+
787
+ self._metric_fn_list = {}
788
+ self._metric_fn_kwargs = {}
789
+ self._aggregation_list = {}
790
+ self._higher_is_better = {}
791
+
792
+ if self.config.metric_list is None:
793
+ # TODO: handle this in TaskConfig.__post_init__ ?
794
+ _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
795
+
796
+ for metric_name in _metric_list:
797
+ self._metric_fn_list[metric_name] = get_metric(metric_name)
798
+ self._metric_fn_kwargs[metric_name] = {}
799
+ self._aggregation_list[metric_name] = get_metric_aggregation(
800
+ metric_name
801
+ )
802
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
803
+ else:
804
+ for metric_config in self.config.metric_list:
805
+ if "metric" not in metric_config:
806
+ raise ValueError(
807
+ "'metric' key not provided for an entry in 'metric_list', must be specified!"
808
+ )
809
+ metric_name = metric_config["metric"]
810
+ kwargs = {
811
+ key: metric_config[key]
812
+ for key in metric_config
813
+ if key
814
+ not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
815
+ }
816
+ hf_evaluate_metric = (
817
+ "hf_evaluate" in metric_config
818
+ and metric_config["hf_evaluate"] is True
819
+ )
820
+
821
+ if self.config.process_results is not None:
822
+ self._metric_fn_list[metric_name] = None
823
+ self._metric_fn_kwargs[metric_name] = {}
824
+ elif callable(metric_name):
825
+ metric_fn = metric_name.__call__
826
+ metric_name = metric_name.__name__
827
+ self._metric_fn_list[metric_name] = metric_fn
828
+ self._metric_fn_kwargs[metric_name] = kwargs
829
+ else:
830
+ self._metric_fn_list[metric_name] = get_metric(
831
+ metric_name, hf_evaluate_metric
832
+ )
833
+ self._metric_fn_kwargs[metric_name] = kwargs
834
+
835
+ if "aggregation" in metric_config:
836
+ agg_name = metric_config["aggregation"]
837
+ if isinstance(agg_name, str):
838
+ self._aggregation_list[metric_name] = get_aggregation(agg_name)
839
+ elif callable(agg_name): # noqa: E721
840
+ self._aggregation_list[metric_name] = metric_config[
841
+ "aggregation"
842
+ ]
843
+ else:
844
+ INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
845
+ metric_agg = get_metric_aggregation(metric_name)
846
+ eval_logger.warning(
847
+ f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
848
+ f"using default "
849
+ f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
850
+ )
851
+ self._aggregation_list[metric_name] = metric_agg
852
+
853
+ if "higher_is_better" in metric_config:
854
+ self._higher_is_better[metric_name] = metric_config[
855
+ "higher_is_better"
856
+ ]
857
+ else:
858
+ eval_logger.warning(
859
+ f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
860
+ f"using default "
861
+ f"higher_is_better={is_higher_better(metric_name)}"
862
+ )
863
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
864
+
865
+ self.download(self.config.dataset_kwargs)
866
+ self._training_docs = None
867
+ self._fewshot_docs = None
868
+
869
+ if self.config.filter_list is not None:
870
+ self._filters = []
871
+ for filter_config in self.config.filter_list:
872
+ filter_name = filter_config["name"]
873
+ filter_functions = filter_config["filter"]
874
+ components = []
875
+ for function in filter_functions:
876
+ kwargs = {
877
+ key: function[key] for key in function if key != "function"
878
+ }
879
+ components.append([function["function"], kwargs])
880
+ filter_pipeline = build_filter_ensemble(filter_name, components)
881
+ self._filters.append(filter_pipeline)
882
+ else:
883
+ # TODO: handle repeats in a more general way rather than just discarding
884
+ eval_logger.debug(
885
+ "No custom filters defined. Using default 'take_first' filter for handling repeats."
886
+ )
887
+ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
888
+
889
+ if self.config.use_prompt is not None:
890
+ eval_logger.info(f"loading prompt {self.config.use_prompt}")
891
+ self.prompt = get_prompt(
892
+ self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
893
+ )
894
+ else:
895
+ self.prompt = None
896
+
897
+ if self.fewshot_docs() is not None:
898
+ self.fewshot_rnd = (
899
+ random.Random()
900
+ ) # setting with no seed, to be overridden at a later time
901
+ config_sampler: Union[str, Callable] = (
902
+ self.config.fewshot_config.get("sampler", "default")
903
+ if self.config.fewshot_config
904
+ else "default"
905
+ )
906
+ if isinstance(config_sampler, str):
907
+ self.sampler = samplers.get_sampler(config_sampler)(
908
+ list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
909
+ )
910
+ elif callable(config_sampler) and issubclass(
911
+ config_sampler, samplers.ContextSampler
912
+ ):
913
+ self.sampler = config_sampler(
914
+ docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
915
+ )
916
+ else:
917
+ raise TypeError(
918
+ f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
919
+ f"not {type(config_sampler)}"
920
+ )
921
+
922
+ self.task_docs = self.eval_docs
923
+
924
+ # Test One Doc
925
+ self.features = list(self.task_docs.features.keys())
926
+ self.multiple_input = 0
927
+ self.multiple_target = 0
928
+ test_doc = self.task_docs[0]
929
+ test_text = self.doc_to_text(test_doc)
930
+ test_target = self.doc_to_target(test_doc)
931
+
932
+ if self.config.doc_to_choice is not None:
933
+ test_choice = self.doc_to_choice(test_doc)
934
+ if not isinstance(test_choice, list):
935
+ eval_logger.error("doc_to_choice must return list")
936
+ else:
937
+ num_choice = len(test_choice)
938
+
939
+ if isinstance(test_text, int):
940
+ eval_logger.debug(
941
+ "doc_to_text returned an int. Assuming multiple inputs."
942
+ )
943
+ self.multiple_input = num_choice
944
+ else:
945
+ test_choice = None
946
+
947
+ if isinstance(test_target, list):
948
+ eval_logger.debug(
949
+ "doc_to_target returned a list. Assuming multiple targets."
950
+ )
951
+ self.multiple_target = len(test_target)
952
+ else:
953
+ if (isinstance(test_target, int)) and (test_choice is not None):
954
+ test_target = test_choice[test_target]
955
+ else:
956
+ test_target = str(test_target)
957
+
958
+ if test_choice is not None:
959
+ check_choices = test_choice
960
+ else:
961
+ check_choices = [test_target]
962
+ if self.config.doc_to_choice is not None:
963
+ for choice in check_choices:
964
+ choice_has_whitespace = True if choice[0].isspace() else False
965
+ delimiter_has_whitespace = (
966
+ True
967
+ if self.config.target_delimiter.rstrip()
968
+ != self.config.target_delimiter
969
+ else False
970
+ )
971
+
972
+ if delimiter_has_whitespace and choice_has_whitespace:
973
+ eval_logger.debug(
974
+ f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
975
+ )
976
+ elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
977
+ eval_logger.debug(
978
+ f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
979
+ )
980
+
981
+ def download(
982
+ self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
983
+ ) -> None:
984
+ if isinstance(self.config.custom_dataset, Callable):
985
+ eval_logger.warning(
986
+ f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
987
+ + "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
988
+ )
989
+ self.dataset = self.config.custom_dataset(
990
+ **(self.config.metadata or {}), **(self.config.dataset_kwargs or {})
991
+ )
992
+ else:
993
+ self.dataset = datasets.load_dataset(
994
+ path=self.DATASET_PATH,
995
+ name=self.DATASET_NAME,
996
+ **dataset_kwargs if dataset_kwargs is not None else {},
997
+ )
998
+
999
+ def has_training_docs(self) -> bool:
1000
+ if self.config.training_split is not None:
1001
+ return True
1002
+ else:
1003
+ return False
1004
+
1005
+ def has_validation_docs(self) -> bool:
1006
+ if self.config.validation_split is not None:
1007
+ return True
1008
+ else:
1009
+ return False
1010
+
1011
+ def has_test_docs(self) -> bool:
1012
+ if self.config.test_split is not None:
1013
+ return True
1014
+ else:
1015
+ return False
1016
+
1017
+ def training_docs(self) -> datasets.Dataset:
1018
+ if self.has_training_docs():
1019
+ if self.config.process_docs is not None:
1020
+ return self.config.process_docs(
1021
+ self.dataset[self.config.training_split]
1022
+ )
1023
+ return self.dataset[self.config.training_split]
1024
+
1025
+ def validation_docs(self) -> datasets.Dataset:
1026
+ if self.has_validation_docs():
1027
+ if self.config.process_docs is not None:
1028
+ return self.config.process_docs(
1029
+ self.dataset[self.config.validation_split]
1030
+ )
1031
+ return self.dataset[self.config.validation_split]
1032
+
1033
+ def test_docs(self) -> datasets.Dataset:
1034
+ if self.has_test_docs():
1035
+ if self.config.process_docs is not None:
1036
+ return self.config.process_docs(self.dataset[self.config.test_split])
1037
+ return self.dataset[self.config.test_split]
1038
+
1039
+ def fewshot_docs(self):
1040
+ if self.config.fewshot_split is not None:
1041
+ if self.config.process_docs is not None:
1042
+ return self.config.process_docs(self.dataset[self.config.fewshot_split])
1043
+ return self.dataset[self.config.fewshot_split]
1044
+ elif (
1045
+ self.config.fewshot_config is not None
1046
+ and self.config.fewshot_config.get("samples", None) is not None
1047
+ ):
1048
+ if isinstance(self.config.fewshot_config["samples"], list):
1049
+ return self.config.fewshot_config["samples"]
1050
+ elif callable(self.config.fewshot_config["samples"]):
1051
+ return self.config.fewshot_config["samples"]()
1052
+ else:
1053
+ raise Exception(
1054
+ "`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list."
1055
+ )
1056
+ else:
1057
+ if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
1058
+ eval_logger.warning(
1059
+ f"[Task: {self.config.task}] "
1060
+ "num_fewshot > 0 but fewshot_split is None. "
1061
+ "using preconfigured rule."
1062
+ )
1063
+ return super().fewshot_docs()
1064
+
1065
+ @staticmethod
1066
+ def append_target_question(
1067
+ labeled_examples: List[Dict[str, str]],
1068
+ question: str,
1069
+ fewshot_as_multiturn: bool = False,
1070
+ gen_prefix: Optional[str] = None,
1071
+ ) -> None:
1072
+ """Adds a target question to the labeled examples list.
1073
+ If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
1074
+ Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant.
1075
+ """
1076
+ if not fewshot_as_multiturn:
1077
+ # if no messages or last message is system, append as new user entry
1078
+ if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
1079
+ labeled_examples.append({"role": "user", "content": question})
1080
+ # if last message is user, append to it to avoid two user messages in a row
1081
+ else:
1082
+ labeled_examples[-1]["content"] += question
1083
+ else:
1084
+ # if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
1085
+ labeled_examples.append({"role": "user", "content": question})
1086
+ if gen_prefix:
1087
+ labeled_examples.append({"role": "assistant", "content": gen_prefix})
1088
+
1089
+ @utils.positional_deprecated
1090
+ def fewshot_context(
1091
+ self,
1092
+ doc: dict,
1093
+ num_fewshot: int,
1094
+ system_instruction: Optional[str] = None,
1095
+ apply_chat_template: bool = False,
1096
+ fewshot_as_multiturn: bool = False,
1097
+ chat_template: Optional[Callable] = None,
1098
+ gen_prefix: Optional[str] = None,
1099
+ ) -> Union[str, List[str]]:
1100
+ """Returns a fewshot context string that is made up of a prepended description
1101
+ (if provided), the `num_fewshot` number of examples, and an appended prompt example.
1102
+
1103
+ :param doc: str
1104
+ The document as returned from training_docs, validation_docs, or test_docs.
1105
+ :param num_fewshot: int
1106
+ The number of fewshot examples to provide in the returned context string.
1107
+ :param system_instruction: str
1108
+ System instruction to be applied to the prompt.
1109
+ :param apply_chat_template: bool
1110
+ Whether to apply the chat template to the fewshot context.
1111
+ :param fewshot_as_multiturn: bool
1112
+ Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
1113
+ :param chat_template:
1114
+ callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
1115
+ :param gen_prefix:
1116
+ String to append after the <|assistant|> token.
1117
+ :returns: str
1118
+ The fewshot context.
1119
+ """
1120
+ if apply_chat_template:
1121
+ labeled_examples = []
1122
+ else:
1123
+ labeled_examples = ""
1124
+
1125
+ # get task description
1126
+ if description := self.config.description:
1127
+ description = utils.apply_template(self.config.description, doc)
1128
+
1129
+ # create system prompt based on the provided system instruction and description
1130
+ if system_instruction is not None and description:
1131
+ system_prompt = (
1132
+ f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
1133
+ )
1134
+ elif system_instruction is not None:
1135
+ system_prompt = system_instruction
1136
+ elif description:
1137
+ system_prompt = description
1138
+ else:
1139
+ system_prompt = ""
1140
+
1141
+ # add system prompt if specified
1142
+ if system_prompt:
1143
+ if apply_chat_template:
1144
+ labeled_examples.append({"role": "system", "content": system_prompt})
1145
+ else:
1146
+ labeled_examples = system_prompt
1147
+ # if few-shot - append examples after the system prompt
1148
+ if num_fewshot > 0:
1149
+ if apply_chat_template:
1150
+ labeled_examples.extend(
1151
+ self.sampler.get_chat_context(
1152
+ doc,
1153
+ num_fewshot,
1154
+ fewshot_as_multiturn,
1155
+ gen_prefix=gen_prefix,
1156
+ )
1157
+ )
1158
+ else:
1159
+ labeled_examples += self.sampler.get_context(
1160
+ doc, num_fewshot, gen_prefix=gen_prefix
1161
+ )
1162
+
1163
+ example = self.doc_to_text(doc)
1164
+ if apply_chat_template:
1165
+ if self.multiple_input:
1166
+ # TODO: append prefill?
1167
+ if not labeled_examples:
1168
+ return ""
1169
+ return chat_template(labeled_examples)
1170
+ if isinstance(example, str):
1171
+ self.append_target_question(
1172
+ labeled_examples,
1173
+ example,
1174
+ fewshot_as_multiturn,
1175
+ gen_prefix=gen_prefix,
1176
+ )
1177
+ # for loglikelihood create a list of questions with appended choices
1178
+ elif isinstance(example, list):
1179
+ labeled_examples_list = []
1180
+ # copy chat history for each example and append the answer
1181
+ for ex in example:
1182
+ chat = deepcopy(labeled_examples)
1183
+ self.append_target_question(
1184
+ chat,
1185
+ ex,
1186
+ fewshot_as_multiturn,
1187
+ gen_prefix=gen_prefix,
1188
+ )
1189
+ # TODO: append prefill?
1190
+ labeled_examples_list.append(
1191
+ chat_template(
1192
+ chat,
1193
+ add_generation_prompt=False if gen_prefix else True,
1194
+ )
1195
+ )
1196
+ return labeled_examples_list
1197
+ # if example is an integer, append the choice or convert to string
1198
+ elif isinstance(example, int):
1199
+ if self.config.doc_to_choice is not None:
1200
+ choices = self.doc_to_choice(doc)
1201
+ self.append_target_question(
1202
+ labeled_examples,
1203
+ choices[example],
1204
+ fewshot_as_multiturn,
1205
+ gen_prefix=gen_prefix,
1206
+ )
1207
+ else:
1208
+ self.append_target_question(
1209
+ labeled_examples,
1210
+ str(example),
1211
+ fewshot_as_multiturn,
1212
+ gen_prefix=gen_prefix,
1213
+ )
1214
+ # return lm.apply_chat_template(labeled_examples)
1215
+ return chat_template(
1216
+ labeled_examples,
1217
+ add_generation_prompt=False if gen_prefix else True,
1218
+ )
1219
+ else:
1220
+ prefix = (
1221
+ self.config.target_delimiter + gen_prefix
1222
+ if gen_prefix is not None
1223
+ else ""
1224
+ )
1225
+ if self.multiple_input:
1226
+ return labeled_examples
1227
+ if isinstance(example, str):
1228
+ return labeled_examples + example + prefix
1229
+ elif isinstance(example, list):
1230
+ return [labeled_examples + ex + prefix for ex in example]
1231
+ elif isinstance(example, int):
1232
+ if self.config.doc_to_choice is not None:
1233
+ choices = self.doc_to_choice(doc)
1234
+ return labeled_examples + choices[example] + prefix
1235
+ else:
1236
+ return labeled_examples + str(example) + prefix
1237
+
1238
+ def apply_filters(self) -> Optional[List[Instance]]:
1239
+ """Iterates over FilterEnsembles and applies them to instances"""
1240
+ if hasattr(self, "_filters"):
1241
+ for f in self._filters:
1242
+ f.apply(self._instances)
1243
+ else:
1244
+ eval_logger.warning("No filter defined, passing through instances")
1245
+ return self._instances
1246
+
1247
+ def should_decontaminate(self):
1248
+ return self.config.should_decontaminate
1249
+
1250
+ def doc_to_decontamination_query(self, doc: dict):
1251
+ if self.config.should_decontaminate:
1252
+ if self.config.doc_to_decontamination_query is None:
1253
+ return self.doc_to_text(doc)
1254
+ else:
1255
+ doc_to_decontamination_query = self.config.doc_to_decontamination_query
1256
+ if doc_to_decontamination_query in self.features:
1257
+ return doc[doc_to_decontamination_query]
1258
+ elif callable(doc_to_decontamination_query):
1259
+ return doc_to_decontamination_query(doc)
1260
+ else:
1261
+ return ast.literal_eval(
1262
+ utils.apply_template(
1263
+ self.config.doc_to_decontamination_query, doc
1264
+ )
1265
+ )
1266
+
1267
+ def _process_doc(self, doc: dict) -> dict:
1268
+ """
1269
+ Override this to process (detokenize, strip, replace, etc.) individual
1270
+ documents. This can be used in a map over documents of a data split.
1271
+ E.g. `map(self._process_doc, self.dataset["validation"])`
1272
+
1273
+ :return: dict
1274
+ The processed version of the specified `doc`.
1275
+ """
1276
+ return doc
1277
+
1278
+ def doc_to_text(self, doc, doc_to_text=None):
1279
+ if self.prompt is not None:
1280
+ doc_to_text = self.prompt
1281
+ elif doc_to_text is not None:
1282
+ doc_to_text = doc_to_text
1283
+ else:
1284
+ doc_to_text = self.config.doc_to_text
1285
+
1286
+ if isinstance(doc_to_text, int):
1287
+ return doc_to_text
1288
+ elif isinstance(doc_to_text, str):
1289
+ if doc_to_text in self.features:
1290
+ # if self.config.doc_to_choice is not None:
1291
+ # return self.doc_to_choice(doc)[doc[doc_to_text]]
1292
+ # else:
1293
+ return doc[doc_to_text]
1294
+ else:
1295
+ text_string = utils.apply_template(doc_to_text, doc)
1296
+ if text_string.isdigit() and self._config.doc_to_choice is not None:
1297
+ return ast.literal_eval(text_string)
1298
+ else:
1299
+ return text_string
1300
+ elif callable(doc_to_text):
1301
+ return doc_to_text(doc)
1302
+ # Used when applying a Promptsource template
1303
+ elif hasattr(doc_to_text, "apply"):
1304
+ applied_prompt = doc_to_text.apply(doc)
1305
+ if len(applied_prompt) == 2:
1306
+ return applied_prompt[0]
1307
+ else:
1308
+ eval_logger.warning("Applied prompt returns empty string")
1309
+ return self.config.fewshot_delimiter
1310
+ else:
1311
+ print(type(doc_to_text))
1312
+ raise TypeError
1313
+
1314
+ def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
1315
+ if self.prompt is not None:
1316
+ doc_to_target = self.prompt
1317
+ elif doc_to_target is not None:
1318
+ doc_to_target = doc_to_target
1319
+ else:
1320
+ doc_to_target = self.config.doc_to_target
1321
+
1322
+ if isinstance(doc_to_target, int):
1323
+ return doc_to_target
1324
+ elif isinstance(doc_to_target, str):
1325
+ if doc_to_target in self.features:
1326
+ # if self.config.doc_to_choice is not None:
1327
+ # return self.doc_to_choice(doc)[doc[doc_to_target]]
1328
+ # else:
1329
+ return doc[doc_to_target]
1330
+ else:
1331
+ target_string = utils.apply_template(doc_to_target, doc)
1332
+ if target_string.isdigit() and self._config.doc_to_choice is not None:
1333
+ return ast.literal_eval(target_string)
1334
+ elif (
1335
+ len(target_string) >= 2
1336
+ and (target_string[0] == "[")
1337
+ and (target_string[-1] == "]")
1338
+ ):
1339
+ try:
1340
+ return ast.literal_eval(target_string)
1341
+ except (SyntaxError, ValueError):
1342
+ return target_string
1343
+ else:
1344
+ return target_string
1345
+ elif isinstance(doc_to_target, list):
1346
+ return doc_to_target
1347
+ elif callable(doc_to_target):
1348
+ return doc_to_target(doc)
1349
+ # Used when applying a Promptsource template
1350
+ elif hasattr(doc_to_target, "apply"):
1351
+ applied_prompt = doc_to_target.apply(doc)
1352
+ if len(applied_prompt) == 2:
1353
+ return applied_prompt[1]
1354
+ else:
1355
+ eval_logger.warning("Applied prompt returns empty string")
1356
+ return self.config.fewshot_delimiter
1357
+ else:
1358
+ raise TypeError
1359
+
1360
+ def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
1361
+ if self.prompt is not None:
1362
+ doc_to_choice = self.prompt
1363
+ elif doc_to_choice is not None:
1364
+ doc_to_choice = doc_to_choice
1365
+ elif self.config.doc_to_choice is None:
1366
+ eval_logger.error("doc_to_choice was called but not set in config")
1367
+ else:
1368
+ doc_to_choice = self.config.doc_to_choice
1369
+
1370
+ if isinstance(doc_to_choice, str):
1371
+ if doc_to_choice in self.features:
1372
+ return doc[doc_to_choice]
1373
+ else:
1374
+ return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
1375
+ elif isinstance(doc_to_choice, list):
1376
+ return doc_to_choice
1377
+ elif isinstance(doc_to_choice, dict):
1378
+ return list(doc_to_choice.values())
1379
+ elif callable(doc_to_choice):
1380
+ return doc_to_choice(doc)
1381
+ elif hasattr(doc_to_choice, "get_answer_choices_list"):
1382
+ return doc_to_choice.get_answer_choices_list(doc)
1383
+ else:
1384
+ raise TypeError
1385
+
1386
+ def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
1387
+ if doc_to_image is not None:
1388
+ doc_to_image = doc_to_image
1389
+ elif self.config.doc_to_image is not None:
1390
+ doc_to_image = self.config.doc_to_image
1391
+ else:
1392
+ return None
1393
+
1394
+ if isinstance(doc_to_image, list):
1395
+ image_feature = [
1396
+ self.doc_to_image(doc, feature) for feature in doc_to_image
1397
+ ]
1398
+ return [feature for feature in image_feature if feature is not None]
1399
+ elif isinstance(doc_to_image, str):
1400
+ if doc_to_image in self.features:
1401
+ return doc[doc_to_image]
1402
+ else:
1403
+ return ast.literal_eval(utils.apply_template(doc_to_image, doc))
1404
+ elif callable(doc_to_image):
1405
+ return doc_to_image(doc)
1406
+ else:
1407
+ return None
1408
+
1409
+ def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]:
1410
+ if doc_to_audio is not None:
1411
+ doc_to_audio = doc_to_audio
1412
+ elif self.config.doc_to_audio is not None:
1413
+ doc_to_audio = self.config.doc_to_audio
1414
+ else:
1415
+ return None
1416
+
1417
+ if isinstance(doc_to_audio, list):
1418
+ audio_feature = [
1419
+ self.doc_to_audio(doc, feature) for feature in doc_to_audio
1420
+ ]
1421
+ return [feature for feature in audio_feature if feature is not None]
1422
+ elif isinstance(doc_to_audio, str):
1423
+ if doc_to_audio in self.features:
1424
+ return doc[doc_to_audio]
1425
+ else:
1426
+ return ast.literal_eval(utils.apply_template(doc_to_audio, doc))
1427
+ elif callable(doc_to_audio):
1428
+ return doc_to_audio(doc)
1429
+ else:
1430
+ return None
1431
+
1432
+ def doc_to_prefix(self, doc):
1433
+ if (gen_prefix := self.config.gen_prefix) is not None:
1434
+ if gen_prefix in self.features:
1435
+ return doc[gen_prefix]
1436
+ else:
1437
+ return utils.apply_template(gen_prefix, doc)
1438
+ return None
1439
+
1440
+ def construct_requests(
1441
+ self, doc: dict, ctx: str, **kwargs
1442
+ ) -> Union[List[Instance], Instance]:
1443
+ apply_chat_template = kwargs.pop("apply_chat_template", False)
1444
+ chat_template: Callable | None = kwargs.pop("chat_template", None)
1445
+
1446
+ aux_arguments = None
1447
+
1448
+ if self.OUTPUT_TYPE == "loglikelihood":
1449
+ arguments = (ctx, self.doc_to_target(doc))
1450
+ elif self.OUTPUT_TYPE == "loglikelihood_rolling":
1451
+ arguments = (self.doc_to_target(doc),)
1452
+ elif self.OUTPUT_TYPE == "multiple_choice":
1453
+ choices = self.doc_to_choice(doc)
1454
+ target_delimiter = self.config.target_delimiter
1455
+ if apply_chat_template:
1456
+ target_delimiter = ""
1457
+ if self.multiple_input:
1458
+ # If there are multiple inputs, choices are placed in the ctx
1459
+ # apply chat_template to choices if apply_chat_template
1460
+ cont = self.doc_to_target(doc)
1461
+
1462
+ arguments = [
1463
+ (
1464
+ ctx
1465
+ + (
1466
+ chat_template([{"role": "user", "content": choice}])
1467
+ if apply_chat_template
1468
+ else choice
1469
+ ),
1470
+ f"{target_delimiter}{cont}",
1471
+ )
1472
+ for choice in choices
1473
+ ]
1474
+ else:
1475
+ # Otherwise they are placed in the continuation
1476
+ arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
1477
+
1478
+ # TODO: we should raise a warning telling users this will at most ~2x runtime.
1479
+ if "acc_mutual_info" in self._metric_fn_list.keys():
1480
+ # if we are calculating multiple choice accuracy
1481
+ # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
1482
+
1483
+ # here mutual info refers to calculating
1484
+ # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
1485
+ # in other words normalizing by subtracting the unconditional logprob of each choice.
1486
+ # TODO: should these be strided? will have to modify the processing in process_results if so
1487
+ aux_arguments = [
1488
+ ("", f"{target_delimiter}{choice}") for choice in choices
1489
+ ]
1490
+
1491
+ arguments.extend(aux_arguments)
1492
+
1493
+ elif self.OUTPUT_TYPE == "generate_until":
1494
+ arguments = (ctx, deepcopy(self.config.generation_kwargs))
1495
+
1496
+ multimodal_arg = {}
1497
+ if (
1498
+ self.config.doc_to_image
1499
+ ): # TODO: ensure that non-multimodal tasks aren't getting visual args
1500
+ multimodal_arg = {
1501
+ **multimodal_arg,
1502
+ **{"visual": self.doc_to_image(doc)},
1503
+ }
1504
+
1505
+ if (
1506
+ self.config.doc_to_audio
1507
+ ): # TODO: ensure that non-multimodal tasks aren't getting audio args
1508
+ multimodal_arg = {
1509
+ **multimodal_arg,
1510
+ **{"audio": self.doc_to_audio(doc)},
1511
+ }
1512
+
1513
+ if bool(multimodal_arg):
1514
+ if isinstance(arguments, list):
1515
+ arguments = [arg + (multimodal_arg,) for arg in arguments]
1516
+ else:
1517
+ arguments = arguments + (multimodal_arg,)
1518
+
1519
+ if self.OUTPUT_TYPE == "multiple_choice":
1520
+ request_list = [
1521
+ Instance(
1522
+ request_type="loglikelihood",
1523
+ doc=doc,
1524
+ arguments=arg,
1525
+ idx=i,
1526
+ **kwargs,
1527
+ )
1528
+ for i, arg in enumerate(arguments)
1529
+ ]
1530
+
1531
+ return request_list
1532
+
1533
+ return Instance(
1534
+ request_type=self.OUTPUT_TYPE,
1535
+ doc=doc,
1536
+ arguments=arguments,
1537
+ idx=0,
1538
+ **kwargs,
1539
+ )
1540
+
1541
+ def process_results(self, doc, results):
1542
+ if callable(self.config.process_results):
1543
+ return self.config.process_results(doc, results)
1544
+
1545
+ result_dict = {}
1546
+ use_metric = list(self._metric_fn_list.keys())
1547
+ if self.OUTPUT_TYPE == "loglikelihood":
1548
+ results = results[0]
1549
+ ll, is_greedy = results
1550
+ return {
1551
+ **({"perplexity": ll} if "perplexity" in use_metric else {}),
1552
+ **({"acc": int(is_greedy)} if "acc" in use_metric else {}),
1553
+ }
1554
+ elif self.OUTPUT_TYPE == "loglikelihood_rolling":
1555
+ (loglikelihood,) = results
1556
+ _words = self.count_words(self.doc_to_target(doc))
1557
+ _bytes = self.count_bytes(self.doc_to_target(doc))
1558
+ return {
1559
+ **(
1560
+ {"word_perplexity": (loglikelihood, _words)}
1561
+ if "word_perplexity" in use_metric
1562
+ else {}
1563
+ ),
1564
+ **(
1565
+ {"byte_perplexity": (loglikelihood, _bytes)}
1566
+ if "byte_perplexity" in use_metric
1567
+ else {}
1568
+ ),
1569
+ **(
1570
+ {"bits_per_byte": (loglikelihood, _bytes)}
1571
+ if "bits_per_byte" in use_metric
1572
+ else {}
1573
+ ),
1574
+ }
1575
+ elif self.OUTPUT_TYPE == "multiple_choice":
1576
+ lls, is_greedy = zip(*results)
1577
+
1578
+ # retrieve choices in List[str] form, to compute choice lengths, etc.
1579
+ choices = self.doc_to_choice(doc)
1580
+ completion_len = np.array([float(len(i)) for i in choices])
1581
+
1582
+ if (
1583
+ 2 * len(choices) == len(lls)
1584
+ and "acc_mutual_info" in self._metric_fn_list.keys()
1585
+ ):
1586
+ # then we are doing mutual info.
1587
+ # this stores the "dryrun" / unconditional answer loglikelihoods
1588
+ # as we extend the args list with unconditional ("", continuation) pairs
1589
+ lls_unconditional = lls[len(choices) :]
1590
+ if len(lls_unconditional) != len(choices):
1591
+ raise ValueError
1592
+ # and this stores our "regular" conditional loglikelihoods
1593
+ lls = lls[: len(choices)]
1594
+
1595
+ pred = np.argmax(lls)
1596
+ pred_norm = np.argmax(lls / completion_len)
1597
+
1598
+ if self.multiple_input:
1599
+ gold = self.doc_to_text(doc)
1600
+ else:
1601
+ gold = self.doc_to_target(doc)
1602
+
1603
+ gold_index_error = False
1604
+ if isinstance(gold, list):
1605
+ gold = [i if i < len(choices) else -100 for i in gold]
1606
+ if -100 in gold:
1607
+ gold_index_error = True
1608
+ else:
1609
+ if isinstance(gold, int):
1610
+ gold = gold if gold < len(choices) else -100
1611
+ elif isinstance(gold, str):
1612
+ gold = choices.index(gold) if gold in choices else -100
1613
+
1614
+ if gold == -100:
1615
+ gold_index_error = True
1616
+
1617
+ if gold_index_error:
1618
+ eval_logger.warning(
1619
+ f"Label index was not in within range of available choices,"
1620
+ f"Sample:\n\n{doc}\n\n"
1621
+ )
1622
+
1623
+ if self.multiple_target:
1624
+ acc = 1.0 if pred in gold else 0.0
1625
+ acc_norm = 1.0 if pred_norm in gold else 0.0
1626
+ exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
1627
+ else:
1628
+ acc = 1.0 if pred == gold else 0.0
1629
+ acc_norm = 1.0 if pred_norm == gold else 0.0
1630
+ # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
1631
+ exact_match = int(is_greedy[gold]) if gold != -100 else 0
1632
+
1633
+ prob_norm = utils.softmax(lls)
1634
+
1635
+ # TODO use keyword arguments to the metric?
1636
+ # gold, pred, norm stuff, the original lls,
1637
+ result_dict = {
1638
+ **({"acc": acc} if "acc" in use_metric else {}),
1639
+ **({"f1": (gold, pred)} if "f1" in use_metric else {}),
1640
+ **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
1641
+ **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
1642
+ **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
1643
+ **(
1644
+ {"brier_score": (gold, prob_norm)}
1645
+ if "brier_score" in use_metric
1646
+ else {}
1647
+ ),
1648
+ }
1649
+
1650
+ if "acc_mutual_info" in use_metric:
1651
+ lls_mutual_info = [
1652
+ ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
1653
+ ]
1654
+ acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
1655
+ result_dict["acc_mutual_info"] = acc_mutual_info
1656
+
1657
+ elif self.OUTPUT_TYPE == "generate_until":
1658
+ gold = self.doc_to_target(doc)
1659
+ result = results[0]
1660
+ if self.config.doc_to_choice is not None:
1661
+ # If you set doc_to_choice,
1662
+ # it assumes that doc_to_target returns a number.
1663
+ choices = self.doc_to_choice(doc)
1664
+ gold = choices[gold]
1665
+ # we expect multiple_targets to be a list.
1666
+ elif self.multiple_target:
1667
+ gold = list(gold)
1668
+ # TODO: handle this better
1669
+ elif type(gold) is not type(result) and not (
1670
+ "bypass" in self._metric_fn_list.keys() or isinstance(result, list)
1671
+ ):
1672
+ # cast gold to the same type as result
1673
+ gold = type(result)(gold)
1674
+
1675
+ for metric in self._metric_fn_list.keys():
1676
+ if self.multiple_target:
1677
+ # in the case where we have multiple targets,
1678
+ # return true if any are true
1679
+ # TODO: this may break for multipLe_target, non zero-or-1 metrics
1680
+ scores = []
1681
+ if not isinstance(gold, list):
1682
+ # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
1683
+ # print(gold)
1684
+ gold = [gold]
1685
+ if metric == "exact_match":
1686
+ result = [result for _ in range(len(gold))]
1687
+ scores = self._metric_fn_list[metric](
1688
+ references=gold,
1689
+ predictions=result,
1690
+ **self._metric_fn_kwargs[metric],
1691
+ )[metric]
1692
+ result_score = 1.0 if scores > 0.0 else 0.0
1693
+ else:
1694
+ for gold_option in gold:
1695
+ try:
1696
+ result_score = self._metric_fn_list[metric](
1697
+ references=[gold_option],
1698
+ predictions=[result],
1699
+ **self._metric_fn_kwargs[metric],
1700
+ )
1701
+ except (
1702
+ TypeError
1703
+ ): # TODO: this is hacky and I don't want to do it
1704
+ result_score = self._metric_fn_list[metric](
1705
+ [gold_option, result]
1706
+ )
1707
+ if isinstance(result_score, dict):
1708
+ # TODO: this handles the case where HF evaluate returns a dict.
1709
+ result_score = result_score[metric]
1710
+ scores.append(result_score)
1711
+ if any(scores):
1712
+ result_score = 1.0
1713
+ else:
1714
+ result_score = 0.0
1715
+ else:
1716
+ try:
1717
+ result_score = self._metric_fn_list[metric](
1718
+ references=[gold],
1719
+ predictions=[result],
1720
+ **self._metric_fn_kwargs[metric],
1721
+ )
1722
+ except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
1723
+ result_score = self._metric_fn_list[metric]([gold, result])
1724
+ if isinstance(result_score, dict):
1725
+ # TODO: this handles the case where HF evaluate returns a dict.
1726
+ # This allows for multiple metrics to be returned from the same function
1727
+ for k, v in result_score.items():
1728
+ result_dict[k] = v
1729
+ else:
1730
+ result_dict[metric] = result_score
1731
+ else:
1732
+ raise ValueError(
1733
+ f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
1734
+ "'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
1735
+ )
1736
+
1737
+ return result_dict
1738
+
1739
+ def aggregation(self) -> dict:
1740
+ return self._aggregation_list
1741
+
1742
+ def higher_is_better(self) -> dict:
1743
+ return self._higher_is_better
1744
+
1745
+ def get_config(self, key: str) -> Any:
1746
+ return getattr(self._config, key, None)
1747
+
1748
+ @property
1749
+ def task_name(self) -> Any:
1750
+ return getattr(self.config, "task", None)
1751
+
1752
+ def __repr__(self):
1753
+ return (
1754
+ f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
1755
+ f"output_type={self.OUTPUT_TYPE},"
1756
+ f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
1757
+ f"num_samples={len(self.eval_docs)})"
1758
+ )
1759
+
1760
+
1761
+ class MultipleChoiceTask(Task):
1762
+ OUTPUT_TYPE = "loglikelihood"
1763
+
1764
+ def doc_to_target(self, doc: dict) -> str:
1765
+ return " " + doc["choices"][doc["gold"]]
1766
+
1767
+ def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
1768
+ # TODO: add mutual info here?
1769
+ return [
1770
+ Instance(
1771
+ request_type="loglikelihood",
1772
+ doc=doc,
1773
+ arguments=(ctx, " {}".format(choice)),
1774
+ idx=i,
1775
+ **kwargs,
1776
+ )
1777
+ for i, choice in enumerate(doc["choices"])
1778
+ ]
1779
+
1780
+ def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
1781
+ results = [
1782
+ res[0] for res in results
1783
+ ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
1784
+ gold = doc["gold"]
1785
+
1786
+ acc = 1.0 if np.argmax(results) == gold else 0.0
1787
+ completion_len = np.array([float(len(i)) for i in doc["choices"]])
1788
+ acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
1789
+
1790
+ return {
1791
+ "acc": acc,
1792
+ "acc_norm": acc_norm,
1793
+ }
1794
+
1795
+ def higher_is_better(self) -> dict:
1796
+ return {
1797
+ "acc": True,
1798
+ "acc_norm": True,
1799
+ }
1800
+
1801
+ def aggregation(self) -> dict:
1802
+ return {
1803
+ "acc": mean,
1804
+ "acc_norm": mean,
1805
+ }
1806
+
1807
+
1808
+ class PerplexityTask(Task):
1809
+ OUTPUT_TYPE = "loglikelihood_rolling"
1810
+
1811
+ def has_training_docs(self) -> bool:
1812
+ return False
1813
+
1814
+ def fewshot_examples(self, k: int, rnd) -> List:
1815
+ if k != 0:
1816
+ raise ValueError(
1817
+ "The number of fewshot examples must be 0 for perplexity tasks."
1818
+ )
1819
+ return []
1820
+
1821
+ def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
1822
+ if num_fewshot != 0:
1823
+ raise ValueError(
1824
+ "The number of fewshot examples must be 0 for perplexity tasks."
1825
+ )
1826
+
1827
+ return ""
1828
+
1829
+ def higher_is_better(self) -> dict:
1830
+ return {
1831
+ "word_perplexity": False,
1832
+ "byte_perplexity": False,
1833
+ "bits_per_byte": False,
1834
+ }
1835
+
1836
+ def doc_to_decontamination_query(self, doc):
1837
+ return doc
1838
+
1839
+ def doc_to_text(self, doc) -> str:
1840
+ return ""
1841
+
1842
+ def doc_to_target(self, doc):
1843
+ return doc
1844
+
1845
+ def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
1846
+ if bool(ctx):
1847
+ raise ValueError
1848
+
1849
+ return Instance(
1850
+ request_type=self.OUTPUT_TYPE,
1851
+ doc=doc,
1852
+ arguments=(self.doc_to_target(doc),),
1853
+ idx=0,
1854
+ **kwargs,
1855
+ )
1856
+
1857
+ def process_results(self, doc: dict, results: Tuple[float]) -> dict:
1858
+ (loglikelihood,) = results
1859
+ words = self.count_words(self.doc_to_target(doc))
1860
+ bytes_ = self.count_bytes(self.doc_to_target(doc))
1861
+ return {
1862
+ "word_perplexity": (loglikelihood, words),
1863
+ "byte_perplexity": (loglikelihood, bytes_),
1864
+ "bits_per_byte": (loglikelihood, bytes_),
1865
+ }
1866
+
1867
+ def aggregation(self) -> dict:
1868
+ return {
1869
+ "word_perplexity": weighted_perplexity,
1870
+ "byte_perplexity": weighted_perplexity,
1871
+ "bits_per_byte": bits_per_byte,
1872
+ }
1873
+
1874
+ @classmethod
1875
+ def count_bytes(cls, doc) -> int:
1876
+ return len(doc.encode("utf-8"))
1877
+
1878
+ @classmethod
1879
+ def count_words(cls, doc) -> int:
1880
+ """Downstream tasks with custom word boundaries should override this!"""
1881
+ return len(re.split(r"\s+", doc))
Prism/LLaDA/LLaDA_Baseline/dllm_eval/caching/__init__.py ADDED
File without changes
Prism/LLaDA/LLaDA_Baseline/dllm_eval/caching/cache.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import logging
3
+ import os
4
+
5
+ import dill
6
+
7
+
8
+ eval_logger = logging.getLogger(__name__)
9
+
10
+
11
+ MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
12
+
13
+ OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH")
14
+
15
+
16
+ PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache"
17
+
18
+ # This should be sufficient for uniqueness
19
+ HASH_INPUT = "EleutherAI-lm-evaluation-harness"
20
+
21
+ HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
22
+
23
+ FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
24
+
25
+
26
+ def load_from_cache(file_name: str, cache: bool = False):
27
+ if not cache:
28
+ return
29
+ try:
30
+ path = f"{PATH}/{file_name}{FILE_SUFFIX}"
31
+
32
+ with open(path, "rb") as file:
33
+ cached_task_dict = dill.loads(file.read())
34
+ return cached_task_dict
35
+
36
+ except Exception:
37
+ eval_logger.debug(f"{file_name} is not cached, generating...")
38
+ pass
39
+
40
+
41
+ def save_to_cache(file_name, obj):
42
+ if not os.path.exists(PATH):
43
+ os.mkdir(PATH)
44
+
45
+ file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"
46
+
47
+ eval_logger.debug(f"Saving {file_path} to cache...")
48
+ with open(file_path, "wb") as file:
49
+ file.write(dill.dumps(obj))
50
+
51
+
52
+ # NOTE the "key" param is to allow for flexibility
53
+ def delete_cache(key: str = ""):
54
+ files = os.listdir(PATH)
55
+
56
+ for file in files:
57
+ if file.startswith(key) and file.endswith(FILE_SUFFIX):
58
+ file_path = f"{PATH}/{file}"
59
+ os.unlink(file_path)
Prism/LLaDA/LLaDA_Baseline/dllm_eval/decontamination/__init__.py ADDED
File without changes
Prism/LLaDA/LLaDA_Baseline/dllm_eval/decontamination/janitor.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import re
3
+ import string
4
+ import traceback
5
+ from typing import Iterator, List, Sequence, Tuple, TypeVar
6
+
7
+
8
+ # This is a cpp module. Compile janitor_util.cpp with:
9
+ # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
10
+ try:
11
+ import janitor_util
12
+
13
+ JANITOR_CPP = True
14
+ except Exception:
15
+ print("WARNING: C++ module could not be loaded. Janitor running in python mode")
16
+ traceback.print_exc()
17
+ JANITOR_CPP = False
18
+
19
+ T = TypeVar("T")
20
+
21
+
22
+ # Implementation from nltk source
23
+ # https://www.nltk.org/_modules/nltk/util.html
24
+ def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
25
+ history = []
26
+ while n > 1:
27
+ # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
28
+ try:
29
+ next_item = next(sequence)
30
+ except StopIteration:
31
+ # no more data, terminate the generator
32
+ return
33
+ history.append(next_item)
34
+ n -= 1
35
+ for item in sequence:
36
+ history.append(item)
37
+ yield tuple(history)
38
+ del history[0]
39
+
40
+
41
+ def word_ngrams(s: str, n: int) -> Iterator[str]:
42
+ """Splits a string into ngram words"""
43
+ tokens = s.split() # not a generator :(
44
+ ngram_seqs = form_ngrams(iter(tokens), n)
45
+ return (" ".join(ngram) for ngram in ngram_seqs)
46
+
47
+
48
+ # Does character sequences only - combined faster function to play around with later
49
+ # def word_ngrams_indices_combined(sequence, n):
50
+ # current_word = ""
51
+ # history = []
52
+ # gap = False;
53
+ # start = 0
54
+ # end = 0
55
+ # for character in sequence:
56
+ # if character == " ":
57
+ # if not gap:
58
+ # gap = True
59
+ # history.append(current_word)
60
+ # end += len(current_word) - 1
61
+ # current_word = ""
62
+ # if len(history) == n:
63
+ # yield (tuple(history), start, end)
64
+ # del history[0]
65
+ # start = end + 1
66
+ # end = start
67
+ # else:
68
+ # gap = False
69
+ # current_word += character
70
+
71
+
72
+ # https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
73
+ def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
74
+ """Splits a string on whitespaces and records the indices of each in the original string.
75
+ @:return generator((word, (start_idx, end_idx)), ...)
76
+ """
77
+ return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
78
+
79
+
80
+ def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
81
+ """Splits a string into pairs of (ngram words, their start/end indices)"""
82
+ tokens_with_indices = split_indices(s)
83
+
84
+ # Generator of ngrams of (word, idx_pairs)
85
+ # (
86
+ # [(word, (start,end)), (word, (start, end))...],
87
+ # [(word, (start, end)), ...],
88
+ # ...
89
+ # )
90
+ ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
91
+
92
+ # Generator of pairs of word and index ngrams
93
+ # (
94
+ # ([word, word, ...], [(start,end), (start,end), ...]),
95
+ # ...
96
+ # )
97
+ ngram_indices_pairs = (
98
+ zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
99
+ )
100
+
101
+ # Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
102
+ return (
103
+ (" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
104
+ for ngram_seq, indices in ngram_indices_pairs
105
+ )
106
+
107
+
108
+ class Janitor:
109
+ # FIXME delete_chars: Should anything else go here? Special chars?
110
+ def __init__(
111
+ self,
112
+ ngram_n: int = 13,
113
+ window_to_remove: int = 200,
114
+ too_dirty_cutoff: int = 10,
115
+ minimum_slice_length: int = 200,
116
+ delete_chars: str = string.punctuation,
117
+ ) -> None:
118
+ self.ngram_n = ngram_n
119
+ self.window_to_remove = window_to_remove
120
+ self.too_dirty_cutoff = too_dirty_cutoff
121
+ self.minimum_slice_length = minimum_slice_length
122
+ self.delete_chars = delete_chars
123
+
124
+ self.dirt_ngrams = set()
125
+
126
+ # If in python, we'll translate uppercase to lowercase and delete naughty characters.
127
+ # This is fast by python standards
128
+ # https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
129
+ self.translation_table = str.maketrans(
130
+ string.ascii_lowercase + string.ascii_uppercase, # These characters
131
+ string.ascii_lowercase * 2, # Become these characters
132
+ self.delete_chars, # These are deleted
133
+ )
134
+
135
+ ##############
136
+ # I/O for saving contamination ngrams
137
+ ##############
138
+
139
+ def save_contamination_ngrams(self, filename: str) -> None:
140
+ with open(filename, "wb") as fp:
141
+ pickle.dump(filename, fp)
142
+
143
+ def load_contamination_ngrams(self, filename: str) -> None:
144
+ with open(filename, "rb") as fp:
145
+ self.dirt_ngrams = pickle.load(fp)
146
+
147
+ ##############
148
+ # Call these :)
149
+ ##############
150
+
151
+ def register_contaminant(self, dirt_string: str) -> None:
152
+ """Register a string as contamination to be removed, e.g. a test set
153
+ This breaks the dirt_string into ngrams to store for future cleaning"""
154
+ if JANITOR_CPP:
155
+ return self.register_contaminant_cpp(dirt_string)
156
+ else:
157
+ print("WARNING: Janitor running in python mode")
158
+ return self.register_contaminant_python(dirt_string)
159
+
160
+ def clean(self, dirty_string: str) -> List[str]:
161
+ """Clean a string (e.g. a training set) by removing all ngrams previously
162
+ registered as contaminants. Returns a list of clean chunks, or empty if
163
+ the string was too dirty"""
164
+ if JANITOR_CPP:
165
+ return self.clean_cpp(dirty_string)
166
+ else:
167
+ print("WARNING: Janitor running in python mode")
168
+ return self.clean_python(dirty_string)
169
+
170
+ def _split_chunks(
171
+ self, dirty_string: str, dirty_parts: Sequence[Tuple]
172
+ ) -> List[str]:
173
+ clean_chunks = []
174
+ splice_idx = 0
175
+ end = -1
176
+ for i, (ngram, start, end) in enumerate(dirty_parts):
177
+ if i >= self.too_dirty_cutoff:
178
+ return []
179
+ start = max(0, start - self.window_to_remove)
180
+ end = min(len(dirty_string), end + self.window_to_remove)
181
+
182
+ if start - splice_idx > self.minimum_slice_length:
183
+ clean_chunks.append(dirty_string[splice_idx:start])
184
+ splice_idx = end
185
+
186
+ if end < len(dirty_string) - self.minimum_slice_length:
187
+ clean_chunks.append(dirty_string[end + 1 :])
188
+
189
+ return clean_chunks
190
+
191
+ ##############
192
+ # Fast C++
193
+ ##############
194
+
195
+ def register_contaminant_cpp(self, dirt_string) -> None:
196
+ self.dirt_ngrams.update(
197
+ janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
198
+ )
199
+
200
+ def clean_cpp(self, dirty_string: str) -> List[str]:
201
+ contamination_indices = janitor_util.clean_ngram_with_indices(
202
+ dirty_string, self.delete_chars, self.ngram_n
203
+ )
204
+ return self._split_chunks(dirty_string, contamination_indices)
205
+
206
+ ##############
207
+ # Slow python
208
+ ##############
209
+
210
+ def normalize_string(self, s: str) -> str:
211
+ return s.translate(self.translation_table)
212
+
213
+ def register_contaminant_python(self, dirt_string: str) -> None:
214
+ self.dirt_ngrams.update(
215
+ word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
216
+ )
217
+
218
+ def clean_python(self, dirty_string: str) -> List[str]:
219
+ contamination_indices = (
220
+ (None, *idx_pair)
221
+ for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
222
+ if self.normalize_string(dirty_ngram) in self.dirt_ngrams
223
+ )
224
+ return self._split_chunks(dirty_string, contamination_indices)
225
+
226
+
227
+ ##################################################################
228
+ # Tests
229
+ #################################################################
230
+
231
+ # def print_cpp():
232
+ # source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
233
+
234
+ # for i in range(1, 10, 2):
235
+ # pprint(janitor_util.clean_ngram(source, string.punctuation, i))
236
+ # for ngram, start, end in \
237
+ # janitor_util.clean_ngram_with_indices(source, string.punctuation, i):
238
+ # print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n"))
239
+
240
+
241
+ # def test_cpp():
242
+ # source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
243
+ # contaminant = "dirty boy. Clean he he"
244
+
245
+ # jan_python = Janitor()
246
+ # jan_cpp = Janitor()
247
+
248
+ # jan_python.register_contaminant_python(contaminant)
249
+ # jan_cpp.register_contaminant(contaminant)
250
+
251
+ # assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams)
252
+
253
+ # assert jan_python.clean_python(source) == jan_cpp.clean(source), \
254
+ # (jan_python.clean_python(source), jan_cpp.clean(source))
255
+
256
+ # print("Passed test, python==cpp")
257
+
258
+
259
+ # def benchmark():
260
+ # # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html
261
+ # setup = \
262
+ # """
263
+ # with open("data/enwik8", "r") as f:
264
+ # data = f.read()
265
+ # jan = Janitor(too_dirty_cutoff=1000)
266
+ # jan.register_contaminant('''
267
+ # theories is that there is a connection between &quot;geekdom&quot; and autism.
268
+ # This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled &quot;
269
+ # The [[Geek]] Syndrome&quot;, which is a point argued by many in the autism rights
270
+ # movement{{ref|Wired}}. This article, many professionals assert, is just one example of
271
+ # the media's application of mental disease labels to what is actually variant normal behavior
272
+ # &amp;mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
273
+ # interests, even when they seem unusual to others, are not in themselves signs of autism or
274
+ # Asperger's syndrome. Others assert that it is actually the medical profession which is applying
275
+ # mental disease labels to children who in the past would have simply been accepted as a little
276
+ # different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
277
+ # Due to the recent publicity surrounding autism and autis
278
+ # ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
279
+ # oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
280
+ # paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
281
+ # would last, took a cautious approach, preferring to save the revenue rather than investing it in
282
+ # development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
283
+ # to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
284
+ # brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
285
+ # with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
286
+ # ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
287
+ # ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
288
+ # Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
289
+ # [[United Arab Emirates]]. After the Emirates gained independence in 1971,
290
+ # ''')
291
+ # """
292
+
293
+ # n = 1
294
+ # print(f"Timing {n} run on 100 MB")
295
+ # print("Register contaminant")
296
+ # # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n))
297
+ # print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n))
298
+
299
+ # print("Clean")
300
+ # # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n))
301
+ # print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n))
302
+
303
+
304
+ # def test_janitor_general():
305
+ # source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
306
+ # contaminant = "dirty boy. Clean he he"
307
+
308
+ # jan = Janitor(ngram_n=3)
309
+ # jan.register_contaminant(contaminant)
310
+ # cleaned = " ".join(jan.clean(source))
311
+ # for contam in jan.dirt_ngrams:
312
+ # assert contam not in cleaned, contam
313
+
314
+ # filename = "data/saved_contam"
315
+ # jan.save_contamination_ngrams(filename)
316
+
317
+ # jan = Janitor(ngram_n=3)
318
+ # jan.load_contamination_ngrams(filename)
319
+ # cleaned = " ".join(jan.clean(source))
320
+ # for contam in jan.dirt_ngrams:
321
+ # assert contam not in cleaned, contam
322
+
323
+
324
+ # if __name__ == "__main__":
325
+ # test()
326
+ # # print_cpp()
327
+ # # test_cpp()
328
+ # # benchmark()
Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .evaluation_tracker import EvaluationTracker
2
+ from .wandb_logger import WandbLogger
Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/evaluation_tracker.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ import time
6
+ from collections import defaultdict
7
+ from dataclasses import asdict, dataclass
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+
11
+ from datasets import load_dataset
12
+ from datasets.utils.metadata import MetadataConfigs
13
+ from huggingface_hub import (
14
+ DatasetCard,
15
+ DatasetCardData,
16
+ HfApi,
17
+ hf_hub_url,
18
+ )
19
+ from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
20
+
21
+ from dllm_eval.utils import (
22
+ get_file_datetime,
23
+ get_file_task_name,
24
+ get_results_filenames,
25
+ get_sample_results_filenames,
26
+ handle_non_serializable,
27
+ hash_string,
28
+ sanitize_list,
29
+ sanitize_model_name,
30
+ sanitize_task_name,
31
+ )
32
+
33
+
34
+ eval_logger = logging.getLogger(__name__)
35
+
36
+
37
+ @dataclass(init=False)
38
+ class GeneralConfigTracker:
39
+ """
40
+ Tracker for the evaluation parameters.
41
+
42
+ Attributes:
43
+ model_source (str): Source of the model (e.g. Hugging Face, GGUF, etc.)
44
+ model_name (str): Name of the model.
45
+ model_name_sanitized (str): Sanitized model name for directory creation.
46
+ start_time (float): Start time of the experiment. Logged at class init.
47
+ end_time (float): Start time of the experiment. Logged when calling [`GeneralConfigTracker.log_end_time`]
48
+ total_evaluation_time_seconds (str): Inferred total evaluation time in seconds (from the start and end times).
49
+ """
50
+
51
+ model_source: str = None
52
+ model_name: str = None
53
+ model_name_sanitized: str = None
54
+ system_instruction: str = None
55
+ system_instruction_sha: str = None
56
+ fewshot_as_multiturn: bool = None
57
+ chat_template: str = None
58
+ chat_template_sha: str = None
59
+ start_time: float = None
60
+ end_time: float = None
61
+ total_evaluation_time_seconds: str = None
62
+
63
+ def __init__(self) -> None:
64
+ """Starts the evaluation timer."""
65
+ self.start_time = time.perf_counter()
66
+
67
+ @staticmethod
68
+ def _get_model_name(model_args: str) -> str:
69
+ """Extracts the model name from the model arguments."""
70
+
71
+ def extract_model_name(model_args: str, key: str) -> str:
72
+ """Extracts the model name from the model arguments using a key."""
73
+ args_after_key = model_args.split(key)[1]
74
+ return args_after_key.split(",")[0]
75
+
76
+ # order does matter, e.g. peft and delta are provided together with pretrained
77
+ prefixes = ["peft=", "delta=", "pretrained=", "model=", "path=", "engine="]
78
+ for prefix in prefixes:
79
+ if prefix in model_args:
80
+ return extract_model_name(model_args, prefix)
81
+ return ""
82
+
83
+ def log_experiment_args(
84
+ self,
85
+ model_source: str,
86
+ model_args: str,
87
+ system_instruction: str,
88
+ chat_template: str,
89
+ fewshot_as_multiturn: bool,
90
+ ) -> None:
91
+ """Logs model parameters and job ID."""
92
+ self.model_source = model_source
93
+ self.model_name = GeneralConfigTracker._get_model_name(model_args)
94
+ self.model_name_sanitized = sanitize_model_name(self.model_name)
95
+ self.system_instruction = system_instruction
96
+ self.system_instruction_sha = (
97
+ hash_string(system_instruction) if system_instruction else None
98
+ )
99
+ self.chat_template = chat_template
100
+ self.chat_template_sha = hash_string(chat_template) if chat_template else None
101
+ self.fewshot_as_multiturn = fewshot_as_multiturn
102
+
103
+ def log_end_time(self) -> None:
104
+ """Logs the end time of the evaluation and calculates the total evaluation time."""
105
+ self.end_time = time.perf_counter()
106
+ self.total_evaluation_time_seconds = str(self.end_time - self.start_time)
107
+
108
+
109
+ class EvaluationTracker:
110
+ """
111
+ Keeps track and saves relevant information of the evaluation process.
112
+ Compiles the data from trackers and writes it to files, which can be published to the Hugging Face hub if requested.
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ output_path: str = None,
118
+ hub_results_org: str = "",
119
+ hub_repo_name: str = "",
120
+ details_repo_name: str = "",
121
+ results_repo_name: str = "",
122
+ push_results_to_hub: bool = False,
123
+ push_samples_to_hub: bool = False,
124
+ public_repo: bool = False,
125
+ token: str = "",
126
+ leaderboard_url: str = "",
127
+ point_of_contact: str = "",
128
+ gated: bool = False,
129
+ ) -> None:
130
+ """
131
+ Creates all the necessary loggers for evaluation tracking.
132
+
133
+ Args:
134
+ output_path (str): Path to save the results. If not provided, the results won't be saved.
135
+ hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token.
136
+ hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`.
137
+ details_repo_name (str): The name of the Hugging Face repository to push the details to. If not provided, the results will be pushed to `lm-eval-results`.
138
+ result_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will not be pushed and will be found in the details_hub_repo.
139
+ push_results_to_hub (bool): Whether to push the results to the Hugging Face hub.
140
+ push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub.
141
+ public_repo (bool): Whether to push the results to a public or private repository.
142
+ token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`.
143
+ leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card.
144
+ point_of_contact (str): Contact information on the Hugging Face hub dataset card.
145
+ gated (bool): Whether to gate the repository.
146
+ """
147
+ self.general_config_tracker = GeneralConfigTracker()
148
+
149
+ self.output_path = output_path
150
+ self.push_results_to_hub = push_results_to_hub
151
+ self.push_samples_to_hub = push_samples_to_hub
152
+ self.public_repo = public_repo
153
+ self.leaderboard_url = leaderboard_url
154
+ self.point_of_contact = point_of_contact
155
+ self.api = HfApi(token=token) if token else None
156
+ self.gated_repo = gated
157
+
158
+ if not self.api and (push_results_to_hub or push_samples_to_hub):
159
+ raise ValueError(
160
+ "Hugging Face token is not defined, but 'push_results_to_hub' or 'push_samples_to_hub' is set to True. "
161
+ "Please provide a valid Hugging Face token by setting the HF_TOKEN environment variable."
162
+ )
163
+
164
+ if (
165
+ self.api
166
+ and hub_results_org == ""
167
+ and (push_results_to_hub or push_samples_to_hub)
168
+ ):
169
+ hub_results_org = self.api.whoami()["name"]
170
+ eval_logger.warning(
171
+ f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'."
172
+ )
173
+
174
+ if hub_repo_name == "":
175
+ details_repo_name = (
176
+ details_repo_name if details_repo_name != "" else "lm-eval-results"
177
+ )
178
+ results_repo_name = (
179
+ results_repo_name if results_repo_name != "" else details_repo_name
180
+ )
181
+ else:
182
+ details_repo_name = hub_repo_name
183
+ results_repo_name = hub_repo_name
184
+ eval_logger.warning(
185
+ "hub_repo_name was specified. Both details and results will be pushed to the same repository. Using hub_repo_name is no longer recommended, details_repo_name and results_repo_name should be used instead."
186
+ )
187
+
188
+ self.details_repo = f"{hub_results_org}/{details_repo_name}"
189
+ self.details_repo_private = f"{hub_results_org}/{details_repo_name}-private"
190
+ self.results_repo = f"{hub_results_org}/{results_repo_name}"
191
+ self.results_repo_private = f"{hub_results_org}/{results_repo_name}-private"
192
+
193
+ def save_results_aggregated(
194
+ self,
195
+ results: dict,
196
+ samples: dict,
197
+ ) -> None:
198
+ """
199
+ Saves the aggregated results and samples to the output path and pushes them to the Hugging Face hub if requested.
200
+
201
+ Args:
202
+ results (dict): The aggregated results to save.
203
+ samples (dict): The samples results to save.
204
+ """
205
+ self.general_config_tracker.log_end_time()
206
+
207
+ if self.output_path:
208
+ try:
209
+ eval_logger.info("Saving results aggregated")
210
+
211
+ # calculate cumulative hash for each task - only if samples are provided
212
+ task_hashes = {}
213
+ if samples:
214
+ for task_name, task_samples in samples.items():
215
+ sample_hashes = [
216
+ s["doc_hash"] + s["prompt_hash"] + s["target_hash"]
217
+ for s in task_samples
218
+ ]
219
+ task_hashes[task_name] = hash_string("".join(sample_hashes))
220
+
221
+ # update initial results dict
222
+ results.update({"task_hashes": task_hashes})
223
+ results.update(asdict(self.general_config_tracker))
224
+ dumped = json.dumps(
225
+ results,
226
+ indent=2,
227
+ default=handle_non_serializable,
228
+ ensure_ascii=False,
229
+ )
230
+
231
+ path = Path(self.output_path if self.output_path else Path.cwd())
232
+ self.date_id = datetime.now().isoformat().replace(":", "-")
233
+ if path.suffix == ".json":
234
+ path.parent.mkdir(parents=True, exist_ok=True)
235
+ file_results_aggregated = path.with_name(
236
+ f"{path.stem}_{self.date_id}.json"
237
+ )
238
+ else:
239
+ path.mkdir(parents=True, exist_ok=True)
240
+ file_results_aggregated = path.joinpath(
241
+ f"results_{self.date_id}.json"
242
+ )
243
+
244
+ file_results_aggregated.open("w", encoding="utf-8").write(dumped)
245
+
246
+ if self.api and self.push_results_to_hub:
247
+ repo_id = (
248
+ self.results_repo
249
+ if self.public_repo
250
+ else self.results_repo_private
251
+ )
252
+ self.api.create_repo(
253
+ repo_id=repo_id,
254
+ repo_type="dataset",
255
+ private=not self.public_repo,
256
+ exist_ok=True,
257
+ )
258
+ self.api.upload_file(
259
+ repo_id=repo_id,
260
+ path_or_fileobj=str(file_results_aggregated),
261
+ path_in_repo=os.path.join(
262
+ self.general_config_tracker.model_name,
263
+ file_results_aggregated.name,
264
+ ),
265
+ repo_type="dataset",
266
+ commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
267
+ )
268
+ eval_logger.info(
269
+ "Successfully pushed aggregated results to the Hugging Face Hub. "
270
+ f"You can find them at: {repo_id}"
271
+ )
272
+
273
+ except Exception as e:
274
+ eval_logger.warning("Could not save results aggregated")
275
+ eval_logger.info(repr(e))
276
+ else:
277
+ eval_logger.info(
278
+ "Output path not provided, skipping saving results aggregated"
279
+ )
280
+
281
+ def save_results_samples(
282
+ self,
283
+ task_name: str,
284
+ samples: dict,
285
+ ) -> None:
286
+ """
287
+ Saves the samples results to the output path and pushes them to the Hugging Face hub if requested.
288
+
289
+ Args:
290
+ task_name (str): The task name to save the samples for.
291
+ samples (dict): The samples results to save.
292
+ """
293
+ if self.output_path:
294
+ try:
295
+ eval_logger.info(f"Saving per-sample results for: {task_name}")
296
+
297
+ path = Path(self.output_path if self.output_path else Path.cwd())
298
+ if path.suffix == ".json":
299
+ path = path.parent
300
+ path.mkdir(parents=True, exist_ok=True)
301
+
302
+ file_results_samples = path.joinpath(
303
+ f"samples_{task_name}_{self.date_id}.jsonl"
304
+ )
305
+
306
+ for sample in samples:
307
+ # we first need to sanitize arguments and resps
308
+ # otherwise we won't be able to load the dataset
309
+ # using the datasets library
310
+ arguments = {}
311
+ for i, arg in enumerate(sample["arguments"]):
312
+ arguments[f"gen_args_{i}"] = {}
313
+ for j, tmp in enumerate(arg):
314
+ arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp
315
+
316
+ sample["resps"] = sanitize_list(sample["resps"])
317
+ sample["filtered_resps"] = sanitize_list(sample["filtered_resps"])
318
+ sample["arguments"] = arguments
319
+ sample["target"] = str(sample["target"])
320
+
321
+ sample_dump = (
322
+ json.dumps(
323
+ sample,
324
+ default=handle_non_serializable,
325
+ ensure_ascii=False,
326
+ )
327
+ + "\n"
328
+ )
329
+
330
+ with open(file_results_samples, "a", encoding="utf-8") as f:
331
+ f.write(sample_dump)
332
+
333
+ if self.api and self.push_samples_to_hub:
334
+ repo_id = (
335
+ self.details_repo
336
+ if self.public_repo
337
+ else self.details_repo_private
338
+ )
339
+ self.api.create_repo(
340
+ repo_id=repo_id,
341
+ repo_type="dataset",
342
+ private=not self.public_repo,
343
+ exist_ok=True,
344
+ )
345
+ try:
346
+ if self.gated_repo:
347
+ headers = build_hf_headers()
348
+ r = get_session().put(
349
+ url=f"https://huggingface.co/api/datasets/{repo_id}/settings",
350
+ headers=headers,
351
+ json={"gated": "auto"},
352
+ )
353
+ hf_raise_for_status(r)
354
+ except Exception as e:
355
+ eval_logger.warning("Could not gate the repository")
356
+ eval_logger.info(repr(e))
357
+ self.api.upload_folder(
358
+ repo_id=repo_id,
359
+ folder_path=str(path),
360
+ path_in_repo=self.general_config_tracker.model_name_sanitized,
361
+ repo_type="dataset",
362
+ commit_message=f"Adding samples results for {task_name} to {self.general_config_tracker.model_name}",
363
+ )
364
+ eval_logger.info(
365
+ f"Successfully pushed sample results for task: {task_name} to the Hugging Face Hub. "
366
+ f"You can find them at: {repo_id}"
367
+ )
368
+
369
+ except Exception as e:
370
+ eval_logger.warning("Could not save sample results")
371
+ eval_logger.info(repr(e))
372
+ else:
373
+ eval_logger.info("Output path not provided, skipping saving sample results")
374
+
375
+ def recreate_metadata_card(self) -> None:
376
+ """
377
+ Creates a metadata card for the evaluation results dataset and pushes it to the Hugging Face hub.
378
+ """
379
+
380
+ eval_logger.info("Recreating metadata card")
381
+ repo_id = self.details_repo if self.public_repo else self.details_repo_private
382
+
383
+ files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset")
384
+ results_files = get_results_filenames(files_in_repo)
385
+ sample_files = get_sample_results_filenames(files_in_repo)
386
+
387
+ # Build a dictionary to store the latest evaluation datetime for:
388
+ # - Each tested model and its aggregated results
389
+ # - Each task and sample results, if existing
390
+ # i.e. {
391
+ # "org__model_name__gsm8k": "2021-09-01T12:00:00",
392
+ # "org__model_name__ifeval": "2021-09-01T12:00:00",
393
+ # "org__model_name__results": "2021-09-01T12:00:00"
394
+ # }
395
+ latest_task_results_datetime = defaultdict(lambda: datetime.min.isoformat())
396
+
397
+ for file_path in sample_files:
398
+ file_path = Path(file_path)
399
+ filename = file_path.name
400
+ model_name = file_path.parent
401
+ task_name = get_file_task_name(filename)
402
+ results_datetime = get_file_datetime(filename)
403
+ task_name_sanitized = sanitize_task_name(task_name)
404
+ # Results and sample results for the same model and task will have the same datetime
405
+ samples_key = f"{model_name}__{task_name_sanitized}"
406
+ results_key = f"{model_name}__results"
407
+ latest_datetime = max(
408
+ latest_task_results_datetime[samples_key],
409
+ results_datetime,
410
+ )
411
+ latest_task_results_datetime[samples_key] = latest_datetime
412
+ latest_task_results_datetime[results_key] = max(
413
+ latest_task_results_datetime[results_key],
414
+ latest_datetime,
415
+ )
416
+
417
+ # Create metadata card
418
+ card_metadata = MetadataConfigs()
419
+
420
+ # Add the latest aggregated results to the metadata card for easy access
421
+ for file_path in results_files:
422
+ file_path = Path(file_path)
423
+ results_filename = file_path.name
424
+ model_name = file_path.parent
425
+ eval_date = get_file_datetime(results_filename)
426
+ eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
427
+ results_filename = Path("**") / Path(results_filename).name
428
+ config_name = f"{model_name}__results"
429
+ sanitized_last_eval_date_results = re.sub(
430
+ r"[^\w\.]", "_", latest_task_results_datetime[config_name]
431
+ )
432
+
433
+ if eval_date_sanitized == sanitized_last_eval_date_results:
434
+ # Ensure that all results files are listed in the metadata card
435
+ current_results = card_metadata.get(config_name, {"data_files": []})
436
+ current_results["data_files"].append(
437
+ {"split": eval_date_sanitized, "path": [str(results_filename)]}
438
+ )
439
+ card_metadata[config_name] = current_results
440
+ # If the results file is the newest, update the "latest" field in the metadata card
441
+ card_metadata[config_name]["data_files"].append(
442
+ {"split": "latest", "path": [str(results_filename)]}
443
+ )
444
+
445
+ # Add the tasks details configs
446
+ for file_path in sample_files:
447
+ file_path = Path(file_path)
448
+ filename = file_path.name
449
+ model_name = file_path.parent
450
+ task_name = get_file_task_name(filename)
451
+ eval_date = get_file_datetime(filename)
452
+ task_name_sanitized = sanitize_task_name(task_name)
453
+ eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
454
+ results_filename = Path("**") / Path(filename).name
455
+ config_name = f"{model_name}__{task_name_sanitized}"
456
+ sanitized_last_eval_date_results = re.sub(
457
+ r"[^\w\.]", "_", latest_task_results_datetime[config_name]
458
+ )
459
+ if eval_date_sanitized == sanitized_last_eval_date_results:
460
+ # Ensure that all sample results files are listed in the metadata card
461
+ current_details_for_task = card_metadata.get(
462
+ config_name, {"data_files": []}
463
+ )
464
+ current_details_for_task["data_files"].append(
465
+ {"split": eval_date_sanitized, "path": [str(results_filename)]}
466
+ )
467
+ card_metadata[config_name] = current_details_for_task
468
+ # If the samples results file is the newest, update the "latest" field in the metadata card
469
+ card_metadata[config_name]["data_files"].append(
470
+ {"split": "latest", "path": [str(results_filename)]}
471
+ )
472
+
473
+ # Get latest results and extract info to update metadata card examples
474
+ latest_datetime = max(latest_task_results_datetime.values())
475
+ latest_model_name = max(
476
+ latest_task_results_datetime, key=lambda k: latest_task_results_datetime[k]
477
+ )
478
+ last_results_file = [
479
+ f for f in results_files if latest_datetime.replace(":", "-") in f
480
+ ][0]
481
+ last_results_file_path = hf_hub_url(
482
+ repo_id=repo_id, filename=last_results_file, repo_type="dataset"
483
+ )
484
+ latest_results_file = load_dataset(
485
+ "json", data_files=last_results_file_path, split="train"
486
+ )
487
+ results_dict = latest_results_file["results"][0]
488
+ new_dictionary = {"all": results_dict}
489
+ new_dictionary.update(results_dict)
490
+ results_string = json.dumps(new_dictionary, indent=4)
491
+
492
+ dataset_summary = (
493
+ "Dataset automatically created during the evaluation run of model "
494
+ )
495
+ if self.general_config_tracker.model_source == "hf":
496
+ dataset_summary += f"[{self.general_config_tracker.model_name}](https://huggingface.co/{self.general_config_tracker.model_name})\n"
497
+ else:
498
+ dataset_summary += f"{self.general_config_tracker.model_name}\n"
499
+ dataset_summary += (
500
+ f"The dataset is composed of {len(card_metadata) - 1} configuration(s), each one corresponding to one of the evaluated task.\n\n"
501
+ f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each "
502
+ 'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n'
503
+ 'An additional configuration "results" store all the aggregated results of the run.\n\n'
504
+ "To load the details from a run, you can for instance do the following:\n"
505
+ )
506
+ if self.general_config_tracker.model_source == "hf":
507
+ dataset_summary += (
508
+ "```python\nfrom datasets import load_dataset\n"
509
+ f'data = load_dataset(\n\t"{repo_id}",\n\tname="{latest_model_name}",\n\tsplit="latest"\n)\n```\n\n'
510
+ )
511
+ dataset_summary += (
512
+ "## Latest results\n\n"
513
+ f"These are the [latest results from run {latest_datetime}]({last_results_file_path.replace('/resolve/', '/blob/')}) "
514
+ "(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. "
515
+ 'You find each in the results and the "latest" split for each eval):\n\n'
516
+ f"```python\n{results_string}\n```"
517
+ )
518
+ card_data = DatasetCardData(
519
+ dataset_summary=dataset_summary,
520
+ repo_url=f"https://huggingface.co/{self.general_config_tracker.model_name}",
521
+ pretty_name=f"Evaluation run of {self.general_config_tracker.model_name}",
522
+ leaderboard_url=self.leaderboard_url,
523
+ point_of_contact=self.point_of_contact,
524
+ )
525
+ card_metadata.to_dataset_card_data(card_data)
526
+ card = DatasetCard.from_template(
527
+ card_data,
528
+ pretty_name=card_data.pretty_name,
529
+ )
530
+ card.push_to_hub(repo_id, repo_type="dataset")
Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ import subprocess
5
+ from importlib.metadata import version
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ from torch.utils.collect_env import get_pretty_env_info
11
+ from transformers import __version__ as trans_version
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
18
+ """Remove the ',none' substring from the input_string if it exists at the end.
19
+
20
+ Args:
21
+ input_string (str): The input string from which to remove the ',none' substring.
22
+
23
+ Returns:
24
+ Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed
25
+ and a boolean indicating whether the modification was made (True) or not (False).
26
+ """
27
+ # Define the pattern to match ',none' at the end of the string
28
+ pattern = re.compile(r",none$")
29
+
30
+ # Use sub() to replace ',none' with an empty string
31
+ result = re.sub(pattern, "", input_string)
32
+
33
+ # check if the input_string changed
34
+ removed = result != input_string
35
+
36
+ return result, removed
37
+
38
+
39
+ def _handle_non_serializable(o: Any) -> Union[int, str, list]:
40
+ """Handle non-serializable objects by converting them to serializable types.
41
+
42
+ Args:
43
+ o (Any): The object to be handled.
44
+
45
+ Returns:
46
+ Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32,
47
+ it will be converted to int. If the object is of type set, it will be converted
48
+ to a list. Otherwise, it will be converted to str.
49
+ """
50
+ if isinstance(o, np.int64) or isinstance(o, np.int32):
51
+ return int(o)
52
+ elif isinstance(o, set):
53
+ return list(o)
54
+ else:
55
+ return str(o)
56
+
57
+
58
+ def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]:
59
+ try:
60
+ git_folder = Path(repo_path, ".git")
61
+ if git_folder.is_file():
62
+ git_folder = Path(
63
+ git_folder.parent,
64
+ git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1],
65
+ )
66
+ if Path(git_folder, "HEAD").exists():
67
+ head_name = (
68
+ Path(git_folder, "HEAD")
69
+ .read_text(encoding="utf-8")
70
+ .split("\n")[0]
71
+ .split(" ")[-1]
72
+ )
73
+ head_ref = Path(git_folder, head_name)
74
+ git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "")
75
+ else:
76
+ git_hash = None
77
+ except Exception as err:
78
+ logger.debug(
79
+ f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}"
80
+ )
81
+ return None
82
+ return git_hash
83
+
84
+
85
+ def get_git_commit_hash():
86
+ """
87
+ Gets the git commit hash of your current repo (if it exists).
88
+ Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
89
+ """
90
+ try:
91
+ git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
92
+ git_hash = git_hash.decode()
93
+ except (subprocess.CalledProcessError, FileNotFoundError):
94
+ # FileNotFoundError occurs when git not installed on system
95
+ git_hash = get_commit_from_path(os.getcwd()) # git hash of repo if exists
96
+ return git_hash
97
+
98
+
99
+ def add_env_info(storage: Dict[str, Any]):
100
+ try:
101
+ pretty_env_info = get_pretty_env_info()
102
+ except Exception as err:
103
+ pretty_env_info = str(err)
104
+ try:
105
+ dllm_eval_version = version("dllm_eval")
106
+ except Exception as err:
107
+ dllm_eval_version = str(err)
108
+ transformers_version = trans_version
109
+ upper_dir_commit = get_commit_from_path(
110
+ Path(os.getcwd(), "..")
111
+ ) # git hash of upper repo if exists
112
+ added_info = {
113
+ "pretty_env_info": pretty_env_info,
114
+ "transformers_version": transformers_version,
115
+ "dllm_eval_version": dllm_eval_version,
116
+ "upper_git_hash": upper_dir_commit, # in case this repo is submodule
117
+ }
118
+ storage.update(added_info)
119
+
120
+
121
+ def add_tokenizer_info(storage: Dict[str, Any], lm):
122
+ if getattr(lm, "tokenizer", False):
123
+ try:
124
+ tokenizer_info = {
125
+ "tokenizer_pad_token": [
126
+ lm.tokenizer.pad_token,
127
+ str(lm.tokenizer.pad_token_id),
128
+ ],
129
+ "tokenizer_eos_token": [
130
+ lm.tokenizer.eos_token,
131
+ str(lm.tokenizer.eos_token_id),
132
+ ],
133
+ "tokenizer_bos_token": [
134
+ lm.tokenizer.bos_token,
135
+ str(lm.tokenizer.bos_token_id),
136
+ ],
137
+ "eot_token_id": getattr(lm, "eot_token_id", None),
138
+ "max_length": getattr(lm, "max_length", None),
139
+ }
140
+ storage.update(tokenizer_info)
141
+ except Exception as err:
142
+ logger.debug(
143
+ f"Logging detailed tokenizer info failed with {err}, skipping..."
144
+ )
145
+ # seems gguf and textsynth do not have tokenizer
146
+ else:
147
+ logger.debug(
148
+ "LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results."
149
+ )
Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/wandb_logger.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import logging
4
+ from typing import Any, Dict, List, Literal, Tuple
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from packaging.version import Version
9
+
10
+ from dllm_eval.loggers.utils import _handle_non_serializable, remove_none_pattern
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def get_wandb_printer() -> Literal["Printer"]:
17
+ """Returns a wandb printer instance for pretty stdout."""
18
+ from wandb.sdk.lib.printer import new_printer
19
+
20
+ printer = new_printer()
21
+ return printer
22
+
23
+
24
+ class WandbLogger:
25
+ def __init__(self, init_args=None, config_args=None) -> None:
26
+ """Attaches to wandb logger if already initialized. Otherwise, passes init_args to wandb.init() and config_args to wandb.config.update()
27
+
28
+ Args:
29
+ init_args Optional[Dict]: Arguments for init configuration.
30
+ config_args Optional[Dict]: Arguments for config
31
+
32
+ Parse and log the results returned from evaluator.simple_evaluate() with:
33
+ wandb_logger.post_init(results)
34
+ wandb_logger.log_eval_result()
35
+ wandb_logger.log_eval_samples(results["samples"])
36
+ """
37
+ try:
38
+ import wandb
39
+
40
+ assert Version(wandb.__version__) >= Version("0.13.6")
41
+ if Version(wandb.__version__) < Version("0.13.6"):
42
+ wandb.require("report-editing:v0")
43
+ except Exception as e:
44
+ logger.warning(
45
+ "To use the wandb reporting functionality please install wandb>=0.13.6.\n"
46
+ "To install the latest version of wandb run `pip install wandb --upgrade`\n"
47
+ f"{e}"
48
+ )
49
+
50
+ self.wandb_args: Dict[str, Any] = init_args or {}
51
+ self.wandb_config_args: Dict[str, Any] = config_args or {}
52
+
53
+ # pop the step key from the args to save for all logging calls
54
+ self.step = self.wandb_args.pop("step", None)
55
+
56
+ # initialize a W&B run
57
+ if wandb.run is None:
58
+ self.run = wandb.init(**self.wandb_args)
59
+ if self.wandb_config_args:
60
+ self.run.config.update(self.wandb_config_args)
61
+ else:
62
+ self.run = wandb.run
63
+
64
+ self.printer = get_wandb_printer()
65
+
66
+ def post_init(self, results: Dict[str, Any]) -> None:
67
+ self.results: Dict[str, Any] = copy.deepcopy(results)
68
+ self.task_names: List[str] = list(results.get("results", {}).keys())
69
+ self.group_names: List[str] = list(results.get("groups", {}).keys())
70
+
71
+ def _get_config(self) -> Dict[str, Any]:
72
+ """Get configuration parameters."""
73
+ self.task_configs = self.results.get("configs", {})
74
+ cli_configs = self.results.get("config", {})
75
+ configs = {
76
+ "task_configs": self.task_configs,
77
+ "cli_configs": cli_configs,
78
+ }
79
+
80
+ return configs
81
+
82
+ def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]:
83
+ """Sanitize the results dictionary."""
84
+ _results = copy.deepcopy(self.results.get("results", dict()))
85
+
86
+ # Remove None from the metric string name
87
+ tmp_results = copy.deepcopy(_results)
88
+ for task_name in self.task_names:
89
+ task_result = tmp_results.get(task_name, dict())
90
+ for metric_name, metric_value in task_result.items():
91
+ _metric_name, removed = remove_none_pattern(metric_name)
92
+ if removed:
93
+ _results[task_name][_metric_name] = metric_value
94
+ _results[task_name].pop(metric_name)
95
+
96
+ # remove string valued keys from the results dict
97
+ wandb_summary = {}
98
+ for task in self.task_names:
99
+ task_result = _results.get(task, dict())
100
+ for metric_name, metric_value in task_result.items():
101
+ if isinstance(metric_value, str):
102
+ wandb_summary[f"{task}/{metric_name}"] = metric_value
103
+
104
+ for summary_metric, summary_value in wandb_summary.items():
105
+ _task, _summary_metric = summary_metric.split("/")
106
+ _results[_task].pop(_summary_metric)
107
+
108
+ tmp_results = copy.deepcopy(_results)
109
+ for task_name, task_results in tmp_results.items():
110
+ for metric_name, metric_value in task_results.items():
111
+ _results[f"{task_name}/{metric_name}"] = metric_value
112
+ _results[task_name].pop(metric_name)
113
+ for task in self.task_names:
114
+ _results.pop(task)
115
+
116
+ return wandb_summary, _results
117
+
118
+ def _log_results_as_table(self) -> None:
119
+ """Generate and log evaluation results as a table to W&B."""
120
+ columns = [
121
+ "Version",
122
+ "Filter",
123
+ "num_fewshot",
124
+ "Metric",
125
+ "Value",
126
+ "Stderr",
127
+ ]
128
+
129
+ def make_table(columns: List[str], key: str = "results"):
130
+ import wandb
131
+
132
+ table = wandb.Table(columns=columns)
133
+ results = copy.deepcopy(self.results)
134
+
135
+ for k, dic in results.get(key).items():
136
+ if k in self.group_names and not key == "groups":
137
+ continue
138
+ version = results.get("versions").get(k)
139
+ if version == "N/A":
140
+ version = None
141
+ n = results.get("n-shot").get(k)
142
+
143
+ for (mf), v in dic.items():
144
+ m, _, f = mf.partition(",")
145
+ if m.endswith("_stderr"):
146
+ continue
147
+ if m == "alias":
148
+ continue
149
+
150
+ if m + "_stderr" + "," + f in dic:
151
+ se = dic[m + "_stderr" + "," + f]
152
+ if se != "N/A":
153
+ se = "%.4f" % se
154
+ table.add_data(*[k, version, f, n, m, str(v), str(se)])
155
+ else:
156
+ table.add_data(*[k, version, f, n, m, str(v), ""])
157
+
158
+ return table
159
+
160
+ # log the complete eval result to W&B Table
161
+ table = make_table(["Tasks"] + columns, "results")
162
+ self.run.log({"evaluation/eval_results": table}, step=self.step)
163
+
164
+ if "groups" in self.results.keys():
165
+ table = make_table(["Groups"] + columns, "groups")
166
+ self.run.log({"evaluation/group_eval_results": table}, step=self.step)
167
+
168
+ def _log_results_as_artifact(self) -> None:
169
+ """Log results as JSON artifact to W&B."""
170
+ import wandb
171
+
172
+ dumped = json.dumps(
173
+ self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False
174
+ )
175
+ artifact = wandb.Artifact("results", type="eval_results")
176
+ with artifact.new_file("results.json", mode="w", encoding="utf-8") as f:
177
+ f.write(dumped)
178
+ self.run.log_artifact(artifact)
179
+
180
+ def log_eval_result(self) -> None:
181
+ """Log evaluation results to W&B."""
182
+ # Log configs to wandb
183
+ configs = self._get_config()
184
+ self.run.config.update(configs, allow_val_change=self.step is not None)
185
+
186
+ wandb_summary, self.wandb_results = self._sanitize_results_dict()
187
+ # update wandb.run.summary with items that were removed
188
+ self.run.summary.update(wandb_summary)
189
+ # Log the evaluation metrics to wandb
190
+ self.run.log(self.wandb_results, step=self.step)
191
+ # Log the evaluation metrics as W&B Table
192
+ self._log_results_as_table()
193
+ # Log the results dict as json to W&B Artifacts
194
+ self._log_results_as_artifact()
195
+
196
+ def _generate_dataset(
197
+ self, data: List[Dict[str, Any]], config: Dict[str, Any]
198
+ ) -> pd.DataFrame:
199
+ """Generate a dataset from evaluation data.
200
+
201
+ Args:
202
+ data (List[Dict[str, Any]]): The data to generate a dataset for.
203
+ config (Dict[str, Any]): The configuration of the task.
204
+
205
+ Returns:
206
+ pd.DataFrame: A dataframe that is ready to be uploaded to W&B.
207
+ """
208
+ ids = [x["doc_id"] for x in data]
209
+ labels = [x["target"] for x in data]
210
+ instance = [""] * len(ids)
211
+ resps = [""] * len(ids)
212
+ filtered_resps = [""] * len(ids)
213
+ model_outputs = {}
214
+
215
+ metrics_list = config["metric_list"]
216
+ metrics = {}
217
+ for metric in metrics_list:
218
+ metric = metric.get("metric")
219
+ if metric in ["word_perplexity", "byte_perplexity", "bits_per_byte"]:
220
+ metrics[f"{metric}_loglikelihood"] = [x[metric][0] for x in data]
221
+ if metric in ["byte_perplexity", "bits_per_byte"]:
222
+ metrics[f"{metric}_bytes"] = [x[metric][1] for x in data]
223
+ else:
224
+ metrics[f"{metric}_words"] = [x[metric][1] for x in data]
225
+ else:
226
+ metrics[metric] = [x[metric] for x in data]
227
+
228
+ if config["output_type"] == "loglikelihood":
229
+ instance = [x["arguments"][0][0] for x in data]
230
+ labels = [x["arguments"][0][1] for x in data]
231
+ resps = [
232
+ f"log probability of continuation is {x['resps'][0][0][0]} "
233
+ + "\n\n"
234
+ + "continuation will {} generated with greedy sampling".format(
235
+ "not be" if not x["resps"][0][0][1] else "be"
236
+ )
237
+ for x in data
238
+ ]
239
+ filtered_resps = [
240
+ f"log probability of continuation is {x['filtered_resps'][0][0]} "
241
+ + "\n\n"
242
+ + "continuation will {} generated with greedy sampling".format(
243
+ "not be" if not x["filtered_resps"][0][1] else "be"
244
+ )
245
+ for x in data
246
+ ]
247
+ elif config["output_type"] == "multiple_choice":
248
+ instance = [x["arguments"][0][0] for x in data]
249
+ choices = [
250
+ "\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])])
251
+ for x in data
252
+ ]
253
+ resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data]
254
+ filtered_resps = [
255
+ np.argmax([n[0] for n in x["filtered_resps"]]) for x in data
256
+ ]
257
+ elif config["output_type"] == "loglikelihood_rolling":
258
+ instance = [x["arguments"][0][0] for x in data]
259
+ resps = [x["resps"][0][0] for x in data]
260
+ filtered_resps = [x["filtered_resps"][0] for x in data]
261
+ elif config["output_type"] == "generate_until":
262
+ instance = [x["arguments"][0][0] for x in data]
263
+ resps = [x["resps"][0][0] for x in data]
264
+ filtered_resps = [x["filtered_resps"][0] for x in data]
265
+
266
+ model_outputs["raw_predictions"] = resps
267
+ model_outputs["filtered_predictions"] = filtered_resps
268
+
269
+ df_data = {
270
+ "id": ids,
271
+ "data": instance,
272
+ }
273
+ if config["output_type"] == "multiple_choice":
274
+ df_data["choices"] = choices
275
+
276
+ tmp_data = {
277
+ "input_len": [len(x) for x in instance],
278
+ "labels": labels,
279
+ "output_type": config["output_type"],
280
+ }
281
+ df_data.update(tmp_data)
282
+ df_data.update(model_outputs)
283
+ df_data.update(metrics)
284
+
285
+ return pd.DataFrame(df_data)
286
+
287
+ def _log_samples_as_artifact(
288
+ self, data: List[Dict[str, Any]], task_name: str
289
+ ) -> None:
290
+ import wandb
291
+
292
+ # log the samples as an artifact
293
+ dumped = json.dumps(
294
+ data,
295
+ indent=2,
296
+ default=_handle_non_serializable,
297
+ ensure_ascii=False,
298
+ )
299
+ artifact = wandb.Artifact(f"{task_name}", type="samples_by_task")
300
+ with artifact.new_file(
301
+ f"{task_name}_eval_samples.json", mode="w", encoding="utf-8"
302
+ ) as f:
303
+ f.write(dumped)
304
+ self.run.log_artifact(artifact)
305
+ # artifact.wait()
306
+
307
+ def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None:
308
+ """Log evaluation samples to W&B.
309
+
310
+ Args:
311
+ samples (Dict[str, List[Dict[str, Any]]]): Evaluation samples for each task.
312
+ """
313
+ task_names: List[str] = [
314
+ x for x in self.task_names if x not in self.group_names
315
+ ]
316
+
317
+ ungrouped_tasks = []
318
+ tasks_by_groups = {}
319
+
320
+ for task_name in task_names:
321
+ group_names = self.task_configs[task_name].get("group", None)
322
+ if group_names:
323
+ if isinstance(group_names, str):
324
+ group_names = [group_names]
325
+
326
+ for group_name in group_names:
327
+ if not tasks_by_groups.get(group_name):
328
+ tasks_by_groups[group_name] = [task_name]
329
+ else:
330
+ tasks_by_groups[group_name].append(task_name)
331
+ else:
332
+ ungrouped_tasks.append(task_name)
333
+
334
+ for task_name in ungrouped_tasks:
335
+ eval_preds = samples[task_name]
336
+
337
+ # log the samples as a W&B Table
338
+ df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
339
+ self.run.log({f"{task_name}_eval_results": df}, step=self.step)
340
+
341
+ # log the samples as a json file as W&B Artifact
342
+ self._log_samples_as_artifact(eval_preds, task_name)
343
+
344
+ for group, grouped_tasks in tasks_by_groups.items():
345
+ grouped_df = pd.DataFrame()
346
+ for task_name in grouped_tasks:
347
+ eval_preds = samples[task_name]
348
+ df = self._generate_dataset(
349
+ eval_preds, self.task_configs.get(task_name)
350
+ )
351
+ df["group"] = group
352
+ df["task"] = task_name
353
+ grouped_df = pd.concat([grouped_df, df], ignore_index=True)
354
+
355
+ # log the samples as a json file as W&B Artifact
356
+ self._log_samples_as_artifact(eval_preds, task_name)
357
+
358
+ self.run.log({f"{group}_eval_results": grouped_df}, step=self.step)
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/LLaDA.py ADDED
@@ -0,0 +1,786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from datetime import timedelta
4
+ from typing import Dict, List, Literal, Optional, Tuple, Union, TypeVar
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import transformers
9
+ import json
10
+ from accelerate import (
11
+ Accelerator,
12
+ InitProcessGroupKwargs,
13
+ )
14
+ from datasets import Dataset
15
+ from accelerate.utils import get_max_memory
16
+ from packaging import version
17
+ from tqdm import tqdm
18
+ import torch.distributed as dist
19
+ from transformers.models.auto.modeling_auto import (
20
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
21
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
22
+ )
23
+ from dllm_eval.api.instance import Instance
24
+ from dllm_eval.api.model import LM, TemplateLM
25
+ from dllm_eval.api.registry import register_model
26
+ from dllm_eval.models.utils import get_dtype, configure_pad_token
27
+
28
+ try:
29
+ from .hts_sampler import HTSSampler
30
+ except ImportError:
31
+ HTSSampler = None
32
+
33
+ eval_logger = logging.getLogger(__name__)
34
+ T = TypeVar("T", bound="LM")
35
+
36
+
37
+ def add_gumbel_noise(logits, temperature):
38
+ """Add Gumbel noise for sampling"""
39
+ if temperature == 0.0:
40
+ return logits
41
+ logits = logits.to(torch.float32)
42
+ noise = torch.rand_like(logits, dtype=torch.float32)
43
+ gumbel_noise = (-torch.log(noise)) ** temperature
44
+ return logits.exp() / gumbel_noise
45
+
46
+
47
+ def get_num_transfer_tokens(mask_index, steps):
48
+ """Calculate number of tokens to transfer at each step"""
49
+ mask_num = mask_index.sum(dim=1, keepdim=True)
50
+ base = mask_num // steps
51
+ remainder = mask_num % steps
52
+ num_transfer_tokens = base.expand(-1, steps).clone()
53
+ if remainder.sum() > 0:
54
+ indices = torch.arange(steps, device=mask_index.device)
55
+ mask = indices.unsqueeze(0) < remainder
56
+ num_transfer_tokens[mask] += 1
57
+ return num_transfer_tokens.to(torch.int64)
58
+
59
+
60
+ @torch.no_grad()
61
+ def generate_llada_v1(model, prompt, attention_mask=None, steps=128, gen_length=128,
62
+ block_length=128, temperature=0., cfg_scale=0.,
63
+ remasking='low_confidence', mask_id=126336,
64
+ logits_eos_inf=False, confidence_eos_eot_inf=False):
65
+ """
66
+ LLaDA v1 generation function
67
+ This is the original generate function from LLaDA v1
68
+ """
69
+ x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id,
70
+ dtype=torch.long).to(model.device)
71
+ x[:, :prompt.shape[1]] = prompt.clone()
72
+
73
+ if attention_mask is not None:
74
+ attention_mask = torch.cat([
75
+ attention_mask,
76
+ torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype,
77
+ device=model.device)
78
+ ], dim=-1)
79
+
80
+ prompt_index = (x != mask_id)
81
+
82
+ assert gen_length % block_length == 0
83
+ num_blocks = gen_length // block_length
84
+
85
+ assert steps % num_blocks == 0
86
+ steps_per_block = steps // num_blocks
87
+
88
+ for num_block in range(num_blocks):
89
+ block_mask_index = (x[:, prompt.shape[1] + num_block * block_length:
90
+ prompt.shape[1] + (num_block + 1) * block_length] == mask_id)
91
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)
92
+
93
+ for i in range(steps_per_block):
94
+ mask_index = (x == mask_id)
95
+
96
+ if cfg_scale > 0.:
97
+ un_x = x.clone()
98
+ un_x[prompt_index] = mask_id
99
+ x_ = torch.cat([x, un_x], dim=0)
100
+ if attention_mask is not None:
101
+ attention_mask_ = torch.cat([attention_mask, attention_mask], dim=0)
102
+ logits = model(x_, attention_mask=attention_mask_).logits
103
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
104
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
105
+ else:
106
+ logits = model(x, attention_mask=attention_mask).logits
107
+
108
+ if logits_eos_inf:
109
+ logits[:, :, 126081] = -torch.inf
110
+
111
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
112
+ x0 = torch.argmax(logits_with_noise, dim=-1)
113
+
114
+ if confidence_eos_eot_inf:
115
+ logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf
116
+
117
+ if remasking == 'low_confidence':
118
+ p = F.softmax(logits, dim=-1)
119
+ x0_p = torch.squeeze(
120
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
121
+ elif remasking == 'random':
122
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
123
+ else:
124
+ raise NotImplementedError(remasking)
125
+
126
+ x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
127
+
128
+ x0 = torch.where(mask_index, x0, x)
129
+ confidence = torch.where(mask_index, x0_p, -np.inf)
130
+
131
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
132
+ for j in range(confidence.shape[0]):
133
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
134
+ transfer_index[j, select_index] = True
135
+
136
+ x[transfer_index] = x0[transfer_index]
137
+
138
+ return x
139
+
140
+
141
+ @register_model("LLaDA")
142
+ class LLaDA(TemplateLM):
143
+ AUTO_MODEL_CLASS = transformers.AutoModel
144
+ _DEFAULT_MAX_LENGTH = 20480
145
+
146
+ def __init__(
147
+ self,
148
+ pretrained: Union[str, transformers.PreTrainedModel],
149
+ backend: Literal["default", "causal", "seq2seq"] = "causal",
150
+ revision: Optional[str] = "main",
151
+ subfolder: Optional[str] = None,
152
+ tokenizer: Optional[
153
+ Union[
154
+ str,
155
+ transformers.PreTrainedTokenizer,
156
+ transformers.PreTrainedTokenizerFast,
157
+ ]
158
+ ] = None,
159
+ truncation: Optional[bool] = False,
160
+ logits_cache: bool = True,
161
+ max_length: Optional[int] = None,
162
+ device: Optional[str] = "cuda",
163
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
164
+ batch_size: Optional[Union[int]] = 1,
165
+ max_batch_size: Optional[int] = 64,
166
+ trust_remote_code: Optional[bool] = True,
167
+ use_fast_tokenizer: Optional[bool] = True,
168
+ add_bos_token: Optional[bool] = False,
169
+ escape_until: Optional[bool] = False,
170
+ prefix_token_id: Optional[int] = None,
171
+ parallelize: Optional[bool] = False,
172
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
173
+ max_cpu_memory: Optional[Union[int, str]] = None,
174
+ offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
175
+ peft: Optional[str] = None,
176
+ delta: Optional[str] = None,
177
+ autogptq: Optional[Union[bool, str]] = False,
178
+ gptqmodel: Optional[bool] = False,
179
+ gguf_file: Optional[str] = None,
180
+ mc_num: int = 1024,
181
+ remasking: str = "low_confidence",
182
+ mask_id: int = 126336, # LLaDA v1 default mask_id
183
+ is_check_greedy: bool = True,
184
+ assistant_prefix: Optional[str] = None,
185
+ **kwargs,
186
+ ) -> None:
187
+ super().__init__()
188
+ self.mc_num = mc_num
189
+ self.mask_id = mask_id
190
+ self.remasking = remasking
191
+ self.pretrained = pretrained
192
+ self.is_check_greedy = is_check_greedy
193
+ self.assistant_prefix = assistant_prefix
194
+ self.add_bos_token = add_bos_token
195
+ self.escape_until = escape_until
196
+
197
+ if not isinstance(pretrained, str):
198
+ eval_logger.warning(
199
+ "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored."
200
+ )
201
+ assert not parallelize, (
202
+ "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
203
+ )
204
+ self._model = pretrained
205
+ self._device = self._model.device
206
+ self._config = self._model.config
207
+ gpus = 0
208
+ else:
209
+ assert isinstance(device, str)
210
+ assert isinstance(pretrained, str)
211
+ assert isinstance(batch_size, (int, str))
212
+ gpus = torch.cuda.device_count()
213
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
214
+ accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
215
+ if accelerator.num_processes > 1:
216
+ self.accelerator = accelerator
217
+ if "npu" in accelerator.device.type:
218
+ gpus = torch.npu.device_count()
219
+ if not (parallelize or accelerator.num_processes > 1):
220
+ device_list = set(
221
+ ["cuda", "cpu"]
222
+ + [f"cuda:{i}" for i in range(gpus)]
223
+ + ["mps", "mps:0"]
224
+ + [f"npu:{i}" for i in range(gpus)]
225
+ )
226
+ if device and device in device_list:
227
+ self._device = torch.device(device)
228
+ eval_logger.info(f"Using device '{device}'")
229
+ if device in ("mps", "mps:0") and version.parse(
230
+ torch.__version__
231
+ ) < version.parse("2.1"):
232
+ raise RuntimeError(
233
+ f"mps requires torch >= 2.1. You have {torch.__version__}"
234
+ )
235
+ else:
236
+ eval_logger.info("Device not specified")
237
+ eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
238
+ self._device = (
239
+ torch.device("cuda")
240
+ if torch.cuda.is_available()
241
+ else torch.device("cpu")
242
+ )
243
+ else:
244
+ if device != "cuda":
245
+ eval_logger.info(
246
+ f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
247
+ )
248
+ self._device = (
249
+ self.accelerator.device
250
+ if hasattr(self, "accelerator")
251
+ else torch.device(device)
252
+ )
253
+ revision = str(revision)
254
+ revision = revision + ("/" + subfolder if subfolder is not None else "")
255
+ self._get_config(
256
+ pretrained,
257
+ revision=revision,
258
+ trust_remote_code=trust_remote_code,
259
+ gguf_file=gguf_file,
260
+ )
261
+
262
+ self._get_backend(
263
+ config=self.config, backend=backend, trust_remote_code=trust_remote_code
264
+ )
265
+ self._create_tokenizer(
266
+ pretrained,
267
+ tokenizer,
268
+ revision=revision,
269
+ trust_remote_code=trust_remote_code,
270
+ use_fast_tokenizer=use_fast_tokenizer,
271
+ gguf_file=gguf_file,
272
+ add_bos_token=add_bos_token,
273
+ )
274
+
275
+ if isinstance(pretrained, str):
276
+ self._create_model(
277
+ pretrained=pretrained,
278
+ revision=revision,
279
+ dtype=dtype,
280
+ trust_remote_code=trust_remote_code,
281
+ parallelize=parallelize,
282
+ gpus=gpus,
283
+ max_memory_per_gpu=max_memory_per_gpu,
284
+ max_cpu_memory=max_cpu_memory,
285
+ offload_folder=offload_folder,
286
+ peft=peft,
287
+ delta=delta,
288
+ autogptq=autogptq,
289
+ gptqmodel=gptqmodel,
290
+ gguf_file=gguf_file,
291
+ **kwargs,
292
+ )
293
+
294
+ if isinstance(self.model, torch.nn.Module):
295
+ self.model.eval()
296
+ self.model.tie_weights()
297
+
298
+ self.truncation = truncation
299
+ self.logits_cache = logits_cache
300
+ self.vocab_size = self.tokenizer.vocab_size
301
+ self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
302
+ self.add_bos_token = add_bos_token
303
+
304
+ if "gemma" in getattr(self.config, "model_type", ""):
305
+ self.add_bos_token = True
306
+ eval_logger.info(
307
+ f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used."
308
+ )
309
+
310
+ self._max_length = max_length
311
+ self.pretrained = pretrained
312
+ self.delta = delta
313
+ self.peft = peft
314
+ self.revision = revision
315
+ self.batch_schedule = 1
316
+ self.batch_sizes = {}
317
+ self.max_batch_size = max_batch_size
318
+
319
+ if str(batch_size).startswith("auto"):
320
+ batch_size = batch_size.split(":")
321
+ self.batch_size_per_gpu = batch_size[0]
322
+ self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
323
+ else:
324
+ self.batch_size_per_gpu = int(batch_size)
325
+
326
+ if isinstance(pretrained, str):
327
+ if gpus >= 1 or str(self.device) == "mps":
328
+ if not (parallelize or autogptq or hasattr(self, "accelerator")):
329
+ try:
330
+ self.model.to(self.device)
331
+ except ValueError:
332
+ eval_logger.debug(
333
+ "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided."
334
+ )
335
+ if gpus > 1:
336
+ if hasattr(self, "accelerator") and self.accelerator.num_processes > 1:
337
+ if parallelize:
338
+ eval_logger.warning(
339
+ "You are both using a HF Accelerate `device_map` and launching via `accelerate launch`."
340
+ )
341
+ elif gpus > self.accelerator.num_processes:
342
+ eval_logger.warning(
343
+ "WARNING: The number of total system GPUs does not match the number of spawned processes."
344
+ )
345
+ self._device = torch.device(f"{self.accelerator.device}")
346
+ self._rank = self.accelerator.local_process_index
347
+ self._world_size = self.accelerator.num_processes
348
+ else:
349
+ self._rank = 0
350
+ self._world_size = 1
351
+ else:
352
+ self._rank = 0
353
+ self._world_size = 1
354
+ else:
355
+ eval_logger.warning(
356
+ "Passed an already-initialized model through `pretrained`, assuming single-process call."
357
+ )
358
+ self._rank = 0
359
+ self._world_size = 1
360
+
361
+ self.custom_prefix_token_id = prefix_token_id
362
+ if prefix_token_id is not None:
363
+ eval_logger.info(
364
+ f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
365
+ )
366
+ self.is_first_inference = True
367
+
368
+ if HTSSampler is not None:
369
+ self.hts_sampler = HTSSampler(self.model, self.tokenizer, device=self.device)
370
+ eval_logger.info("HTSSampler initialized successfully.")
371
+
372
+ # Copy all the property and helper methods from LLaDA2
373
+ @property
374
+ def rank(self):
375
+ if hasattr(self, "_rank"):
376
+ return self._rank
377
+ if hasattr(self, "accelerator"):
378
+ return self.accelerator.local_process_index
379
+ return int(os.environ.get("LOCAL_RANK", 0))
380
+
381
+ @property
382
+ def world_size(self):
383
+ if hasattr(self, "_world_size"):
384
+ return self._world_size
385
+ if hasattr(self, "accelerator"):
386
+ return self.accelerator.num_processes
387
+ return int(os.environ.get("WORLD_SIZE", 1))
388
+
389
+ def _get_accelerate_args(
390
+ self,
391
+ parallelize: Optional[bool] = None,
392
+ device_map: Optional[str] = "auto",
393
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
394
+ max_cpu_memory: Optional[Union[int, str]] = None,
395
+ offload_folder: Optional[str] = "./offload",
396
+ gpus: Optional[int] = None,
397
+ ) -> dict:
398
+ """Get accelerate arguments - same as LLaDA2"""
399
+ num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
400
+ if parallelize is None and gpus is not None and gpus > 1:
401
+ parallelize = True
402
+ args = {}
403
+ if parallelize:
404
+ max_memory_all_gpus = get_max_memory()
405
+ if "cpu" in max_memory_all_gpus:
406
+ del max_memory_all_gpus["cpu"]
407
+ max_memory_per_gpu_map = {
408
+ device_idx: max_memory_per_gpu for device_idx in range(len(max_memory_all_gpus))
409
+ } if max_memory_per_gpu is not None else {k: v for k, v in max_memory_all_gpus.items()}
410
+ if hasattr(self, "accelerator"):
411
+ max_memory_per_gpu_map = {
412
+ k: v for k, v in max_memory_all_gpus.items()
413
+ if k % num_local_processes == self.accelerator.process_index % num_local_processes
414
+ }
415
+ args["max_memory"] = max_memory_per_gpu_map
416
+ args["device_map"] = "auto"
417
+ args["offload_folder"] = offload_folder
418
+ if max_cpu_memory is not None:
419
+ args["max_memory"]["cpu"] = max_cpu_memory
420
+ else:
421
+ args["device_map"] = {"": str(self.device)}
422
+ return args
423
+
424
+ @property
425
+ def config(self):
426
+ return self._config
427
+
428
+ @property
429
+ def model(self):
430
+ if hasattr(self, "accelerator"):
431
+ return self.accelerator.unwrap_model(self._model)
432
+ else:
433
+ return self._model
434
+
435
+ @property
436
+ def eot_token_id(self):
437
+ return self.tokenizer.eos_token_id
438
+
439
+ @property
440
+ def prefix_token_id(self):
441
+ if self.custom_prefix_token_id is not None:
442
+ return self.custom_prefix_token_id
443
+ if self.tokenizer.bos_token_id is not None:
444
+ return self.tokenizer.bos_token_id
445
+ return self.tokenizer.eos_token_id
446
+
447
+ @property
448
+ def max_length(self):
449
+ if self._max_length:
450
+ return self._max_length
451
+ seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
452
+ for attr in seqlen_config_attrs:
453
+ if hasattr(self.model.config, attr):
454
+ return getattr(self.model.config, attr)
455
+ if hasattr(self.tokenizer, "model_max_length"):
456
+ if self.tokenizer.model_max_length > 1e10:
457
+ return self._DEFAULT_MAX_LENGTH
458
+ return self.tokenizer.model_max_length
459
+ return self._DEFAULT_MAX_LENGTH
460
+
461
+ @property
462
+ def max_gen_toks(self) -> int:
463
+ return 256
464
+
465
+ @property
466
+ def batch_size(self):
467
+ return self.batch_size_per_gpu
468
+
469
+ @property
470
+ def device(self):
471
+ return self._device
472
+
473
+ @property
474
+ def tokenizer_name(self) -> str:
475
+ return self.tokenizer.name_or_path.replace("/", "__")
476
+
477
+ def _get_backend(self, config, backend, trust_remote_code):
478
+ """Get backend type - same as LLaDA2"""
479
+ assert backend in ["default", "causal", "seq2seq"]
480
+ if backend != "default":
481
+ self.backend = backend
482
+ eval_logger.info(f"Overrode HF model backend type, and using type '{self.backend}'")
483
+ else:
484
+ if getattr(config, "model_type") in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
485
+ self.backend = "seq2seq"
486
+ elif getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
487
+ self.backend = "causal"
488
+ else:
489
+ eval_logger.warning("HF model type is neither CausalLM nor Seq2SeqLM. Assuming CausalLM.")
490
+ self.backend = "causal"
491
+
492
+ def _get_config(self, pretrained, revision, trust_remote_code, gguf_file):
493
+ """Get model config - same as LLaDA2"""
494
+ self._config = transformers.AutoConfig.from_pretrained(
495
+ pretrained, revision=revision, trust_remote_code=trust_remote_code
496
+ )
497
+
498
+ def _create_model(self, pretrained, revision, dtype, trust_remote_code, parallelize,
499
+ gpus, max_memory_per_gpu, max_cpu_memory, offload_folder,
500
+ peft, delta, autogptq, gptqmodel, gguf_file, **kwargs):
501
+ """Create model - same as LLaDA2"""
502
+ if autogptq or gptqmodel:
503
+ raise NotImplementedError("Quantization options are not implemented.")
504
+ model_dtype = get_dtype(dtype)
505
+ eval_logger.info(f"Loading model with dtype: {model_dtype}")
506
+ model_kwargs = kwargs if kwargs else {}
507
+ if not parallelize:
508
+ model_kwargs.update(
509
+ self._get_accelerate_args(
510
+ parallelize=parallelize,
511
+ gpus=gpus,
512
+ max_memory_per_gpu=max_memory_per_gpu,
513
+ max_cpu_memory=max_cpu_memory,
514
+ offload_folder=offload_folder
515
+ )
516
+ )
517
+ self._model = transformers.AutoModelForCausalLM.from_pretrained(
518
+ pretrained, revision=revision, torch_dtype=model_dtype,
519
+ trust_remote_code=trust_remote_code, **model_kwargs
520
+ )
521
+ if peft:
522
+ from peft import PeftModel
523
+ eval_logger.info(f"Loading PEFT model from {peft}")
524
+ self._model = PeftModel.from_pretrained(self._model, peft, torch_dtype=model_dtype)
525
+ if not parallelize:
526
+ self._model = self._model.to(self.device)
527
+ self._model = self._model.to(torch.bfloat16)
528
+ self._model.eval()
529
+
530
+ def _create_tokenizer(self, pretrained, tokenizer, revision, trust_remote_code,
531
+ use_fast_tokenizer, gguf_file, add_bos_token):
532
+ """Create tokenizer - same as LLaDA2"""
533
+ kwargs = {
534
+ "revision": revision,
535
+ "trust_remote_code": trust_remote_code,
536
+ "use_fast": use_fast_tokenizer
537
+ }
538
+ if add_bos_token:
539
+ kwargs["add_bos_token"] = True
540
+ if tokenizer:
541
+ if isinstance(tokenizer, str):
542
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer, **kwargs)
543
+ else:
544
+ self.tokenizer = tokenizer
545
+ else:
546
+ model_name = pretrained if isinstance(pretrained, str) else self.model.name_or_path
547
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, **kwargs)
548
+
549
+ def tok_encode(self, string, left_truncate_len=None, add_special_tokens=None):
550
+ """Tokenize string - same as LLaDA2"""
551
+ special_tokens_kwargs = {}
552
+ if add_special_tokens is None:
553
+ if self.backend == "causal":
554
+ special_tokens_kwargs["add_special_tokens"] = self.add_bos_token
555
+ else:
556
+ special_tokens_kwargs["add_special_tokens"] = add_special_tokens
557
+ encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
558
+ if left_truncate_len:
559
+ encoding = encoding[-left_truncate_len:]
560
+ return encoding
561
+
562
+ def tok_batch_encode(self, strings, padding_side="left", left_truncate_len=None, truncation=False):
563
+ """Batch tokenize - same as LLaDA2"""
564
+ old_padding_side = self.tokenizer.padding_side
565
+ self.tokenizer.padding_side = padding_side
566
+ add_special_tokens = {"add_special_tokens": self.add_bos_token} if self.backend == "causal" else {}
567
+ encoding = self.tokenizer(
568
+ strings, truncation=truncation, padding="longest",
569
+ return_tensors="pt", **add_special_tokens
570
+ )
571
+ if left_truncate_len and encoding["input_ids"].size(1) > left_truncate_len:
572
+ eval_logger.warning(f"Left-truncating from {encoding['input_ids'].size(1)} to {left_truncate_len} tokens.")
573
+ encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
574
+ encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:]
575
+ self.tokenizer.padding_side = old_padding_side
576
+ return encoding["input_ids"].to(self.device), encoding["attention_mask"].to(self.device)
577
+
578
+ def tok_decode(self, tokens, skip_special_tokens=False):
579
+ """Decode tokens - same as LLaDA2"""
580
+ return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
581
+
582
+ def _model_call(self, inps, attn_mask=None, labels=None):
583
+ """Model forward call - same as LLaDA2"""
584
+ with torch.no_grad():
585
+ if self.backend == "seq2seq":
586
+ return self.model(input_ids=inps, attention_mask=attn_mask, labels=labels).logits
587
+ else:
588
+ return self.model(inps, attention_mask=attn_mask).logits
589
+
590
+ def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
591
+ raise NotImplementedError
592
+
593
+ def loglikelihood_rolling(
594
+ self, requests: List[Instance], disable_tqdm: bool = False
595
+ ) -> List[float]:
596
+ raise NotImplementedError
597
+
598
+ def loglikelihood(self, requests):
599
+ raise NotImplementedError
600
+
601
+ def generate_until(self, requests: List[Instance]) -> List[str]:
602
+ """Generate until - adapted for LLaDA v1 """
603
+ res = []
604
+ gen_kwargs = requests[0].args[1]
605
+ use_hts = gen_kwargs.get("use_hts", False)
606
+
607
+ realtime_output = gen_kwargs.get("realtime_output", "realtime_hts_results.jsonl")
608
+ baseline_realtime_output = gen_kwargs.get("realtime_output", "realtime_baseline_results.jsonl")
609
+
610
+ if not use_hts and "realtime_output" not in gen_kwargs:
611
+ baseline_realtime_output = "realtime_baseline_results.jsonl"
612
+
613
+ if not use_hts:
614
+ bar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Running Baseline (LLaDA v1)")
615
+
616
+ for req in requests:
617
+ prompt_text = req.args[0]
618
+ local_gen_kwargs = req.args[1] if len(req.args) > 1 else {}
619
+
620
+ context_enc, _ = self.tok_batch_encode([prompt_text])
621
+
622
+ final_codes, stats = self.hts_sampler.generate_hts(
623
+ prompt_text=prompt_text,
624
+ input_ids=context_enc,
625
+ initial_N=1,
626
+ final_K=1,
627
+ hts_survivor_k=1,
628
+ hts_mode=False,
629
+ hts_start_pct=0.0,
630
+ hts_end_pct=0.0,
631
+ decay_factor=1.5,
632
+ pruning_interval=0,
633
+ reward_mode="confidence",
634
+ task_type=local_gen_kwargs.get("task_type", "code"),
635
+ steps=int(local_gen_kwargs.get("steps", 32)),
636
+ gen_length=int(local_gen_kwargs.get("gen_length", 512)),
637
+ block_length=int(local_gen_kwargs.get("block_length", 32)),
638
+ temperature=float(local_gen_kwargs.get("temperature", 0.0)),
639
+ top_p=float(local_gen_kwargs.get("top_p", 0.95)),
640
+ top_k=local_gen_kwargs.get("top_k", None),
641
+ threshold=float(local_gen_kwargs.get("threshold", 0.85)),
642
+ mask_id=self.mask_id,
643
+ eos_id=self.eot_token_id,
644
+ until=local_gen_kwargs.get("until", []),
645
+ )
646
+
647
+ processed_codes = []
648
+ for code in final_codes:
649
+ code = code.strip()
650
+ if not self.escape_until:
651
+ until_terms = local_gen_kwargs.get("until", [])
652
+ for term in until_terms:
653
+ if len(term) > 0 and term in code:
654
+ code = code.split(term)[0]
655
+ processed_codes.append(code)
656
+
657
+ final_choice = processed_codes[0] if processed_codes else ""
658
+ res.append(final_choice)
659
+
660
+ target_val = getattr(req, "target", None)
661
+ if target_val is None or target_val == "N/A":
662
+ if "test" in req.doc and "entry_point" in req.doc:
663
+ target_val = req.doc["test"] + "\ncheck(" + req.doc["entry_point"] + ")"
664
+ else:
665
+ target_val = req.doc.get("answer", req.doc.get("solution", "N/A"))
666
+
667
+ output_dir = os.path.dirname(baseline_realtime_output)
668
+ if output_dir:
669
+ os.makedirs(output_dir, exist_ok=True)
670
+ with open(baseline_realtime_output, "a", encoding="utf-8") as f:
671
+ all_resps = [[code] for code in processed_codes]
672
+ output_data = {
673
+ "doc": req.doc,
674
+ "target": target_val,
675
+ "resps": all_resps,
676
+ "prompt": prompt_text,
677
+ "entropy_history": stats.get("entropy_history", []),
678
+ "pruning_history": stats.get("pruning_history", []),
679
+ "final_scores": stats.get("final_scores", []),
680
+ "all_trajectories": stats.get("all_trajectories", []),
681
+ "nfe": stats.get("nfe", 0),
682
+ "first_block_nfe": stats.get("first_block_nfe", 0),
683
+ "svf_calls": stats.get("svf_calls", 0),
684
+ "total_steps": stats.get("total_steps", 0),
685
+ "num_gen_blocks": stats.get("num_gen_blocks", []),
686
+ "steps_per_block": stats.get("steps_per_block", [])
687
+ }
688
+ f.write(json.dumps(output_data, ensure_ascii=False) + "\n")
689
+ f.flush()
690
+
691
+ bar.update(1)
692
+ bar.close()
693
+
694
+ else:
695
+ bar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Running HTS+SVF (LLaDA v1)")
696
+ for req in requests:
697
+ prompt_text = req.args[0]
698
+ local_gen_kwargs = req.args[1] if len(req.args) > 1 else {}
699
+ context_enc, _ = self.tok_batch_encode([prompt_text])
700
+
701
+ p_interval = int(local_gen_kwargs.get("pruning_interval", 0))
702
+
703
+ final_codes, stats = self.hts_sampler.generate_hts(
704
+ prompt_text=prompt_text,
705
+ input_ids=context_enc,
706
+ initial_N=int(local_gen_kwargs.get("hts_N", 4)),
707
+ final_K=int(local_gen_kwargs.get("final_K", 1)),
708
+ hts_survivor_k=int(local_gen_kwargs.get("hts_survivor_k", 4)),
709
+ hts_mode=local_gen_kwargs.get("hts_mode", True),
710
+ hts_start_pct=float(local_gen_kwargs.get("hts_start_pct", 0.1)),
711
+ hts_end_pct=float(local_gen_kwargs.get("hts_end_pct", 0.6)),
712
+ decay_factor=float(local_gen_kwargs.get("decay_factor", 1.5)),
713
+ pruning_interval=p_interval,
714
+ reward_mode=local_gen_kwargs.get("reward_mode", "svf"),
715
+ task_type=local_gen_kwargs.get("task_type", "code"),
716
+ steps=int(local_gen_kwargs.get("steps", 32)),
717
+ gen_length=int(local_gen_kwargs.get("gen_length", 512)),
718
+ block_length=int(local_gen_kwargs.get("block_length", 32)),
719
+ temperature=float(local_gen_kwargs.get("temperature", 0.7)),
720
+ top_p=float(local_gen_kwargs.get("top_p", 0.95)),
721
+ top_k=local_gen_kwargs.get("top_k", None),
722
+ threshold=float(local_gen_kwargs.get("threshold", 0.85)),
723
+ mask_id=self.mask_id,
724
+ eos_id=self.eot_token_id,
725
+ until=local_gen_kwargs.get("until", []),
726
+ )
727
+
728
+ processed_codes = []
729
+ for code in final_codes:
730
+ code = code.strip()
731
+ if not self.escape_until:
732
+ until_terms = local_gen_kwargs.get("until", [])
733
+ for term in until_terms:
734
+ if len(term) > 0 and term in code:
735
+ code = code.split(term)[0]
736
+ processed_codes.append(code)
737
+
738
+ final_choice = processed_codes[0]
739
+ res.append(final_choice)
740
+
741
+ target_val = getattr(req, "target", None)
742
+ if target_val is None or target_val == "N/A":
743
+ if "test" in req.doc and "entry_point" in req.doc:
744
+ target_val = req.doc["test"] + "\ncheck(" + req.doc["entry_point"] + ")"
745
+ else:
746
+ target_val = req.doc.get("answer", req.doc.get("solution", "N/A"))
747
+
748
+ output_dir = os.path.dirname(realtime_output)
749
+ if output_dir:
750
+ os.makedirs(output_dir, exist_ok=True)
751
+ with open(realtime_output, "a", encoding="utf-8") as f:
752
+ all_resps = [[code] for code in processed_codes]
753
+ output_data = {
754
+ "doc": req.doc,
755
+ "target": target_val,
756
+ "resps": all_resps,
757
+ "prompt": prompt_text,
758
+ "entropy_history": stats.get("entropy_history", []),
759
+ "pruning_history": stats.get("pruning_history", []),
760
+ "final_scores": stats.get("final_scores", []),
761
+ "all_trajectories": stats.get("all_trajectories", []),
762
+ "nfe": stats.get("nfe", 0),
763
+ "first_block_nfe": stats.get("first_block_nfe", 0),
764
+ "svf_calls": stats.get("svf_calls", 0),
765
+ "total_steps": stats.get("total_steps", 0),
766
+ "num_gen_blocks": stats.get("num_gen_blocks", []),
767
+ "steps_per_block": stats.get("steps_per_block", [])
768
+ }
769
+ f.write(json.dumps(output_data, ensure_ascii=False) + "\n")
770
+ f.flush()
771
+
772
+ bar.update(1)
773
+ bar.close()
774
+
775
+ return res
776
+
777
+ def apply_chat_template(
778
+ self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
779
+ ) -> str:
780
+ """Apply chat template - same as LLaDA2"""
781
+ chat_templated = self.tokenizer.apply_chat_template(
782
+ chat_history, tokenize=False, add_generation_prompt=add_generation_prompt
783
+ )
784
+ if self.assistant_prefix:
785
+ chat_templated += self.assistant_prefix
786
+ return chat_templated
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import (
2
+ LLaDA,
3
+ huggingface,
4
+ )
5
+ # from .configuration_llada import LLaDAConfig
6
+ # from .modeling_llada import LLaDAModelLM
7
+
8
+
9
+ try:
10
+ # enable hf hub transfer if available
11
+ import hf_transfer # type: ignore # noqa
12
+ import huggingface_hub.constants # type: ignore
13
+
14
+ huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
15
+ except ImportError:
16
+ pass
17
+
18
+
19
+ # __all__ = ['LLaDAConfig', 'LLaDAModelLM']
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/dummy.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from tqdm import tqdm
4
+
5
+ from dllm_eval.api.model import LM
6
+ from dllm_eval.api.registry import register_model
7
+
8
+
9
+ @register_model("dummy")
10
+ class DummyLM(LM):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+ @classmethod
15
+ def create_from_arg_string(cls, arg_string, additional_config=None):
16
+ return cls()
17
+
18
+ def loglikelihood(self, requests, disable_tqdm: bool = False):
19
+ res = []
20
+
21
+ for _ in tqdm(requests, disable=disable_tqdm):
22
+ res.append((-random.random(), False))
23
+
24
+ return res
25
+
26
+ def generate_until(self, requests, disable_tqdm: bool = False):
27
+ res = []
28
+
29
+ for request in tqdm(requests, disable=disable_tqdm):
30
+ res.append("lol")
31
+ assert request.arguments[0].strip() != ""
32
+
33
+ return res
34
+
35
+ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
36
+ res = []
37
+
38
+ for _ in tqdm(requests, disable=disable_tqdm):
39
+ res.append(-random.random())
40
+
41
+ return res
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/hts_sampler.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from .verifier import CodeVerifier
5
+ import logging
6
+ import re
7
+ import math
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class HTSSampler:
12
+ def __init__(self, model, tokenizer, device="cuda"):
13
+ self.model = model
14
+ self.tokenizer = tokenizer
15
+ self.device = device
16
+ self.verifier = CodeVerifier(model, tokenizer, device)
17
+
18
+ def _get_num_transfer_tokens(self, block_length, steps):
19
+ if steps == 0: return torch.tensor([], dtype=torch.int64)
20
+ base = block_length // steps
21
+ remainder = block_length % steps
22
+ num_transfer_tokens = torch.full((steps,), base, dtype=torch.int64)
23
+ num_transfer_tokens[:remainder] += 1
24
+ return num_transfer_tokens
25
+
26
+ def _sample_with_temperature(self, logits, temperature, top_k, top_p):
27
+ logits = logits.to(torch.float32)
28
+
29
+ orig_probs = torch.softmax(logits, dim=-1)
30
+ x0_p, _ = torch.max(orig_probs, dim=-1)
31
+
32
+ if temperature > 0.0:
33
+ noise = torch.rand_like(logits, dtype=torch.float32)
34
+ gumbel_noise = -torch.log(-torch.log(noise + 1e-10) + 1e-10)
35
+ logits = logits / temperature + gumbel_noise
36
+
37
+ if top_k is not None and top_k > 0:
38
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
39
+ logits[indices_to_remove] = -float('Inf')
40
+
41
+ x0 = torch.argmax(logits, dim=-1)
42
+
43
+ return x0, x0_p
44
+
45
+ def _safe_scalar(self, val):
46
+ if isinstance(val, torch.Tensor):
47
+ if val.numel() > 1: return val.mean().item()
48
+ return val.item()
49
+ return float(val)
50
+
51
+ def _analyze_structure(self, text, task_type="code"):
52
+ score = 0.0
53
+ stripped = text.strip()
54
+ if task_type == "code":
55
+ if len(stripped) < 5: return -0.1
56
+ keywords = ["return", "print", "yield", "lambda", "class ", "def "]
57
+ if any(k in stripped for k in keywords): score += 0.05
58
+ if ":" in stripped: score += 0.02
59
+ if " " in text: score += 0.03
60
+ elif task_type == "math":
61
+ if "\\boxed{" in stripped: score += 0.1
62
+ if "The answer is" in stripped: score += 0.05
63
+ if len(stripped) < 10: return -0.1
64
+ if "Step" in stripped and stripped.count("Step") > 15: score -= 0.2
65
+ return score
66
+
67
+ def _chunked_forward(self, x, chunk_size=96, slice_indices=None):
68
+ total_batch = x.shape[0]
69
+ logits_list = []
70
+ for i in range(0, total_batch, chunk_size):
71
+ end_idx = min(i + chunk_size, total_batch)
72
+ sub_x = x[i:end_idx]
73
+ sub_mask = torch.ones_like(sub_x, device=self.device)
74
+ with torch.no_grad():
75
+ outputs = self.model(input_ids=sub_x, attention_mask=sub_mask)
76
+ sub_logits = outputs.logits
77
+ if slice_indices is not None:
78
+ s_start, s_end = slice_indices
79
+ sub_logits = sub_logits[:, s_start:s_end, :]
80
+ logits_list.append(sub_logits.detach().clone())
81
+ return torch.cat(logits_list, dim=0)
82
+
83
+ def _branch_and_resample(self, x, conf_scores, survivor_indices, target_width, mask_id,
84
+ prompt_length, resample_window=5, task_type="code"):
85
+ num_survivors = len(survivor_indices)
86
+ if num_survivors == 0: return x[:target_width].clone(), conf_scores[:target_width].clone()
87
+
88
+ if task_type == "math": resample_window = 6
89
+ elif task_type == "reasoning": resample_window = 6
90
+ elif task_type == "code": resample_window = 6
91
+
92
+ base_repeat = target_width // num_survivors
93
+ remainder = target_width % num_survivors
94
+ new_x_list = []
95
+ new_conf_list = []
96
+
97
+ for i in range(num_survivors):
98
+ count = base_repeat + (1 if i < remainder else 0)
99
+ if count == 0: continue
100
+
101
+ survivor_x = x[survivor_indices[i]]
102
+ survivor_conf = conf_scores[survivor_indices[i]]
103
+
104
+ new_x_list.append(survivor_x.unsqueeze(0))
105
+ new_conf_list.append(survivor_conf.unsqueeze(0))
106
+
107
+ if count > 1:
108
+ gen_part = survivor_x[prompt_length:]
109
+ gen_conf = survivor_conf[prompt_length:]
110
+ non_mask_indices = (gen_part != mask_id).nonzero(as_tuple=True)[0]
111
+
112
+ for _ in range(count - 1):
113
+ perturbed_x = survivor_x.clone()
114
+ perturbed_conf = survivor_conf.clone()
115
+
116
+ if len(non_mask_indices) > 0:
117
+ pool_size = min(resample_window * 2, len(non_mask_indices))
118
+ current_token_confs = gen_conf[non_mask_indices]
119
+
120
+ _, candidate_indices = torch.topk(current_token_confs, k=pool_size, largest=False)
121
+
122
+ num_to_perturb = min(resample_window, pool_size)
123
+ rand_indices = torch.randperm(pool_size, device=self.device)[:num_to_perturb]
124
+ selected_sub_indices = candidate_indices[rand_indices]
125
+
126
+ target_indices_in_x = prompt_length + non_mask_indices[selected_sub_indices]
127
+ perturbed_x[target_indices_in_x] = mask_id
128
+ perturbed_conf[target_indices_in_x] = 0.0
129
+
130
+ new_x_list.append(perturbed_x.unsqueeze(0))
131
+ new_conf_list.append(perturbed_conf.unsqueeze(0))
132
+
133
+ return torch.cat(new_x_list, dim=0), torch.cat(new_conf_list, dim=0)
134
+
135
+ @torch.no_grad()
136
+ def generate_hts(self, prompt_text, input_ids, problem_data=None,
137
+ initial_N=1, final_K=1, survivor_K=None,
138
+ prune_step_pct=0.0, reward_mode="confidence",
139
+ temperature=0.7, block_length=32, steps=64, gen_length=1024,
140
+ top_p=0.95, top_k=None, minimal_topk=1, threshold=0.9,
141
+ eos_id=156892, mask_id=156895,
142
+ hts_mode=False, hts_start_pct=0.1, hts_end_pct=0.6, decay_factor=1.5,
143
+ hts_survivor_k=4, task_type="code", until=None, pruning_interval=0):
144
+
145
+ input_ids = input_ids.to(self.device)
146
+ if input_ids.shape[0] == 1: input_ids = input_ids.repeat(initial_N, 1)
147
+
148
+ schedule_map = {}
149
+ ts_start, tr_end = 0, 0
150
+ if not hts_mode:
151
+ final_K_list = [final_K] if not isinstance(final_K, list) else final_K
152
+ prune_pct_list = [prune_step_pct] if not isinstance(prune_step_pct, list) else prune_step_pct
153
+ survivor_K_list = final_K_list if survivor_K is None else ([survivor_K] if not isinstance(survivor_K, list) else survivor_K)
154
+ if len(survivor_K_list) < len(final_K_list): survivor_K_list.extend(final_K_list[len(survivor_K_list):])
155
+ for pct, width, parents in zip(prune_pct_list, final_K_list, survivor_K_list):
156
+ if pct > 0:
157
+ s = int(steps * pct)
158
+ schedule_map[s] = (width, parents)
159
+ else:
160
+ final_K_list = [final_K] if not isinstance(final_K, int) else [final_K]
161
+ ts_start, tr_end = int(steps * hts_start_pct), int(steps * hts_end_pct)
162
+
163
+
164
+ prompt_length = input_ids.shape[1]
165
+ num_blocks = (gen_length + block_length - 1) // block_length
166
+ total_length = prompt_length + num_blocks * block_length
167
+
168
+ x = torch.full((initial_N, total_length), mask_id, dtype=torch.long, device=self.device)
169
+ x[:, :prompt_length] = input_ids.clone()
170
+
171
+ conf_scores = torch.zeros((initial_N, total_length), dtype=torch.float32, device=self.device)
172
+ conf_scores[:, :prompt_length] = 1.0
173
+
174
+ prefill_blocks = 0
175
+ num_gen_blocks = num_blocks
176
+ current_bsz = initial_N
177
+
178
+ next_allowed_pruning_step = ts_start if hts_mode else 0
179
+
180
+ stats = {
181
+ "initial_n": initial_N, "final_k": final_K_list[-1],
182
+ "pruning_history": [], "entropy_history": [], "nfe": 0.0,
183
+ "svf_calls": 0, "final_scores": [], "total_steps": steps,
184
+ "first_block_nfe": 0.0, "num_gen_blocks": [], "steps_per_block": []
185
+ }
186
+
187
+ for num_block in range(num_gen_blocks):
188
+ stats["num_gen_blocks"].append(num_block)
189
+
190
+ window_start = prompt_length + num_block * block_length
191
+ window_end = window_start + block_length
192
+
193
+ schedule = self._get_num_transfer_tokens(block_length, steps)
194
+
195
+ steps_this_block = 0
196
+ for step in range(steps):
197
+ steps_this_block += 1
198
+ cur_full_x = x[:current_bsz, :]
199
+
200
+ perform_pruning = False
201
+ num_parents_to_select = 0
202
+
203
+ if hts_mode and step >= next_allowed_pruning_step and step < tr_end:
204
+ target_width = max(final_K_list[-1], math.ceil(initial_N * (decay_factor ** -(step - ts_start))))
205
+ if current_bsz > target_width:
206
+ perform_pruning = True
207
+ num_parents_to_select = hts_survivor_k
208
+ elif not hts_mode and step in schedule_map:
209
+ target_width, num_parents_to_select = schedule_map[step]
210
+ if current_bsz > target_width: perform_pruning = True
211
+
212
+ if perform_pruning:
213
+ stats["nfe"] += current_bsz
214
+ if num_block == 0: stats["first_block_nfe"] += current_bsz
215
+ stats["svf_calls"] += current_bsz
216
+
217
+ gen_logits = self._chunked_forward(cur_full_x, chunk_size=64, slice_indices=(prompt_length, total_length))
218
+ rough_ids = torch.argmax(gen_logits, dim=-1)
219
+ rough_codes_snippet = self.tokenizer.batch_decode(rough_ids, skip_special_tokens=True)
220
+ candidates = []
221
+ for i in range(current_bsz):
222
+ full_code = rough_codes_snippet[i]
223
+ s = self._safe_scalar(self.verifier.get_reward(prompt_text, full_code, mode=reward_mode, problem_data=problem_data, current_logits=gen_logits[i] if reward_mode != "svf" else None, task_type=task_type))
224
+ s += self._analyze_structure(full_code, task_type=task_type)
225
+ clean_content = full_code.strip().replace(" ", "").replace("\n", "")
226
+ candidates.append({'score': s, 'idx': i, 'key': hash(clean_content[:200] + clean_content[-200:])})
227
+
228
+ stats["pruning_history"].append({"step": step, "scores": [c['score'] for c in candidates]})
229
+ candidates.sort(key=lambda x: x['score'], reverse=True)
230
+
231
+ selected_indices, seen_keys = [], set()
232
+ for cand in candidates:
233
+ if len(selected_indices) >= num_parents_to_select: break
234
+ if cand['key'] not in seen_keys:
235
+ selected_indices.append(cand['idx']); seen_keys.add(cand['key'])
236
+
237
+ if len(selected_indices) < num_parents_to_select:
238
+ for cand in candidates:
239
+ if len(selected_indices) >= num_parents_to_select: break
240
+ if cand['idx'] not in selected_indices: selected_indices.append(cand['idx'])
241
+
242
+ top_indices = torch.tensor(selected_indices, device=self.device)
243
+ x, conf_scores = self._branch_and_resample(x, conf_scores, top_indices, target_width, mask_id, prompt_length, task_type=task_type)
244
+
245
+ current_bsz = target_width
246
+ cur_full_x = x[:current_bsz, :]
247
+ next_allowed_pruning_step = step + 1 + pruning_interval
248
+
249
+ stats["nfe"] += current_bsz
250
+ if num_block == 0: stats["first_block_nfe"] += current_bsz
251
+
252
+ active_logits = self._chunked_forward(cur_full_x, chunk_size=32, slice_indices=(window_start, window_end))
253
+ active_logits[:, :, eos_id] = -1e10
254
+
255
+ x0, x0_p = self._sample_with_temperature(active_logits, temperature, top_k, top_p)
256
+
257
+ active_mask = x[:current_bsz, window_start:window_end] == mask_id
258
+
259
+ num_transfer = schedule[step].item()
260
+ confidence = torch.where(active_mask, x0_p, -torch.inf)
261
+ transfer_idx = torch.zeros_like(x0, dtype=torch.bool)
262
+
263
+ for b in range(current_bsz):
264
+ mask_count = active_mask[b].sum().item()
265
+ if mask_count > 0:
266
+ k_transfer = min(num_transfer, mask_count)
267
+ active_indices = torch.where(active_mask[b])[0]
268
+ high_conf_mask = (confidence[b] > threshold) & active_mask[b]
269
+ if high_conf_mask.sum().item() >= k_transfer:
270
+ conf_indices = torch.where(high_conf_mask)[0]
271
+ transfer_idx[b, conf_indices] = True
272
+ else:
273
+ _, topk_indices = torch.topk(confidence[b][active_indices], k=min(k_transfer, len(active_indices)))
274
+ transfer_idx[b, active_indices[topk_indices]] = True
275
+
276
+ if transfer_idx.any():
277
+ x[:current_bsz, window_start:window_end][transfer_idx] = x0[transfer_idx]
278
+ conf_scores[:current_bsz, window_start:window_end][transfer_idx] = x0_p[transfer_idx]
279
+
280
+ if task_type in ["math", "reasoning"]:
281
+ for b in range(current_bsz):
282
+ text_snippet = self.tokenizer.decode(x[b, prompt_length:window_end], skip_special_tokens=True)
283
+ should_stop = False
284
+ if task_type == "reasoning" and ("###" in text_snippet): should_stop = True
285
+ if task_type == "math" and ("\\boxed{" in text_snippet and "}" in text_snippet.split("\\boxed{")[-1]): should_stop = True
286
+
287
+ if should_stop:
288
+ after_mask = (x[b, window_start:total_length] == mask_id)
289
+ x[b, window_start:total_length][after_mask] = eos_id
290
+
291
+ stats["steps_per_block"].append(steps_this_block)
292
+ x = x[:current_bsz]
293
+
294
+ stats["nfe"] = int(round(stats["nfe"]))
295
+ stats["first_block_nfe"] = int(round(stats["first_block_nfe"]))
296
+
297
+ final_gen_tokens = x[:current_bsz, prompt_length:]
298
+ final_codes = self.tokenizer.batch_decode(final_gen_tokens, skip_special_tokens=True)
299
+ final_candidates = []
300
+
301
+ stats["svf_calls"] += len(final_codes)
302
+ for i in range(len(final_codes)):
303
+ txt = final_codes[i]
304
+ if until:
305
+ for term in until:
306
+ if term in txt: txt = txt.split(term)[0]
307
+ s = self._safe_scalar(self.verifier.get_reward(prompt_text, txt, mode=reward_mode, task_type=task_type))
308
+ s += self._analyze_structure(txt, task_type)
309
+ final_candidates.append({'resp': txt, 'score': s})
310
+
311
+ final_candidates.sort(key=lambda x: x['score'], reverse=True)
312
+ stats["final_scores"] = [c['score'] for c in final_candidates]
313
+ stats["all_trajectories"] = [{"rank": i+1, "resp": c['resp'], "score": c['score']} for i, c in enumerate(final_candidates)]
314
+
315
+ return [c['resp'] for c in final_candidates], stats
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/huggingface.py ADDED
@@ -0,0 +1,1489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from datetime import timedelta
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
7
+
8
+ import jinja2
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import transformers
12
+ from accelerate import (
13
+ Accelerator,
14
+ InitProcessGroupKwargs,
15
+ find_executable_batch_size,
16
+ )
17
+ from accelerate.utils import get_max_memory
18
+ from huggingface_hub import HfApi
19
+ from packaging import version
20
+ from peft import PeftModel
21
+ from peft import __version__ as PEFT_VERSION
22
+ from tqdm import tqdm
23
+ from transformers.models.auto.modeling_auto import (
24
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
25
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
26
+ )
27
+
28
+ from dllm_eval import utils
29
+ from dllm_eval.api.instance import Instance
30
+ from dllm_eval.api.model import TemplateLM
31
+ from dllm_eval.api.registry import register_model
32
+ from dllm_eval.models.utils import (
33
+ Collator,
34
+ clear_torch_cache,
35
+ configure_pad_token,
36
+ get_dtype,
37
+ handle_stop_sequences,
38
+ pad_and_concat,
39
+ stop_sequences_criteria,
40
+ )
41
+
42
+
43
+ eval_logger = logging.getLogger(__name__)
44
+
45
+
46
+ @register_model("hf-auto", "hf", "huggingface")
47
+ class HFLM(TemplateLM):
48
+ """
49
+ An abstracted Huggingface model class. Enables usage with both models of
50
+ `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
51
+
52
+ Supports data-parallel multi-GPU with HF Accelerate.
53
+ """
54
+
55
+ AUTO_MODEL_CLASS = None
56
+ _DEFAULT_MAX_LENGTH = 2048
57
+
58
+ def __init__(
59
+ self,
60
+ pretrained: Union[str, transformers.PreTrainedModel],
61
+ backend: Literal["default", "causal", "seq2seq"] = "default",
62
+ # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
63
+ revision: Optional[str] = "main",
64
+ subfolder: str = "",
65
+ tokenizer: Optional[
66
+ Union[
67
+ str,
68
+ transformers.PreTrainedTokenizer,
69
+ transformers.PreTrainedTokenizerFast,
70
+ ]
71
+ ] = None,
72
+ truncation: Optional[bool] = False,
73
+ logits_cache: bool = True,
74
+ max_length: Optional[int] = None,
75
+ device: Optional[str] = "cuda",
76
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
77
+ softmax_dtype: Optional[Union[str, torch.dtype]] = None,
78
+ batch_size: Optional[Union[int, str]] = 1,
79
+ max_batch_size: Optional[int] = 64,
80
+ trust_remote_code: Optional[bool] = False,
81
+ use_fast_tokenizer: Optional[bool] = True,
82
+ add_bos_token: Optional[bool] = False,
83
+ prefix_token_id: Optional[int] = None,
84
+ # arguments used for splitting a model across GPUs naively.
85
+ # only used if `parallelize=True`.
86
+ parallelize: Optional[bool] = False,
87
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
88
+ max_cpu_memory: Optional[Union[int, str]] = None,
89
+ offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
90
+ # PEFT, delta weights and quantization options
91
+ peft: Optional[str] = None,
92
+ delta: Optional[str] = None,
93
+ autogptq: Optional[Union[bool, str]] = False,
94
+ gptqmodel: Optional[bool] = False,
95
+ gguf_file: Optional[str] = None,
96
+ **kwargs,
97
+ ) -> None:
98
+ super().__init__()
99
+ # optionally: take in an already-initialized transformers.PreTrainedModel
100
+ if not isinstance(pretrained, str):
101
+ eval_logger.warning(
102
+ "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
103
+ )
104
+ assert not parallelize, (
105
+ "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
106
+ )
107
+ self._model = pretrained
108
+ self._device = self._model.device
109
+ self._config = self._model.config
110
+ gpus = 0
111
+
112
+ else:
113
+ assert isinstance(device, str)
114
+ assert isinstance(pretrained, str)
115
+ assert isinstance(batch_size, (int, str))
116
+
117
+ gpus = torch.cuda.device_count()
118
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
119
+ accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
120
+ if accelerator.num_processes > 1:
121
+ self.accelerator = accelerator
122
+
123
+ if "npu" in accelerator.device.type:
124
+ gpus = torch.npu.device_count()
125
+
126
+ # using one process with no model parallelism
127
+ if not (parallelize or accelerator.num_processes > 1):
128
+ # use user-passed device
129
+ device_list = set(
130
+ ["cuda", "cpu"]
131
+ + [f"cuda:{i}" for i in range(gpus)]
132
+ + ["mps", "mps:0"]
133
+ + [f"npu:{i}" for i in range(gpus)]
134
+ )
135
+ if device and device in device_list:
136
+ self._device = torch.device(device)
137
+ eval_logger.info(f"Using device '{device}'")
138
+ if device in ("mps", "mps:0") and version.parse(
139
+ torch.__version__
140
+ ) < version.parse("2.1"):
141
+ raise RuntimeError(
142
+ f"mps requires torch >= 2.1. You have {torch.__version__}"
143
+ )
144
+ else:
145
+ eval_logger.info("Device not specified")
146
+ eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
147
+ self._device = (
148
+ torch.device("cuda")
149
+ if torch.cuda.is_available()
150
+ else torch.device("cpu")
151
+ )
152
+ else: # Parallelism managed by accelerate
153
+ if device != "cuda":
154
+ eval_logger.info(
155
+ f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
156
+ )
157
+ # TODO: include in warning that `load_in_8bit` etc. affect this too
158
+ self._device = (
159
+ self.accelerator.device
160
+ if hasattr(self, "accelerator")
161
+ else torch.device(device)
162
+ )
163
+
164
+ revision = str(revision) # cast to string if not already one
165
+
166
+ self._get_config(
167
+ pretrained,
168
+ revision=revision,
169
+ trust_remote_code=trust_remote_code,
170
+ gguf_file=gguf_file,
171
+ subfolder=subfolder,
172
+ )
173
+
174
+ # determine which of 'causal' and 'seq2seq' backends to use for HF models
175
+ self._get_backend(
176
+ config=self.config, backend=backend, trust_remote_code=trust_remote_code
177
+ )
178
+
179
+ # load tokenizer so we know tokenizer vocabulary size before loading model and PEFT
180
+ self._create_tokenizer(
181
+ pretrained,
182
+ tokenizer,
183
+ revision=revision,
184
+ subfolder=subfolder,
185
+ trust_remote_code=trust_remote_code,
186
+ use_fast_tokenizer=use_fast_tokenizer,
187
+ gguf_file=gguf_file,
188
+ add_bos_token=add_bos_token,
189
+ )
190
+
191
+ # if we passed `pretrained` as a string, initialize our model now
192
+ if isinstance(pretrained, str):
193
+ self._create_model(
194
+ pretrained=pretrained,
195
+ revision=revision,
196
+ dtype=dtype,
197
+ trust_remote_code=trust_remote_code,
198
+ parallelize=parallelize,
199
+ gpus=gpus,
200
+ max_memory_per_gpu=max_memory_per_gpu,
201
+ max_cpu_memory=max_cpu_memory,
202
+ offload_folder=offload_folder,
203
+ peft=peft,
204
+ delta=delta,
205
+ autogptq=autogptq,
206
+ gptqmodel=gptqmodel,
207
+ gguf_file=gguf_file,
208
+ quantization_config=getattr(self.config, "quantization_config", None),
209
+ subfolder=subfolder,
210
+ **kwargs,
211
+ )
212
+
213
+ # access self._model through self.model property outside this method
214
+ if isinstance(self.model, torch.nn.Module):
215
+ self.model.eval()
216
+ self.model.tie_weights()
217
+
218
+ self.truncation = truncation
219
+ self.logits_cache = logits_cache
220
+ self.vocab_size = self.tokenizer.vocab_size
221
+ # select (or create) a pad token to use
222
+ self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
223
+
224
+ self.add_bos_token = add_bos_token
225
+ if "gemma" in getattr(self.config, "model_type", ""):
226
+ self.add_bos_token = True
227
+ eval_logger.info(
228
+ f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
229
+ )
230
+
231
+ self._max_length = max_length
232
+ self.pretrained = pretrained
233
+ self.delta = delta
234
+ self.peft = peft
235
+ self.revision = revision
236
+ self.batch_schedule = 1
237
+ self.batch_sizes = {}
238
+ self.max_batch_size = max_batch_size
239
+ self.softmax_dtype = (
240
+ get_dtype(softmax_dtype) if softmax_dtype is not None else None
241
+ )
242
+
243
+ if str(batch_size).startswith("auto"):
244
+ batch_size = batch_size.split(":")
245
+ self.batch_size_per_gpu = batch_size[0]
246
+ self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
247
+ else:
248
+ self.batch_size_per_gpu = int(batch_size)
249
+
250
+ if isinstance(pretrained, str):
251
+ if gpus >= 1 or str(self.device) == "mps":
252
+ # TODO: can remove this whole snippet except in the mps case, perhaps?
253
+ if not (parallelize or autogptq or hasattr(self, "accelerator")):
254
+ # place model onto device requested manually,
255
+ # if not using HF Accelerate or device_map
256
+ # or any other option that preloads model onto device
257
+ try:
258
+ self.model.to(self.device)
259
+ except ValueError:
260
+ eval_logger.debug(
261
+ "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
262
+ )
263
+ # multigpu data-parallel support when launched with accelerate
264
+ if gpus > 1:
265
+ if accelerator.num_processes > 1:
266
+ if parallelize:
267
+ eval_logger.warning(
268
+ "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
269
+ )
270
+ elif gpus > accelerator.num_processes:
271
+ eval_logger.warning(
272
+ "WARNING: The number of total system GPUs does not match the number of spawned processes. "
273
+ "If you would like to use data parallelism, please launch the script "
274
+ "with 'accelerate launch *script*'. "
275
+ f"Current run will proceed with {accelerator.num_processes} devices."
276
+ )
277
+ if self.accelerator.is_local_main_process:
278
+ eval_logger.info(
279
+ f"Using {gpus} devices with data parallelism"
280
+ )
281
+
282
+ self._device = torch.device(f"{accelerator.device}")
283
+ self.accelerator = accelerator
284
+
285
+ self._rank = self.accelerator.local_process_index
286
+ self._world_size = self.accelerator.num_processes
287
+ else:
288
+ # if we aren't launching via accelerate, ditch
289
+ self._rank = 0
290
+ self._world_size = 1
291
+ else:
292
+ # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
293
+ eval_logger.warning(
294
+ "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
295
+ )
296
+ self._rank = 0
297
+ self._world_size = 1
298
+
299
+ self.custom_prefix_token_id = prefix_token_id
300
+ if prefix_token_id is not None:
301
+ eval_logger.info(
302
+ f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
303
+ )
304
+
305
+ def _get_accelerate_args(
306
+ self,
307
+ parallelize: Optional[bool] = None,
308
+ device_map: Optional[str] = "auto",
309
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
310
+ max_cpu_memory: Optional[Union[int, str]] = None,
311
+ offload_folder: Optional[str] = "./offload",
312
+ gpus: Optional[int] = None,
313
+ ) -> dict:
314
+ """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
315
+ num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
316
+ num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes
317
+ if (
318
+ num_machines == 0
319
+ and hasattr(self, "accelerator")
320
+ and self.accelerator is not None
321
+ ):
322
+ eval_logger.info(
323
+ "We are not in a distributed setting for accelerate. Setting model_parallel to False."
324
+ )
325
+ parallelize = False
326
+
327
+ if parallelize is None:
328
+ # If parallelism is unset by the user, we automatically assign model parallelism
329
+ # if enough extra GPUs are available
330
+ max_memory_all_gpus = get_max_memory()
331
+ # We just want gpu, not cpu, max memory
332
+ if "cpu" in max_memory_all_gpus:
333
+ del max_memory_all_gpus["cpu"]
334
+ parallelize = bool(num_local_processes < len(max_memory_all_gpus))
335
+ eval_logger.info(
336
+ f"Setting model parallel to {parallelize} since "
337
+ f"the number of local processes is {num_local_processes} "
338
+ f"and the number of GPUs is {len(max_memory_all_gpus)}"
339
+ )
340
+
341
+ args = {}
342
+ if parallelize: # Model parallelism will be used
343
+ max_memory = {}
344
+ if max_memory_per_gpu is not None: # Using the provided memory requirements
345
+ max_memory_per_gpu_map = {
346
+ device_idx: max_memory_per_gpu for device_idx in range(gpus)
347
+ }
348
+ else: # Estimating the possible memory requirements
349
+ max_memory_all_gpus = get_max_memory()
350
+ if "cpu" in max_memory_all_gpus:
351
+ del max_memory_all_gpus["cpu"]
352
+ if not hasattr(self, "accelerator"):
353
+ max_memory_per_gpu_map = {
354
+ k: v for k, v in max_memory_all_gpus.items()
355
+ }
356
+ else:
357
+ # use only 1 / num_processes of the GPUs if we are running under accelerate launch
358
+ max_memory_per_gpu_map = {
359
+ k: v
360
+ for k, v in max_memory_all_gpus.items()
361
+ if k % num_local_processes
362
+ == (self.accelerator.process_index % num_local_processes)
363
+ }
364
+ args["max_memory"] = max_memory_per_gpu_map
365
+ args["device_map"] = "auto" if device_map is None else device_map
366
+ eval_logger.info(
367
+ f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}"
368
+ )
369
+
370
+ if max_cpu_memory is not None:
371
+ max_memory["cpu"] = max_cpu_memory
372
+
373
+ args["offload_folder"] = offload_folder
374
+ elif (
375
+ device_map is None
376
+ ): # No model parallelism, we use the default provided device for our model
377
+ if hasattr(self, "accelerator"):
378
+ device_map = {"": f"{self.accelerator.device}"}
379
+ else:
380
+ device_map = {"": str(self.device)}
381
+ args["max_memory"] = None
382
+ args["device_map"] = device_map
383
+ eval_logger.info(
384
+ f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}"
385
+ )
386
+ else:
387
+ args["max_memory"] = None
388
+ args["device_map"] = None
389
+ eval_logger.info("Model parallel was set to False.")
390
+
391
+ return args
392
+
393
+ @property
394
+ def config(self):
395
+ # return the associated transformers.AutoConfig for the given pretrained model.
396
+ return self._config
397
+
398
+ @property
399
+ def model(self):
400
+ # returns the model, unwrapping it if using Accelerate
401
+ if hasattr(self, "accelerator"):
402
+ return self.accelerator.unwrap_model(self._model)
403
+ else:
404
+ return self._model
405
+
406
+ @property
407
+ def eot_token_id(self):
408
+ # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
409
+ return self.tokenizer.eos_token_id
410
+
411
+ @property
412
+ def prefix_token_id(self):
413
+ # it is used as prefix for loglikelihood
414
+ if self.custom_prefix_token_id is not None:
415
+ return self.custom_prefix_token_id
416
+ if self.tokenizer.bos_token_id is not None:
417
+ return self.tokenizer.bos_token_id
418
+ return self.tokenizer.eos_token_id
419
+
420
+ @property
421
+ def max_length(self):
422
+ if self._max_length: # if max length manually set, return it
423
+ return self._max_length
424
+ seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
425
+ for attr in seqlen_config_attrs:
426
+ if hasattr(self.model.config, attr):
427
+ return getattr(self.model.config, attr)
428
+ if hasattr(self.tokenizer, "model_max_length"):
429
+ if self.tokenizer.model_max_length == 1000000000000000019884624838656:
430
+ return self._DEFAULT_MAX_LENGTH
431
+ return self.tokenizer.model_max_length
432
+ return self._DEFAULT_MAX_LENGTH
433
+
434
+ @property
435
+ def max_gen_toks(self) -> int:
436
+ return 256
437
+
438
+ @property
439
+ def batch_size(self):
440
+ return self.batch_size_per_gpu
441
+
442
+ @property
443
+ def device(self):
444
+ return self._device
445
+
446
+ @property
447
+ def rank(self):
448
+ return self._rank
449
+
450
+ @property
451
+ def world_size(self):
452
+ return self._world_size
453
+
454
+ @property
455
+ def tokenizer_name(self) -> str:
456
+ return self.tokenizer.name_or_path.replace("/", "__")
457
+
458
+ def _get_backend(
459
+ self,
460
+ config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
461
+ backend: Literal["default", "causal", "seq2seq"] = "default",
462
+ trust_remote_code: Optional[bool] = False,
463
+ ) -> None:
464
+ """
465
+ Helper method during initialization.
466
+ Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used.
467
+ sets `self.AUTO_MODEL_CLASS` appropriately if not already set.
468
+
469
+ **If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM,
470
+ user must set `self.backend` to be either "causal" or "seq2seq" manually!**
471
+ """
472
+
473
+ assert backend in ["default", "causal", "seq2seq"]
474
+
475
+ if backend != "default":
476
+ # if we've settled on non-default backend, use that manually
477
+ if backend == "causal":
478
+ self.backend = backend
479
+ elif backend == "seq2seq":
480
+ self.backend = backend
481
+ eval_logger.info(
482
+ f"Overrode HF model backend type, and using type '{self.backend}'"
483
+ )
484
+ else:
485
+ # determine and use the default HF backend for this model, based on its config + metadata.
486
+ if (
487
+ getattr(config, "model_type")
488
+ in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
489
+ ):
490
+ # first check if model type is listed under seq2seq models, since some
491
+ # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
492
+ # these special cases should be treated as seq2seq models.
493
+ self.backend = "seq2seq"
494
+ eval_logger.debug(f"Using model type '{self.backend}'")
495
+ elif (
496
+ getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
497
+ ):
498
+ self.backend = "causal"
499
+ eval_logger.debug(f"Using model type '{self.backend}'")
500
+ else:
501
+ if not trust_remote_code:
502
+ eval_logger.warning(
503
+ "HF model type is neither marked as CausalLM or Seq2SeqLM. \
504
+ This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
505
+ "Setting backend to causal"
506
+ )
507
+ # if model type is neither in HF transformers causal or seq2seq model registries
508
+ # then we default to assuming AutoModelForCausalLM
509
+ self.backend = "causal"
510
+ eval_logger.info(
511
+ f"Model type cannot be determined. Using default model type '{self.backend}'"
512
+ )
513
+
514
+ if self.AUTO_MODEL_CLASS is None:
515
+ if self.backend == "causal":
516
+ self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
517
+ elif self.backend == "seq2seq":
518
+ self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
519
+
520
+ def _get_config(
521
+ self,
522
+ pretrained: str,
523
+ revision: str = "main",
524
+ trust_remote_code: bool = False,
525
+ gguf_file: Optional[str] = None,
526
+ subfolder: str = "",
527
+ ) -> None:
528
+ """Return the model config for HuggingFace models"""
529
+ self._config = transformers.AutoConfig.from_pretrained(
530
+ pretrained,
531
+ revision=revision,
532
+ trust_remote_code=trust_remote_code,
533
+ gguf_file=gguf_file,
534
+ subfolder=subfolder,
535
+ )
536
+
537
+ def _create_model(
538
+ self,
539
+ pretrained: str,
540
+ revision: Optional[str] = "main",
541
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
542
+ trust_remote_code: Optional[bool] = False,
543
+ # arguments used for splitting a model across GPUs naively.
544
+ # only used if `parallelize=True`.
545
+ # (accelerate naive PP (device_map) options)
546
+ parallelize: Optional[bool] = False,
547
+ gpus: Optional[int] = None,
548
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
549
+ max_cpu_memory: Optional[Union[int, str]] = None,
550
+ offload_folder: Optional[str] = "./offload",
551
+ # PEFT, delta weights and quantization options
552
+ peft: Optional[str] = None,
553
+ delta: Optional[str] = None,
554
+ autogptq: Optional[Union[bool, str]] = False,
555
+ gptqmodel: Optional[bool] = False,
556
+ gguf_file: Optional[str] = None,
557
+ quantization_config: Optional[Dict[str, Any]] = None,
558
+ subfolder: str = "",
559
+ **kwargs,
560
+ ) -> None:
561
+ """
562
+ Initializes an HF or HF-compatible PreTrainedModel from scratch
563
+ inside HFLM, using the kwargs passed into self.__init__().
564
+
565
+ Also handles functionality such as AutoGPTQ usage and PEFT wrapping.
566
+
567
+ For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
568
+ (such as PyTorch models that are nearly, but not quite, fully mirroring
569
+ HF's public interface relied on in this HFLM class)
570
+ please consider subclassing HFLM and overriding this and other methods as needed.
571
+ """
572
+
573
+ model_kwargs = kwargs if kwargs else {}
574
+
575
+ model_kwargs.update(
576
+ self._get_accelerate_args(
577
+ parallelize=parallelize,
578
+ device_map=kwargs.get("device_map", None),
579
+ max_memory_per_gpu=max_memory_per_gpu,
580
+ max_cpu_memory=max_cpu_memory,
581
+ offload_folder=offload_folder,
582
+ gpus=gpus,
583
+ )
584
+ )
585
+
586
+ if not autogptq and not gptqmodel:
587
+ if model_kwargs.get("load_in_4bit", None):
588
+ assert transformers.__version__ >= "4.30.0", (
589
+ "load_in_4bit requires transformers >= 4.30.0"
590
+ )
591
+ if transformers.__version__ >= "4.30.0":
592
+ if model_kwargs.get("load_in_4bit", None):
593
+ if model_kwargs.get("bnb_4bit_compute_dtype", None):
594
+ model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
595
+ model_kwargs["bnb_4bit_compute_dtype"]
596
+ )
597
+
598
+ self._model = self.AUTO_MODEL_CLASS.from_pretrained(
599
+ pretrained,
600
+ revision=revision,
601
+ torch_dtype=get_dtype(dtype),
602
+ trust_remote_code=trust_remote_code,
603
+ gguf_file=gguf_file,
604
+ quantization_config=quantization_config,
605
+ subfolder=subfolder,
606
+ **model_kwargs,
607
+ )
608
+ else:
609
+ if autogptq and gptqmodel:
610
+ raise ValueError(
611
+ "Cannot use both 'autogptq' and 'gptqmodel' options at the same time."
612
+ )
613
+
614
+ if autogptq:
615
+ try:
616
+ from auto_gptq import AutoGPTQForCausalLM
617
+ except ModuleNotFoundError as exception:
618
+ raise type(exception)(
619
+ "Tried to load auto_gptq, but auto-gptq is not installed ",
620
+ "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
621
+ )
622
+
623
+ self._model = AutoGPTQForCausalLM.from_quantized(
624
+ pretrained,
625
+ trust_remote_code=trust_remote_code,
626
+ model_basename=None if autogptq is True else Path(autogptq).stem,
627
+ use_safetensors=True
628
+ if autogptq is True
629
+ else autogptq.endswith(".safetensors"),
630
+ **model_kwargs,
631
+ )
632
+
633
+ if gptqmodel:
634
+ try:
635
+ from gptqmodel import GPTQModel
636
+ except ModuleNotFoundError as exception:
637
+ raise type(exception)(
638
+ "Tried to load gptqmodel, but gptqmodel is not installed ",
639
+ "please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`",
640
+ )
641
+
642
+ self._model = GPTQModel.from_quantized(
643
+ pretrained, trust_remote_code=trust_remote_code, **model_kwargs
644
+ )
645
+
646
+ if peft and delta:
647
+ raise ValueError(
648
+ "Cannot use both 'peft' and 'delta' options at the same time."
649
+ )
650
+
651
+ if peft:
652
+ if model_kwargs.get("load_in_4bit", None):
653
+ if version.parse(PEFT_VERSION) < version.parse("0.4.0"):
654
+ raise AssertionError("load_in_4bit requires peft >= 0.4.0")
655
+ if self._model.config.vocab_size != len(self.tokenizer):
656
+ # resize model for LoRAs with added tokens
657
+ eval_logger.info(
658
+ f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
659
+ )
660
+ self._model.resize_token_embeddings(len(self.tokenizer))
661
+ self._model = PeftModel.from_pretrained(
662
+ self._model, peft, revision=revision
663
+ )
664
+ elif delta:
665
+ if autogptq:
666
+ eval_logger.warning(
667
+ "Delta weights might trigger unexpected behavior when used with AutoGPTQ."
668
+ )
669
+ _model_delta = self.AUTO_MODEL_CLASS.from_pretrained(
670
+ delta,
671
+ revision=revision,
672
+ torch_dtype=get_dtype(dtype),
673
+ trust_remote_code=trust_remote_code,
674
+ **model_kwargs,
675
+ )
676
+ for name, param in self._model.state_dict().items():
677
+ try:
678
+ param.data += _model_delta.state_dict()[name]
679
+ except KeyError:
680
+ raise KeyError(f"Delta model is missing weights for layer: {name}")
681
+ except Exception as e:
682
+ raise RuntimeError(
683
+ f"Failed to add delta weights to layer {name}. Error: {e}"
684
+ )
685
+
686
+ del _model_delta
687
+
688
+ return None
689
+
690
+ def _create_tokenizer(
691
+ self,
692
+ pretrained: Union[str, transformers.PreTrainedModel],
693
+ tokenizer: Optional[
694
+ Union[
695
+ str,
696
+ transformers.PreTrainedTokenizer,
697
+ transformers.PreTrainedTokenizerFast,
698
+ ]
699
+ ],
700
+ revision: Optional[str] = "main",
701
+ trust_remote_code: Optional[bool] = False,
702
+ use_fast_tokenizer: Optional[bool] = True,
703
+ gguf_file: Optional[str] = None,
704
+ add_bos_token: Optional[bool] = False,
705
+ subfolder: Optional[str] = "",
706
+ ) -> None:
707
+ """
708
+ Helper method during initialization.
709
+
710
+ Create a tokenizer object corresponding to the correct
711
+ tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
712
+ """
713
+ kwargs = {
714
+ "revision": revision,
715
+ "trust_remote_code": trust_remote_code,
716
+ }
717
+
718
+ # gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param
719
+ if gguf_file is not None:
720
+ kwargs["gguf_file"] = gguf_file
721
+ else:
722
+ kwargs["use_fast"] = use_fast_tokenizer
723
+
724
+ if add_bos_token:
725
+ kwargs["add_bos_token"] = True
726
+
727
+ if subfolder:
728
+ kwargs["subfolder"] = subfolder
729
+
730
+ if tokenizer:
731
+ if isinstance(tokenizer, str):
732
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
733
+ tokenizer, **kwargs
734
+ )
735
+ else:
736
+ assert isinstance(
737
+ tokenizer, transformers.PreTrainedTokenizer
738
+ ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
739
+ self.tokenizer = tokenizer
740
+ else:
741
+ # Get tokenizer based on 'pretrained'
742
+ if isinstance(pretrained, str):
743
+ model_name = pretrained
744
+ else:
745
+ # get the HF hub name via accessor on model
746
+ model_name = self.model.name_or_path
747
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
748
+ model_name, **kwargs
749
+ )
750
+ return None
751
+
752
+ def _detect_batch_size(self, requests=None, pos: int = 0):
753
+ if requests:
754
+ _, context_enc, continuation_enc = requests[pos]
755
+ max_length = len(
756
+ (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
757
+ )
758
+ max_context_enc = len(context_enc[-(self.max_length + 1) :])
759
+ max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
760
+ else:
761
+ max_length = self.max_length
762
+ max_context_enc = max_length
763
+ max_cont_enc = max_length
764
+
765
+ # if OOM, then halves batch_size and tries again
766
+ @find_executable_batch_size(starting_batch_size=self.max_batch_size)
767
+ def forward_batch(batch_size):
768
+ if self.backend == "seq2seq":
769
+ length = max(max_context_enc, max_cont_enc)
770
+ batched_conts = torch.ones(
771
+ (batch_size, length), device=self.device
772
+ ).long()
773
+ test_batch = torch.ones((batch_size, length), device=self.device).long()
774
+ call_kwargs = {
775
+ "attn_mask": test_batch,
776
+ "labels": batched_conts,
777
+ }
778
+ else:
779
+ call_kwargs = {}
780
+ test_batch = torch.ones(
781
+ (batch_size, max_length), device=self.device
782
+ ).long()
783
+ for _ in range(5):
784
+ out = F.log_softmax( # noqa: F841
785
+ self._model_call(test_batch, **call_kwargs),
786
+ dim=-1,
787
+ dtype=self.softmax_dtype,
788
+ )
789
+
790
+ return batch_size
791
+
792
+ try:
793
+ batch_size = forward_batch()
794
+ except RuntimeError as e:
795
+ if "No executable batch size found" in str(e):
796
+ batch_size = 1
797
+ else:
798
+ raise
799
+
800
+ if self.world_size > 1:
801
+ # if multi-GPU, always take minimum over all selected batch sizes
802
+ max_rnk_bs = torch.tensor([batch_size], device=self.device)
803
+ gathered = (
804
+ self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
805
+ )
806
+ batch_size = min(gathered)
807
+ clear_torch_cache()
808
+ return batch_size
809
+
810
+ clear_torch_cache()
811
+ return batch_size
812
+
813
+ def tok_encode(
814
+ self, string: str, left_truncate_len=None, add_special_tokens=None
815
+ ) -> List[int]:
816
+ """ """
817
+ # default for None - empty dict, use predefined tokenizer param
818
+ # used for all models except for CausalLM or predefined value
819
+ special_tokens_kwargs = {}
820
+
821
+ # by default for CausalLM - false or self.add_bos_token is set
822
+ if add_special_tokens is None:
823
+ if self.backend == "causal":
824
+ special_tokens_kwargs = {
825
+ "add_special_tokens": False or self.add_bos_token
826
+ }
827
+ # otherwise the method explicitly defines the value
828
+ else:
829
+ special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
830
+
831
+ encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
832
+
833
+ # left-truncate the encoded context to be at most `left_truncate_len` tokens long
834
+ if left_truncate_len:
835
+ encoding = encoding[-left_truncate_len:]
836
+
837
+ return encoding
838
+
839
+ def tok_batch_encode(
840
+ self,
841
+ strings: List[str],
842
+ padding_side: str = "left",
843
+ left_truncate_len: int = None,
844
+ truncation: bool = False,
845
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
846
+ # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
847
+ old_padding_side = self.tokenizer.padding_side
848
+ self.tokenizer.padding_side = padding_side
849
+
850
+ add_special_tokens = {}
851
+ if self.backend == "causal":
852
+ add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
853
+
854
+ encoding = self.tokenizer(
855
+ strings,
856
+ truncation=truncation,
857
+ padding="longest",
858
+ return_tensors="pt",
859
+ **add_special_tokens,
860
+ )
861
+ if left_truncate_len:
862
+ original_lengths = encoding["input_ids"].size(1)
863
+ if original_lengths > left_truncate_len:
864
+ eval_logger.warn(
865
+ f"Left truncation applied. Original sequence length was {original_lengths}, "
866
+ f"truncating to last {left_truncate_len} tokens. Some content will be lost.",
867
+ )
868
+ encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
869
+ encoding["attention_mask"] = encoding["attention_mask"][
870
+ :, -left_truncate_len:
871
+ ]
872
+ self.tokenizer.padding_side = old_padding_side
873
+
874
+ return encoding["input_ids"], encoding["attention_mask"]
875
+
876
+ def tok_decode(self, tokens, skip_special_tokens=True):
877
+ return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
878
+
879
+ def _model_call(self, inps, attn_mask=None, labels=None):
880
+ """
881
+ :param inps: torch.Tensor
882
+ A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
883
+ [batch, sequence_ctx]. the size of sequence may vary from call to call
884
+ :param attn_mask: torch.Tensor, optional
885
+ A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
886
+ (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
887
+ :param labels: torch.Tensor, optional
888
+ A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
889
+ (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
890
+ :return
891
+ A torch tensor of shape [batch, sequence, vocab] with the
892
+ logits returned from the model's decoder
893
+ """
894
+ with torch.no_grad():
895
+ if attn_mask is not None or labels is not None:
896
+ assert attn_mask is not None and labels is not None
897
+ assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
898
+ return self.model(
899
+ input_ids=inps, attention_mask=attn_mask, labels=labels
900
+ ).logits
901
+ else:
902
+ assert self.AUTO_MODEL_CLASS in (
903
+ transformers.AutoModelForCausalLM,
904
+ transformers.AutoModelForVision2Seq,
905
+ )
906
+ return self.model(inps).logits
907
+
908
+ def _model_generate(self, context, max_length, stop, **generation_kwargs):
909
+ # temperature = 0.0 if not set
910
+ # if do_sample is false and temp==0.0:
911
+ # remove temperature, as do_sample=False takes care of this
912
+ # and we don't want a warning from HF
913
+ generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
914
+ do_sample = generation_kwargs.get("do_sample", None)
915
+
916
+ # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
917
+ if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
918
+ generation_kwargs["do_sample"] = do_sample = False
919
+
920
+ if do_sample is False and generation_kwargs.get("temperature") == 0.0:
921
+ generation_kwargs.pop("temperature")
922
+ # build stopping criteria
923
+ stopping_criteria = stop_sequences_criteria(
924
+ self.tokenizer, stop, context.shape[1], context.shape[0]
925
+ )
926
+ return self.model.generate(
927
+ input_ids=context,
928
+ max_length=max_length,
929
+ stopping_criteria=stopping_criteria,
930
+ pad_token_id=self.tokenizer.pad_token_id,
931
+ use_cache=True,
932
+ **generation_kwargs,
933
+ )
934
+
935
+ def _select_cont_toks(
936
+ self, logits: torch.Tensor, contlen: int = None, inplen: int = None
937
+ ) -> torch.Tensor:
938
+ if self.backend == "causal":
939
+ assert contlen and inplen, (
940
+ "Must pass input len and cont. len to select scored logits for causal LM"
941
+ )
942
+ # discard right-padding.
943
+ # also discard the input/context tokens. we'll only score continuations.
944
+ logits = logits[inplen - contlen : inplen]
945
+ elif self.backend == "seq2seq":
946
+ assert contlen and not inplen, (
947
+ "Selecting scored logits for Seq2SeqLM requires only cont. len"
948
+ )
949
+ # only discard right-padding.
950
+ # the logits input to this fn only contain decoder-side tokens.
951
+ logits = logits[:contlen]
952
+
953
+ return logits
954
+
955
+ def loglikelihood_rolling(
956
+ self, requests: List[Instance], disable_tqdm: bool = False
957
+ ) -> List[float]:
958
+ adaptive_batch_size = None
959
+ if self.batch_size == "auto":
960
+ # using rolling window with maximum context
961
+ print("Passed argument batch_size = auto. Detecting largest batch size")
962
+ batch_size = self._detect_batch_size()
963
+ print(f"Determined Largest batch size: {batch_size}")
964
+ adaptive_batch_size = batch_size
965
+
966
+ # First, collect all windows from all requests
967
+ all_windows = [] # List of (request_idx, window) tuples
968
+ request_window_counts = [] # Track number of windows per request
969
+
970
+ for req_idx, (string,) in enumerate(
971
+ tqdm(
972
+ [req.args for req in requests],
973
+ disable=(disable_tqdm or (self.rank != 0)),
974
+ )
975
+ ):
976
+ rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
977
+ map(
978
+ utils.make_disjoint_window,
979
+ utils.get_rolling_token_windows(
980
+ token_list=self.tok_encode(string),
981
+ prefix_token=self.prefix_token_id,
982
+ max_seq_len=self.max_length,
983
+ context_len=1,
984
+ ),
985
+ )
986
+ )
987
+
988
+ # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
989
+ windows = [(None,) + x for x in rolling_token_windows]
990
+
991
+ # Store windows with their request index
992
+ all_windows.extend((req_idx, window) for window in windows)
993
+ request_window_counts.append(len(windows))
994
+
995
+ # Handle distributed case padding
996
+ pad_amnt = 0
997
+ if self.world_size > 1:
998
+ mytensor = torch.tensor(len(all_windows), device=self.device)
999
+ gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
1000
+ pad_amnt = max(gathered) - gathered[self.rank]
1001
+ if pad_amnt > 0:
1002
+ all_windows += pad_amnt * [all_windows[0]]
1003
+
1004
+ all_nlls = []
1005
+ batch_size = adaptive_batch_size or self.batch_size
1006
+ for i in range(0, len(all_windows), batch_size):
1007
+ batch = all_windows[i : i + batch_size]
1008
+ # Extract just the windows for processing, keeping track of request indices
1009
+ batch_indices, batch_windows = zip(*batch)
1010
+
1011
+ batch_nlls = self._loglikelihood_tokens(
1012
+ requests=batch_windows,
1013
+ disable_tqdm=False,
1014
+ override_bs=len(batch_windows),
1015
+ )
1016
+ # Store results with their request indices
1017
+ all_nlls.extend(zip(batch_indices, batch_nlls))
1018
+
1019
+ # Remove padding if necessary
1020
+ if (self.world_size > 1) and (pad_amnt > 0):
1021
+ all_nlls = all_nlls[:-pad_amnt]
1022
+
1023
+ # Reconstruct per-request loglikelihoods
1024
+ loglikelihoods = []
1025
+ current_idx = 0
1026
+ for window_count in request_window_counts:
1027
+ # Get all nlls for this request
1028
+ request_nlls = all_nlls[current_idx : current_idx + window_count]
1029
+ # Sum up the nlls for this request (discarding is_greedy)
1030
+ request_total = sum(nll[0] for _, nll in request_nlls)
1031
+ loglikelihoods.append(request_total)
1032
+ current_idx += window_count
1033
+
1034
+ string = requests[len(loglikelihoods) - 1].args[0]
1035
+ self.cache_hook.add_partial(
1036
+ "loglikelihood_rolling", (string,), request_total
1037
+ )
1038
+
1039
+ return loglikelihoods
1040
+
1041
+ def _batch_scheduler(self, pos, n_reordered_requests):
1042
+ sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
1043
+ if sched in self.batch_sizes:
1044
+ return self.batch_sizes[sched]
1045
+ if (len(self.batch_sizes) > 1) and (
1046
+ self.batch_sizes[sched - 1] == self.max_batch_size
1047
+ ):
1048
+ # if previous batch size is already maximal, skip recomputation
1049
+ self.batch_sizes[sched] = self.max_batch_size
1050
+ return self.batch_sizes[sched]
1051
+ print(
1052
+ f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
1053
+ )
1054
+ self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
1055
+ print(f"Determined largest batch size: {self.batch_sizes[sched]}")
1056
+ return self.batch_sizes[sched]
1057
+
1058
+ def _loglikelihood_tokens(
1059
+ self,
1060
+ requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
1061
+ disable_tqdm: bool = False,
1062
+ override_bs: int = None,
1063
+ ) -> List[Tuple[float, bool]]:
1064
+ # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
1065
+ res = []
1066
+
1067
+ def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
1068
+ """Defines the key for the sorted method"""
1069
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
1070
+ # - time estimates will always be over not underestimates, which is more useful for planning
1071
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
1072
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
1073
+ # automatic adaptive batches much much easier to implement
1074
+ # - any OOMs will happen right away rather than near the end
1075
+
1076
+ toks = req[1] + req[2]
1077
+ return -len(toks), tuple(toks)
1078
+
1079
+ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
1080
+ """Defines the key to group and lookup one-token continuations"""
1081
+ # Use with group_by="contexts" (optional)"
1082
+ # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
1083
+ # speeds up some multiple-choice tasks proportionally to the number of choices.
1084
+ # groups requests by context+continuation[:-1] and infer on one request/group.
1085
+ return req[-2] + req[-1][:-1]
1086
+
1087
+ re_ord = Collator(
1088
+ requests,
1089
+ sort_fn=_collate,
1090
+ group_by="contexts"
1091
+ if self.backend == "causal" and self.logits_cache
1092
+ else None,
1093
+ group_fn=_lookup_one_token_cont,
1094
+ )
1095
+
1096
+ # automatic (variable) batch size detection for vectorization
1097
+ # pull longest context sample from request
1098
+ n_reordered_requests = len(re_ord)
1099
+ batch_size = (
1100
+ self.batch_size
1101
+ if self.batch_size != "auto"
1102
+ else override_bs
1103
+ if override_bs is not None
1104
+ else 0
1105
+ )
1106
+ batch_fn = (
1107
+ self._batch_scheduler
1108
+ if self.batch_size == "auto"
1109
+ and n_reordered_requests > 0
1110
+ and not override_bs
1111
+ else None
1112
+ )
1113
+
1114
+ chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
1115
+ pbar = tqdm(
1116
+ total=len(requests),
1117
+ disable=(disable_tqdm or (self.rank != 0)),
1118
+ desc="Running loglikelihood requests",
1119
+ )
1120
+ for chunk in chunks:
1121
+ inps = []
1122
+ cont_toks_list = []
1123
+ inplens = []
1124
+
1125
+ conts = []
1126
+ encoder_attns = []
1127
+
1128
+ padding_len_inp = None
1129
+ padding_len_cont = None
1130
+ # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
1131
+ # tensors, then we pack them together into a batch, call the model, and then pick it all apart
1132
+ # again because vectorizing is annoying
1133
+
1134
+ for _, context_enc, continuation_enc in chunk:
1135
+ # sanity check
1136
+ assert len(context_enc) > 0
1137
+ assert len(continuation_enc) > 0
1138
+ assert len(continuation_enc) <= self.max_length
1139
+
1140
+ # how this all works (illustrated on a causal decoder-only setup):
1141
+ # CTX CONT
1142
+ # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
1143
+ # model \ \
1144
+ # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
1145
+ # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
1146
+
1147
+ # when too long to fit in context, truncate from the left
1148
+ if self.backend == "causal":
1149
+ total_length = len(context_enc) + len(continuation_enc)
1150
+ if total_length > self.max_length + 1:
1151
+ eval_logger.warning(
1152
+ f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
1153
+ f"exceeds model's maximum length ({self.max_length}). "
1154
+ f"Truncating {total_length - self.max_length + 1} tokens from the left."
1155
+ )
1156
+ inp = torch.tensor(
1157
+ (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
1158
+ dtype=torch.long,
1159
+ device=self.device,
1160
+ )
1161
+ (inplen,) = inp.shape
1162
+ elif self.backend == "seq2seq":
1163
+ inp = torch.tensor(
1164
+ (context_enc)[-self.max_length :],
1165
+ dtype=torch.long,
1166
+ device=self.device,
1167
+ )
1168
+ (inplen,) = inp.shape
1169
+
1170
+ # build encoder attn masks
1171
+ encoder_attns.append(torch.ones_like(inp))
1172
+
1173
+ cont = torch.tensor(
1174
+ (continuation_enc)[-self.max_length :],
1175
+ # TODO: left-shift these?
1176
+ # TODO: our code assumes we never end up truncating conts for either model type
1177
+ dtype=torch.long,
1178
+ device=self.device,
1179
+ )
1180
+ (contlen,) = cont.shape
1181
+
1182
+ conts.append(cont)
1183
+
1184
+ padding_len_cont = (
1185
+ max(padding_len_cont, contlen)
1186
+ if padding_len_cont is not None
1187
+ else contlen
1188
+ )
1189
+
1190
+ padding_len_inp = (
1191
+ max(padding_len_inp, inplen)
1192
+ if padding_len_inp is not None
1193
+ else inplen
1194
+ )
1195
+
1196
+ inps.append(inp) # [1, inp_length]
1197
+ cont_toks_list.append(continuation_enc)
1198
+ inplens.append(inplen)
1199
+
1200
+ # create encoder attn mask and batched conts, if seq2seq
1201
+ call_kwargs = {}
1202
+ if self.backend == "causal":
1203
+ batched_inps = pad_and_concat(
1204
+ padding_len_inp, inps, padding_side="right"
1205
+ ) # [batch, padding_len_inp]
1206
+ elif self.backend == "seq2seq":
1207
+ # TODO: left-pad encoder inps and mask?
1208
+ batched_inps = pad_and_concat(
1209
+ padding_len_inp, inps
1210
+ ) # [batch, padding_len_inp]
1211
+ batched_conts = pad_and_concat(
1212
+ padding_len_cont, conts
1213
+ ) # [batch, padding_len_cont]
1214
+ batched_encoder_mask = pad_and_concat(
1215
+ padding_len_inp, encoder_attns
1216
+ ) # [batch, padding_len_inp]
1217
+ call_kwargs = {
1218
+ "attn_mask": batched_encoder_mask,
1219
+ "labels": batched_conts,
1220
+ }
1221
+
1222
+ multi_logits = F.log_softmax(
1223
+ self._model_call(batched_inps, **call_kwargs),
1224
+ dim=-1,
1225
+ dtype=self.softmax_dtype,
1226
+ ) # [batch, padding_length (inp or cont), vocab]
1227
+
1228
+ for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
1229
+ chunk, multi_logits, inplens, cont_toks_list
1230
+ ):
1231
+ # Slice to original seq length
1232
+ contlen = len(cont_toks)
1233
+ # take only logits in the continuation
1234
+ # (discard context toks if decoder-only ; discard right-padding)
1235
+ # also discards + checks for "virtual tokens" in the causal LM's input window
1236
+ # from prompt/prefix tuning tokens, if applicable
1237
+ ctx_len = (
1238
+ inplen + (logits.shape[0] - padding_len_inp)
1239
+ if self.backend == "causal"
1240
+ else None
1241
+ )
1242
+ logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
1243
+ logits = logits.unsqueeze(0) # [1, seq, vocab]
1244
+
1245
+ # Check if per-token argmax is exactly equal to continuation
1246
+ greedy_tokens = logits.argmax(dim=-1)
1247
+
1248
+ # check for one-token continuation cache hits.
1249
+ # noop in case group_by != "contexts" or no cache hit and returns the
1250
+ # original args. Otherwise, expands the logits batch dimension and yields each
1251
+ # batch along with matching continuation tokens and prompt strings.
1252
+ # logits -> [1, seq, vocab]
1253
+ for request_str, cont_toks, logits in re_ord.get_cache(
1254
+ req_str=request_str,
1255
+ cxt_toks=ctx_tokens,
1256
+ cont_toks=cont_toks,
1257
+ logits=logits,
1258
+ ):
1259
+ cont_toks = torch.tensor(
1260
+ cont_toks, dtype=torch.long, device=self.device
1261
+ ).unsqueeze(0) # [1, seq]
1262
+ # Use trailing slice [-cont_toks.shape[1]:] to handle variable length cont_len (but same ctx+cont[:-1]).
1263
+ # i.e. continuations can be sliced at diff points. Collator ensures we have sufficient greedy_tokens
1264
+ # by choosing key with longest cont if group_by="contexts".
1265
+ max_equal = (
1266
+ greedy_tokens[:, -cont_toks.shape[1] :] == cont_toks
1267
+ ).all()
1268
+
1269
+ # Obtain log-probs at the corresponding continuation token indices
1270
+ # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
1271
+ logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
1272
+ -1
1273
+ ) # [1, seq]
1274
+
1275
+ # Answer: (log prob, is-exact-match)
1276
+ answer = (float(logits.sum()), bool(max_equal))
1277
+
1278
+ res.append(answer)
1279
+
1280
+ if request_str is not None:
1281
+ # special case: loglikelihood_rolling produces a number of loglikelihood requests
1282
+ # all with cache key None. instead do add_partial on the per-example level
1283
+ # in the loglikelihood_rolling() function for those.
1284
+ self.cache_hook.add_partial(
1285
+ "loglikelihood", request_str, answer
1286
+ )
1287
+ pbar.update(1)
1288
+
1289
+ pbar.close()
1290
+
1291
+ return re_ord.get_original(res)
1292
+
1293
+ def generate_until(
1294
+ self, requests: List[Instance], disable_tqdm: bool = False
1295
+ ) -> List[str]:
1296
+ res = []
1297
+
1298
+ def _collate(req: Tuple[str, dict]):
1299
+ """Defines the key for the sorted method"""
1300
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
1301
+ # - time estimates will always be over not underestimates, which is more useful for planning
1302
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
1303
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
1304
+ # automatic adaptive batches much much easier to implement
1305
+ # - any OOMs will happen right away rather than near the end
1306
+ toks = self.tok_encode(req[0])
1307
+ return -len(toks), req[0]
1308
+
1309
+ pbar = tqdm(
1310
+ total=len(requests),
1311
+ disable=(disable_tqdm or (self.rank != 0)),
1312
+ desc="Running generate_until requests",
1313
+ )
1314
+ adaptive_batch_size = None
1315
+ if self.batch_size == "auto":
1316
+ # using rolling window with maximum context
1317
+ print("Passed argument batch_size = auto. Detecting largest batch size")
1318
+ batch_size = self._detect_batch_size()
1319
+ print(f"Determined Largest batch size: {batch_size}")
1320
+ adaptive_batch_size = batch_size
1321
+ # for each different set of kwargs, we execute all requests, by batch.
1322
+ batch_size = (
1323
+ self.batch_size
1324
+ if self.batch_size != "auto"
1325
+ else adaptive_batch_size
1326
+ if adaptive_batch_size is not None
1327
+ else 0
1328
+ )
1329
+ batch_fn = (
1330
+ self._batch_scheduler
1331
+ if self.batch_size == "auto" and not adaptive_batch_size
1332
+ else None
1333
+ )
1334
+
1335
+ # we group requests by their generation_kwargs,
1336
+ # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
1337
+ # in the same batch.
1338
+ # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
1339
+ re_ords = Collator(
1340
+ [reg.args for reg in requests],
1341
+ sort_fn=_collate,
1342
+ group_by="gen_kwargs",
1343
+ group_fn=lambda x: x[1],
1344
+ )
1345
+ chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
1346
+ eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
1347
+ for chunk in chunks:
1348
+ contexts, all_gen_kwargs = zip(*chunk)
1349
+ # we assume all gen kwargs in the batch are the same
1350
+ # this is safe to assume because the `grouper` object ensures it.
1351
+ gen_kwargs = all_gen_kwargs[0]
1352
+ # unpack our keyword arguments.
1353
+ if isinstance(gen_kwargs, dict):
1354
+ kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
1355
+ # add EOS token to stop sequences
1356
+ until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
1357
+ else:
1358
+ raise ValueError(
1359
+ f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
1360
+ )
1361
+ if "max_gen_toks" in kwargs.keys():
1362
+ max_gen_toks = kwargs.pop("max_gen_toks")
1363
+ else:
1364
+ max_gen_toks = self.max_gen_toks
1365
+
1366
+ # set the max length in tokens of inputs ("context_enc")
1367
+ if self.backend == "causal":
1368
+ # max len for inputs = max length, minus room to generate the max new tokens
1369
+ max_ctx_len = self.max_length - max_gen_toks
1370
+ assert max_ctx_len > 0, (
1371
+ f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
1372
+ )
1373
+ elif self.backend == "seq2seq":
1374
+ # max len for inputs = encoder's whole max_length
1375
+ max_ctx_len = self.max_length
1376
+
1377
+ # encode, pad, and truncate contexts for this batch
1378
+ context_enc, attn_masks = self.tok_batch_encode(
1379
+ contexts,
1380
+ left_truncate_len=max_ctx_len,
1381
+ truncation=self.truncation,
1382
+ )
1383
+ context_enc = context_enc.to(self.device)
1384
+ attn_masks = attn_masks.to(self.device)
1385
+
1386
+ if "max_length" not in kwargs:
1387
+ kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
1388
+
1389
+ # perform batched generation
1390
+ cont = self._model_generate(
1391
+ context=context_enc,
1392
+ attention_mask=attn_masks,
1393
+ stop=until,
1394
+ **kwargs,
1395
+ )
1396
+
1397
+ cont_toks_list = cont.tolist()
1398
+ for cont_toks, context in zip(cont_toks_list, contexts):
1399
+ # discard context + left-padding toks if using causal decoder-only LM
1400
+ if self.backend == "causal":
1401
+ cont_toks = cont_toks[context_enc.shape[1] :]
1402
+
1403
+ s = self.tok_decode(cont_toks)
1404
+
1405
+ # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
1406
+ for term in until:
1407
+ if len(term) > 0:
1408
+ # ignore '' separator,
1409
+ # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
1410
+ s = s.split(term)[0]
1411
+
1412
+ res.append(s)
1413
+
1414
+ self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
1415
+ pbar.update(1)
1416
+ # reorder this group of results back to original unsorted form
1417
+ res = re_ords.get_original(res)
1418
+
1419
+ pbar.close()
1420
+
1421
+ return res
1422
+
1423
+ def apply_chat_template(
1424
+ self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
1425
+ ) -> str:
1426
+ """
1427
+ Method to apply a chat template to a list of chat history between user and model.
1428
+ """
1429
+ try:
1430
+ chat_templated = self.tokenizer.apply_chat_template(
1431
+ chat_history,
1432
+ tokenize=False,
1433
+ add_generation_prompt=add_generation_prompt,
1434
+ continue_final_message=not add_generation_prompt,
1435
+ )
1436
+ except jinja2.exceptions.TemplateError:
1437
+ eval_logger.warning(
1438
+ "Failed to apply chat template. removing the system role in chat history."
1439
+ )
1440
+ chat_history = [msg for msg in chat_history if msg["role"] != "system"]
1441
+ chat_templated = self.tokenizer.apply_chat_template(
1442
+ chat_history,
1443
+ tokenize=False,
1444
+ add_generation_prompt=add_generation_prompt,
1445
+ continue_final_message=not add_generation_prompt,
1446
+ )
1447
+
1448
+ return chat_templated
1449
+
1450
+ def get_model_info(self) -> dict:
1451
+ """
1452
+ Method to get Hugging Face model information for experiment reproducibility.
1453
+ """
1454
+
1455
+ def get_model_num_params(model) -> int:
1456
+ if hasattr(model, "num_parameters"):
1457
+ return model.num_parameters()
1458
+ if hasattr(model, "parameters"):
1459
+ return sum(p.numel() for p in model.parameters())
1460
+ else:
1461
+ return -1
1462
+
1463
+ def get_model_dtype(model) -> str:
1464
+ if hasattr(model, "dtype"):
1465
+ return model.dtype
1466
+ else:
1467
+ return ""
1468
+
1469
+ def get_model_sha(pretrained: str, revision: str) -> str:
1470
+ try:
1471
+ model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
1472
+ return model_info.sha
1473
+ except Exception as e:
1474
+ eval_logger.debug(
1475
+ f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
1476
+ )
1477
+ return ""
1478
+
1479
+ model_info = {
1480
+ "model_num_parameters": get_model_num_params(self._model),
1481
+ "model_dtype": get_model_dtype(self._model),
1482
+ "model_revision": self.revision,
1483
+ "model_sha": get_model_sha(self.pretrained, self.revision),
1484
+ }
1485
+ if self.peft:
1486
+ model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
1487
+ if self.delta:
1488
+ model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
1489
+ return model_info
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/utils.py ADDED
@@ -0,0 +1,854 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import fnmatch
3
+ import gc
4
+ import itertools
5
+ import logging
6
+ import time
7
+ from functools import wraps
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ Callable,
12
+ Dict,
13
+ Iterable,
14
+ Iterator,
15
+ List,
16
+ Literal,
17
+ Optional,
18
+ Tuple,
19
+ Type,
20
+ Union,
21
+ )
22
+
23
+ import torch
24
+ import transformers
25
+
26
+
27
+ eval_logger = logging.getLogger(__name__)
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from PIL import Image
32
+ from transformers import PreTrainedTokenizerBase
33
+ from transformers.configuration_utils import PretrainedConfig
34
+
35
+
36
+ def chunks(iter, n: int = 0, fn=None):
37
+ """
38
+ Divides an iterable into chunks of specified size or based on a given function.
39
+ Useful for batching
40
+
41
+ Parameters:
42
+ - iter: The input iterable to be divided into chunks.
43
+ - n: An integer representing the size of each chunk. Default is 0.
44
+ - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
45
+
46
+ Returns:
47
+ An iterator that yields chunks of the input iterable.
48
+
49
+ Example usage:
50
+ ```
51
+ data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
52
+ for chunk in chunks(data, 3):
53
+ print(chunk)
54
+ ```
55
+ Output:
56
+ ```
57
+ [1, 2, 3]
58
+ [4, 5, 6]
59
+ [7, 8, 9]
60
+ [10]
61
+ ```
62
+ """
63
+ arr = []
64
+ for i, x in enumerate(iter):
65
+ arr.append(x)
66
+ if len(arr) == (fn(i, iter) if fn else n):
67
+ yield arr
68
+ arr = []
69
+
70
+ if arr:
71
+ yield arr
72
+
73
+
74
+ class MultiChoice:
75
+ def __init__(self, choices) -> None:
76
+ self.choices = choices
77
+
78
+ # Simple wildcard support (linux filename patterns)
79
+ def __contains__(self, values) -> bool:
80
+ for value in values.split(","):
81
+ if len(fnmatch.filter(self.choices, value)) == 0:
82
+ eval_logger.info("Available tasks to choose:")
83
+ for choice in self.choices:
84
+ eval_logger.info(f" - {choice}")
85
+ raise ValueError("'{}' is not in task list".format(value))
86
+ return True
87
+
88
+ def __iter__(self) -> Iterator:
89
+ for choice in self.choices:
90
+ yield choice
91
+
92
+
93
+ class Grouper:
94
+ """
95
+ takes an array `arr` and function `fn` and returns a dictionary
96
+ with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all
97
+ objects in `arr` satisfying `key == fn(ob)`.
98
+ """
99
+
100
+ def __init__(self, arr, fn) -> None:
101
+ # self.orig_arr = arr
102
+ self.size = len(arr)
103
+ arr = list(enumerate(arr))
104
+
105
+ def group_return_dict(arr, fn):
106
+ res = collections.defaultdict(list)
107
+
108
+ for ob in arr:
109
+ res[fn(ob)].append(ob)
110
+ return res
111
+
112
+ arr = group_return_dict(arr, lambda x: fn(x[1]))
113
+
114
+ # self.arr has format Dict[Tuple[int, <entry from orig. arr>]]
115
+ self.arr = arr
116
+ self._grouped = None
117
+
118
+ def get_grouped(self):
119
+ # return the contents but not indices for our grouped dict.
120
+ if self._grouped:
121
+ return self._grouped
122
+ grouped = {}
123
+ for key in self.arr.keys():
124
+ # drop the index from each element of self.arr
125
+ grouped[key] = [y[1] for y in self.arr[key]]
126
+ self._grouped = grouped
127
+ return grouped
128
+
129
+ def get_original(self, grouped_dict):
130
+ # take in a grouped dictionary with e.g. results for each key listed
131
+ # in the same order as the instances in `self.arr`, and
132
+ # return the results in the same (single list) order as `self.orig_arr`.
133
+ res = [None] * self.size
134
+ cov = [False] * self.size
135
+ # orig = [None] * self.size
136
+
137
+ assert grouped_dict.keys() == self.arr.keys()
138
+
139
+ for key in grouped_dict.keys():
140
+ for (ind, _), v in zip(self.arr[key], grouped_dict[key]):
141
+ res[ind] = v
142
+ cov[ind] = True
143
+ # orig[ind] = _
144
+
145
+ assert all(cov)
146
+ # assert orig == self.orig_arr
147
+
148
+ return res
149
+
150
+
151
+ def pad_and_concat(
152
+ max_length: int,
153
+ tensors: List[torch.Tensor],
154
+ padding_side: Literal["right", "left"] = "right",
155
+ ):
156
+ """
157
+ Method for padding a list of tensors given the maximum tensor
158
+ length in the batch. Used for batching inputs and continuations in
159
+ seq2seq models.
160
+ """
161
+ assert padding_side == "left" or padding_side == "right", (
162
+ f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
163
+ )
164
+
165
+ for i, tensor in enumerate(tensors):
166
+ if len(tensor.shape) == 2:
167
+ tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
168
+ tensor_len = tensor.shape[0]
169
+ if tensor_len < max_length:
170
+ if padding_side == "right":
171
+ # right-pad
172
+ tensors[i] = torch.cat(
173
+ [
174
+ tensor, # [seq]
175
+ torch.zeros(
176
+ max_length - tensor_len,
177
+ dtype=torch.long,
178
+ device=tensor.device,
179
+ ), # [padding_length - seq]
180
+ ],
181
+ dim=0,
182
+ ).unsqueeze(0)
183
+ else:
184
+ # left-pad
185
+ tensors[i] = torch.cat(
186
+ [
187
+ torch.zeros(
188
+ max_length - tensor_len,
189
+ dtype=torch.long,
190
+ device=tensor.device,
191
+ ), # [padding_length - seq]
192
+ tensor, # [seq]
193
+ ],
194
+ dim=0,
195
+ ).unsqueeze(0)
196
+ else:
197
+ tensors[i] = tensor.unsqueeze(0)
198
+
199
+ return torch.cat(tensors, dim=0)
200
+
201
+
202
+ def clear_torch_cache() -> None:
203
+ gc.collect()
204
+ torch.cuda.empty_cache()
205
+
206
+
207
+ def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
208
+ """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
209
+ if isinstance(dtype, str) and dtype != "auto":
210
+ # Convert `str` args torch dtype: `float16` -> `torch.float16`
211
+ _torch_dtype = getattr(torch, dtype)
212
+ else:
213
+ _torch_dtype = dtype
214
+ return _torch_dtype
215
+
216
+
217
+ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
218
+ """Criteria to stop on the specified multi-token sequence."""
219
+
220
+ def __init__(
221
+ self,
222
+ sequence: str,
223
+ tokenizer: transformers.PreTrainedTokenizer,
224
+ initial_decoder_input_length: int,
225
+ batch_size: int,
226
+ ) -> None:
227
+ self.initial_decoder_input_length = initial_decoder_input_length
228
+ self.done_tracker = [False] * batch_size
229
+ self.sequence = sequence
230
+ self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
231
+ # print(sequence, self.sequence_ids)
232
+ # we look back for 2 more tokens than it takes to encode our stop sequence
233
+ # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
234
+ # and we don't want to mistakenly not stop a generation because our
235
+ # (string) stop sequence was output in a different tokenization
236
+
237
+ # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
238
+ # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
239
+ # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
240
+ self.sequence_id_len = len(self.sequence_ids) + 2
241
+ self.tokenizer = tokenizer
242
+
243
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
244
+ # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
245
+ lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]
246
+
247
+ lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
248
+
249
+ lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
250
+
251
+ for i, done in enumerate(self.done_tracker):
252
+ if not done:
253
+ self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
254
+ return False not in self.done_tracker
255
+
256
+
257
+ def stop_sequences_criteria(
258
+ tokenizer: transformers.PreTrainedTokenizer,
259
+ stop_sequences: List[str],
260
+ initial_decoder_input_length: int,
261
+ batch_size: int,
262
+ ) -> transformers.StoppingCriteriaList:
263
+ return transformers.StoppingCriteriaList(
264
+ [
265
+ *[
266
+ MultiTokenEOSCriteria(
267
+ sequence, tokenizer, initial_decoder_input_length, batch_size
268
+ )
269
+ for sequence in stop_sequences
270
+ ],
271
+ ]
272
+ )
273
+
274
+
275
+ def undistribute(iterable):
276
+ """
277
+ Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute .
278
+
279
+ Re-interleaves results that have been split using more_itertools.distribute:
280
+ >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
281
+ >>> list(group_1)
282
+ [1, 3, 5]
283
+ >>> list(group_2)
284
+ [2, 4, 6]
285
+ >>> undistribute([group_1, group_2])
286
+ [1, 2, 3, 4, 5, 6]
287
+
288
+ Handles non-uniform component lengths:
289
+
290
+ >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
291
+ >>> [list(c) for c in children]
292
+ [[1, 4, 7], [2, 5], [3, 6]]
293
+ >>> undistribute(children)
294
+ [1, 2, 3, 4, 5, 6, 7]
295
+
296
+ Also handles when some iterables are empty:
297
+
298
+ >>> children = distribute(5, [1, 2, 3])
299
+ >>> [list(c) for c in children]
300
+ [[1], [2], [3], [], []]
301
+ >>> undistribute(children)
302
+ [1, 2, 3]
303
+
304
+ """
305
+
306
+ return [
307
+ x
308
+ for x in itertools.chain.from_iterable(
309
+ itertools.zip_longest(*[list(x) for x in iterable])
310
+ )
311
+ if x is not None
312
+ ]
313
+
314
+
315
+ def retry_on_specific_exceptions(
316
+ on_exceptions: List[Type[Exception]],
317
+ max_retries: Optional[int] = None,
318
+ backoff_time: float = 3.0,
319
+ backoff_multiplier: float = 1.5,
320
+ on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
321
+ ):
322
+ """Retry on an LLM Provider's rate limit error with exponential backoff
323
+ For example, to use for OpenAI, do the following:
324
+ ```
325
+ from openai import RateLimitError
326
+
327
+ # Recommend specifying max_retries to avoid infinite loops!
328
+ @retry_on_specific_exceptions([RateLimitError], max_retries=3)
329
+ def completion(...):
330
+ # Wrap OpenAI completion function here
331
+ ...
332
+ ```
333
+ """
334
+
335
+ def decorator(func: Callable):
336
+ @wraps(func)
337
+ def wrapper(*args, **kwargs):
338
+ sleep_time = backoff_time
339
+ attempt = 0
340
+ while max_retries is None or attempt < max_retries:
341
+ try:
342
+ return func(*args, **kwargs)
343
+ except tuple(on_exceptions) as e:
344
+ if on_exception_callback is not None:
345
+ on_exception_callback(e, sleep_time)
346
+ time.sleep(sleep_time)
347
+ sleep_time *= backoff_multiplier
348
+ attempt += 1
349
+
350
+ return wrapper
351
+
352
+ return decorator
353
+
354
+
355
+ class Collator:
356
+ """
357
+ A class for reordering and batching elements of an array.
358
+
359
+ This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
360
+
361
+ Objects of this class have the group_by attribute which determines the method for grouping
362
+ the data while batching it. Three options include "gen_kwargs", "contexts", or None:
363
+ If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs
364
+ If group_by == "contexts" then requests will be grouped by context + cont[:-1]
365
+ If None then requests will just be reordered by length descending.
366
+ """
367
+
368
+ def __init__(
369
+ self,
370
+ arr: List,
371
+ sort_fn: Callable = lambda x: x,
372
+ group_fn: Callable = lambda x: x[1],
373
+ group_by: Union[Literal["gen_kwargs", "contexts"], None] = None,
374
+ ) -> None:
375
+ self._group_by = group_by
376
+ # 0 indices are enumerated indices. Apply functions to original arr.
377
+ self._sort_fn = lambda x: sort_fn(x[1])
378
+ self._group_fn = lambda x: group_fn(x[1])
379
+ self._reorder_indices: List = []
380
+ self._size = len(arr)
381
+ self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple(
382
+ enumerate(arr)
383
+ ) # [indices, (arr)]
384
+ if self._group_by == "contexts":
385
+ self._group_by_context()
386
+ elif self._group_by == "gen_kwargs":
387
+ self._group_by_index()
388
+
389
+ def _group_by_index(self) -> None:
390
+ """Group the elements of a list based on their indices."""
391
+ self._arr_with_indices = self.group(
392
+ self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs"
393
+ )
394
+
395
+ def _group_by_context(self) -> None:
396
+ """Group the array with indices by context."""
397
+ self._arr_with_indices = self.group(
398
+ self._arr_with_indices, fn=self._group_fn, group_by="contexts"
399
+ )
400
+
401
+ def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
402
+ """
403
+ Generates and yields batches from the reordered array. The method of grouping and batching
404
+ depends on the parameter `group_by`.
405
+ If `group_by` is set to "gen_kwargs", it will batch the
406
+ re-ordered values with same gen_kwargs for each batch.
407
+ If `group_by` is "contexts", it caches the requests by context before batching.
408
+ If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array
409
+
410
+ Parameters:
411
+ - n (int): The size of each batch. Defaults to 1.
412
+ - batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
413
+ each batch. Optional, defaults to None.
414
+
415
+ Returns:
416
+ Iterator: An iterator over batches of reordered elements grouped as per the `group_by`
417
+ attribute.
418
+
419
+ Yields:
420
+ List of batched elements according to the `group_by` attribute.
421
+ """
422
+ if self._group_by == "gen_kwargs":
423
+ for (
424
+ key,
425
+ values,
426
+ ) in self._arr_with_indices.items(): # type: ignore
427
+ values = self._reorder(values)
428
+ batch = self.get_chunks(values, n=n, fn=batch_fn)
429
+ yield from batch
430
+ elif self._group_by == "contexts":
431
+ # Get one sample from each key.
432
+ # Select longest continuation per group to ensure sufficient context logits
433
+ values = self._reorder(
434
+ [
435
+ max(value, key=lambda x: len(x[1][-1]))
436
+ for value in self._arr_with_indices.values()
437
+ ]
438
+ )
439
+ batch = self.get_chunks(values, n=n, fn=batch_fn)
440
+ yield from batch
441
+ else:
442
+ values = self._reorder(self._arr_with_indices) # type: ignore
443
+ batch = self.get_chunks(values, n=n, fn=batch_fn)
444
+ yield from batch
445
+
446
+ def get_cache(
447
+ self,
448
+ req_str: Tuple[str, str] = None,
449
+ cxt_toks: List[int] = None,
450
+ cont_toks: List[int] = None,
451
+ logits: torch.Tensor = None,
452
+ ) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]:
453
+ """
454
+ Retrieves cached single-token continuations and their associated arguments, updating indices as necessary.
455
+
456
+ The behavior of this function varies depending on how the `group_by` attribute is set:
457
+
458
+ - When `group_by` is "contexts":
459
+ The function identifies single-token continuations by checking for keys that equate to
460
+ [context+continuation][-1] and logs the indices for re-ordering.
461
+ In this mode, this function can work in two scenarios:
462
+
463
+ 1. Cache Hit - Single Match:
464
+ If a single matching context-continuation pair is found in the cache,
465
+ the function yields the original arguments.
466
+
467
+ 2. Cache Hit - Multiple Matches:
468
+ If multiple matching context-continuation pairs are found in the cache,
469
+ the function expands the logits batch dimension to match the number of cache hits.
470
+ It updates the original requests and continuation tokens.
471
+
472
+ - When `group_by` is not set to "contexts":
473
+ This method yields the original arguments, logits and continuation tokens,
474
+ without checking for one-token continuations.
475
+
476
+ Parameters:
477
+ - req_str (tuple[str, str]): Original strings used for CachingLM.
478
+ - cxt_toks (list[int]): Full context tokens used for lookup.
479
+ - cont_toks (list[int]): Continuation tokens for which logits were generated.
480
+ - logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys.
481
+
482
+ Yields:
483
+ - Iterator:
484
+ - req_str (tuple[str, str]): strings used for CachingLM.
485
+ - cont_toks (list[int]) : continuation tokens.
486
+ - logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times)
487
+ """
488
+ if self._group_by == "contexts":
489
+ cache_hit: List[
490
+ Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]]
491
+ ] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1]))
492
+ if (cache_size := len(cache_hit)) == 1:
493
+ self._reorder_indices.extend(x[0] for x in cache_hit)
494
+ yield req_str, cont_toks, logits
495
+ else:
496
+ # If we have matching requests then expand the batch dimension (no-op) and
497
+ # yield each along with its corresponding args.
498
+ multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size)
499
+ indices, req_str, cont_toks = zip(
500
+ *[(x[0], x[1][0], x[-1][-1]) for x in cache_hit]
501
+ )
502
+ self._reorder_indices.extend(indices)
503
+ for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits):
504
+ yield c_key, cont_tok, logit
505
+ else:
506
+ yield req_str, cont_toks, logits
507
+
508
+ def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator:
509
+ """
510
+ Reorders the elements in the array based on the sorting function.
511
+
512
+ Parameters:
513
+ - arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered.
514
+
515
+ Yields:
516
+ Iterator
517
+ """
518
+ arr = sorted(arr, key=self._sort_fn)
519
+ if not self._group_by == "contexts":
520
+ # If grouped by contexts then indices will be set in get_cache()
521
+ self._reorder_indices.extend([x[0] for x in arr])
522
+ yield from [x[1] for x in arr]
523
+
524
+ def get_original(self, newarr: List) -> List:
525
+ """
526
+ Restores the original order of elements from the reordered list.
527
+
528
+ Parameters:
529
+ - newarr (list): The reordered array.
530
+
531
+ Returns:
532
+ list: The array with elements restored to their original order.
533
+ """
534
+ res = [None] * self._size
535
+ cov = [False] * self._size
536
+
537
+ for ind, v in zip(self._reorder_indices, newarr):
538
+ res[ind] = v
539
+ cov[ind] = True
540
+
541
+ assert all(cov)
542
+
543
+ return res
544
+
545
+ def __len__(self):
546
+ return self._size
547
+
548
+ @staticmethod
549
+ def group(
550
+ arr: Iterable,
551
+ fn: Callable,
552
+ group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs",
553
+ ) -> dict:
554
+ """
555
+ Groups elements of an iterable based on a provided function.
556
+
557
+
558
+ The `group_by` parameter determines the method of grouping.
559
+ If `group_by` is "contexts", the elements are grouped by [context + cont][:-1].
560
+ If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict.
561
+
562
+ Parameters:
563
+ - arr (Iterable): The iterable to be grouped.
564
+ - fn (Callable): The function to determine the grouping.
565
+ - values (bool): If True, returns the values of the group. Defaults to False.
566
+
567
+ Returns:
568
+ Iterator: An iterable of grouped elements.
569
+ """
570
+ res = collections.defaultdict(list)
571
+ for ob in arr:
572
+ # where ob == [context + cont]
573
+ if group_by == "contexts":
574
+ res[tuple(fn(ob))].append(ob)
575
+ else:
576
+ try:
577
+ hashable_dict = tuple(
578
+ (
579
+ key,
580
+ tuple(value)
581
+ if isinstance(value, collections.abc.Iterable)
582
+ else value,
583
+ )
584
+ for key, value in sorted(fn(ob).items())
585
+ )
586
+ res[hashable_dict].append(ob)
587
+ except (TypeError, AttributeError):
588
+ res[tuple(fn(ob))].append(ob)
589
+ return res
590
+
591
+ @staticmethod
592
+ def get_chunks(_iter, n: int = 0, fn=None):
593
+ """
594
+ Divides an iterable into chunks of specified size or based on a given function.
595
+ Useful for batching
596
+
597
+ Parameters:
598
+ - iter: The input iterable to be divided into chunks.
599
+ - n: An integer representing the size of each chunk. Default is 0.
600
+ - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
601
+
602
+ Returns:
603
+ An iterator that yields chunks of the input iterable.
604
+
605
+ Example usage:
606
+ ```
607
+ data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
608
+ for chunk in chunks(data, 3):
609
+ print(chunk)
610
+ ```
611
+ Output:
612
+ ```
613
+ [1, 2, 3]
614
+ [4, 5, 6]
615
+ [7, 8, 9]
616
+ [10]
617
+ ```
618
+ """
619
+ arr = []
620
+ _iter = tuple(_iter)
621
+ for i, x in enumerate(_iter):
622
+ arr.append(x)
623
+ if len(arr) == (fn(i, _iter) if fn else n):
624
+ yield arr
625
+ arr = []
626
+
627
+ if arr:
628
+ yield arr
629
+
630
+
631
+ def configure_pad_token(
632
+ tokenizer: "PreTrainedTokenizerBase",
633
+ model_config: Optional["PretrainedConfig"] = None,
634
+ ) -> "PreTrainedTokenizerBase":
635
+ """
636
+ This function checks if the (Hugging Face) tokenizer has a padding token and sets it if not present.
637
+ Some tokenizers require special handling.
638
+
639
+ Args:
640
+ tokenizer: The tokenizer for which the padding token is to be handled.
641
+ model_config: The configuration of the model. Default is None.
642
+
643
+ Returns:
644
+ The tokenizer after the padding token has been handled.
645
+
646
+ Raises:
647
+ AssertionError: If the tokenizer is of type RWKVWorldTokenizer or Rwkv5Tokenizer and the padding token id is not 0.
648
+ """
649
+ if tokenizer.pad_token:
650
+ pass
651
+ elif tokenizer.unk_token:
652
+ tokenizer.pad_token_id = tokenizer.unk_token_id
653
+ elif tokenizer.eos_token:
654
+ tokenizer.pad_token_id = tokenizer.eos_token_id
655
+ else:
656
+ # handle special cases
657
+ if model_config and getattr(model_config, "model_type", None) == "qwen":
658
+ # Qwen's trust_remote_code tokenizer does not allow for adding special tokens
659
+ tokenizer.pad_token = "<|endoftext|>"
660
+ elif (
661
+ tokenizer.__class__.__name__ == "RWKVWorldTokenizer"
662
+ or tokenizer.__class__.__name__ == "Rwkv5Tokenizer"
663
+ ):
664
+ # The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0)
665
+ # The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer
666
+ # ---
667
+ # Note that the world tokenizer class name, might change in the future for the final huggingface merge
668
+ # https://github.com/huggingface/transformers/pull/26963
669
+ assert tokenizer.pad_token_id == 0
670
+ else:
671
+ tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
672
+
673
+ return tokenizer
674
+
675
+
676
+ def replace_placeholders(
677
+ string: str, default_placeholder: str, image_token: str, max_images: int
678
+ ):
679
+ """
680
+ A utility function used for local multimodal models. It locates all `placeholder` string
681
+ occurrences in the given input `string_` and replaces the first `max_count` instances with
682
+ `replacement`, and all subsequent occurrences with the empty string.
683
+
684
+ This is used to replace <image> placeholder tags by model-specific image tokens like <|image_pad|>
685
+ and to allow for only the first `max_count` images to be passed to a model if desired.
686
+
687
+ :param string: The original string containing placeholders.
688
+ :param default_placeholder: The placeholder text to be replaced.
689
+ :param image_token: The token to replace the placeholder with.
690
+ :param max_images: The maximum number of replacements to make.
691
+ :return: The string with placeholders replaced.
692
+ """
693
+ count = 0
694
+ result = []
695
+
696
+ parts = string.split(default_placeholder)
697
+ for part in parts[:-1]: # Iterate through all but the last part
698
+ result.append(part)
699
+ if count < max_images:
700
+ result.append(image_token)
701
+ count += 1
702
+ elif default_placeholder != image_token:
703
+ result.append(default_placeholder)
704
+
705
+ # Add the last part of the string
706
+ result.append(parts[-1])
707
+ return "".join(result)
708
+
709
+
710
+ def flatten_image_list(images: List[List]):
711
+ """
712
+ Takes in a list of lists of images, and returns a single list of all images in order.
713
+ Used for some multimodal models like Llava-1.5 which expects this flattened-list format for its image processor.
714
+
715
+ :param images: A list of lists of PIL images.
716
+ :return: a list of PIL images, via concatenating all the sub-lists in order.
717
+ """
718
+ return [image for image_list in images for image in image_list]
719
+
720
+
721
+ def handle_stop_sequences(
722
+ until: Union[str, List[str], None], eos: Optional[str]
723
+ ) -> List[str]:
724
+ """Ensures that the `until` parameter is a list of stop sequences and includes the EOS token."""
725
+ if isinstance(until, str):
726
+ until = [until]
727
+ elif until is None:
728
+ until = []
729
+ elif not isinstance(until, list):
730
+ raise ValueError(
731
+ f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
732
+ )
733
+
734
+ if eos is not None and eos not in until:
735
+ until.append(eos)
736
+ return until
737
+
738
+
739
+ def resize_image(
740
+ image: "Image.Image",
741
+ width: Optional[int] = None,
742
+ height: Optional[int] = None,
743
+ max_dimension: Optional[int] = None,
744
+ keep_aspect_ratio: bool = True,
745
+ resample_filter: Union[int, str] = "Image.BICUBIC",
746
+ min_width: int = 1,
747
+ min_height: int = 1,
748
+ ) -> "Image.Image":
749
+ """
750
+ Resizes a PIL Image object with flexible options.
751
+
752
+ Args:
753
+ image: The PIL Image object to resize.
754
+ width: Target width in pixels.
755
+ height: Target height in pixels.
756
+ max_dimension: Maximum size for the longer dimension of the image.
757
+ keep_aspect_ratio: If True (default) and both width and height are provided,
758
+ the image is resized to fit within these dimensions while
759
+ maintaining its aspect ratio. If False, the image is stretched
760
+ to the exact width and height.
761
+ resample_filter: The resampling filter to use for resizing.
762
+ Defaults to Image.BICUBIC.
763
+ min_width: Minimum width for the resized image. Defaults to 1.
764
+ min_height: Minimum height for the resized image. Defaults to 1.
765
+
766
+ Returns:
767
+ The resized PIL Image object. If no resize parameters are provided
768
+ or if the image already meets the criteria, the original image is returned.
769
+
770
+ Order of precedence for resizing:
771
+ 1. If width AND height are provided:
772
+ - If keep_aspect_ratio is True: Fits image within bounds, preserving aspect ratio.
773
+ - If keep_aspect_ratio is False: Resizes to exact dimensions (may distort).
774
+ 2. Else if only width is provided: Calculates height proportionally.
775
+ 3. Else if only height is provided: Calculates width proportionally.
776
+ 4. Else if max_dimension is provided: Resizes the longest side to max_dimension
777
+ and scales the other side proportionally.
778
+ 5. If none of the above are provided, returns the original image.
779
+ """
780
+ original_width, original_height = image.size
781
+
782
+ # If no arguments are provided, return the original image
783
+ if width is None and height is None and max_dimension is None:
784
+ return image
785
+
786
+ new_width = original_width
787
+ new_height = original_height
788
+
789
+ if width is not None and height is not None:
790
+ # No resize needed if image is already smaller than target dimensions
791
+ if original_width <= width and original_height <= height:
792
+ return image
793
+
794
+ if keep_aspect_ratio:
795
+ # Calculate the ratio to fit within the target dimensions
796
+ ratio = min(width / original_width, height / original_height)
797
+ new_width = int(original_width * ratio)
798
+ new_height = int(original_height * ratio)
799
+ else:
800
+ # Stretch to exact dimensions
801
+ new_width = width
802
+ new_height = height
803
+ elif width is not None:
804
+ # No resize needed if width is already smaller
805
+ if original_width <= width:
806
+ return image
807
+ # Calculate height proportionally
808
+ new_width = width
809
+ new_height = int((original_height / original_width) * new_width)
810
+ elif height is not None:
811
+ # No resize needed if height is already smaller
812
+ if original_height <= height:
813
+ return image
814
+ # Calculate width proportionally
815
+ new_height = height
816
+ new_width = int((original_width / original_height) * new_height)
817
+ elif max_dimension is not None:
818
+ # No resize needed if both dimensions are smaller than max_dimension
819
+ if max(original_height, original_width) <= max_dimension:
820
+ return image
821
+
822
+ if original_width > original_height:
823
+ # Width is the longer side
824
+ new_width = max_dimension
825
+ new_height = int((original_height / original_width) * new_width)
826
+ else:
827
+ # Height is the longer side or sides are equal
828
+ new_height = max_dimension
829
+ new_width = int((original_width / original_height) * new_height)
830
+
831
+ # Ensure dimensions are at least minimum values
832
+ new_width = max(min_width, new_width)
833
+ new_height = max(min_height, new_height)
834
+
835
+ # Perform the resize operation with the calculated dimensions
836
+ return image.resize((new_width, new_height), resample_filter)
837
+
838
+
839
+ def truncate_tokens(
840
+ tokens: List[int],
841
+ max_length: int,
842
+ tokenizer: "PreTrainedTokenizerBase",
843
+ strategy: str = "left",
844
+ ):
845
+ if strategy == "left":
846
+ return tokens[-max_length:]
847
+ elif strategy == "right":
848
+ return tokens[:max_length]
849
+ elif strategy == "middle":
850
+ # Truncate the middle of the sequence
851
+ left_length = max_length // 2
852
+ right_length = max_length - left_length
853
+ return tokens[:left_length] + tokens[-right_length:]
854
+ return None
Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/verifier.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ import ast
4
+ import re
5
+ import numpy as np
6
+ import textwrap
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class CodeVerifier:
11
+ def __init__(self, model, tokenizer, device="cuda"):
12
+ self.model = model
13
+ self.tokenizer = tokenizer
14
+ self.device = device
15
+
16
+ self.yes_ids, self.no_ids = [], []
17
+ for t in ["Yes", " Yes", "YES"]:
18
+ ids = self.tokenizer.encode(t, add_special_tokens=False)
19
+ if len(ids) > 0: self.yes_ids.append(ids[-1])
20
+ for t in ["No", " No", "NO"]:
21
+ ids = self.tokenizer.encode(t, add_special_tokens=False)
22
+ if len(ids) > 0: self.no_ids.append(ids[-1])
23
+
24
+ self.yes_ids = list(set(self.yes_ids))
25
+ self.no_ids = list(set(self.no_ids))
26
+
27
+ def _extract_python_code(self, text):
28
+ text = text.strip()
29
+ match = re.search(r"```python\s*(.*?)```", text, re.DOTALL)
30
+ if match: return match.group(1)
31
+ match_generic = re.search(r"```\s*(.*?)```", text, re.DOTALL)
32
+ if match_generic: return match_generic.group(1)
33
+ return text
34
+
35
+ def check_syntax(self, code_str):
36
+ clean_code = self._extract_python_code(code_str)
37
+ try:
38
+ if len(clean_code.strip()) < 5: return False
39
+ ast.parse(clean_code)
40
+ return True
41
+ except:
42
+ return False
43
+
44
+ def compute_confidence(self, logits):
45
+ if logits is None: return 0.0
46
+ probs = torch.softmax(logits, dim=-1)
47
+ max_probs, _ = torch.max(probs, dim=-1)
48
+ log_probs = torch.log(max_probs + 1e-10)
49
+ return torch.exp(torch.mean(log_probs)).item()
50
+
51
+ def svf_score(self, prompt, code_str, task_type="code"):
52
+
53
+ max_len = 2000
54
+ if len(code_str) > max_len:
55
+ if task_type == "reasoning":
56
+ truncated_code = code_str[:500] + "\n...[truncated]...\n" + code_str[-(max_len-500):]
57
+ else:
58
+ truncated_code = code_str[-max_len:]
59
+ else:
60
+ truncated_code = code_str
61
+
62
+ if task_type == "code":
63
+ prompt_template = f"""
64
+ You are an expert programming contest judge. Your task is to evaluate a generated solution for a given problem based on correctness, efficiency, and adherence to constraints.
65
+
66
+ [Problem Statement]
67
+ {prompt}
68
+ [/Problem Statement]
69
+
70
+ [Proposed Python Solution]
71
+ ```python
72
+ {truncated_code}
73
+ ```
74
+ [/Proposed Python Solution]
75
+
76
+ **Analysis Steps:**
77
+ 1. Correctness: Does the core algorithm correctly solve the problem?
78
+ 2. Efficiency: Is the time complexity acceptable for the given constraints?
79
+ 3. Edge Cases & Constraints: Does the code handle all rules and edge cases?
80
+
81
+ **Conclusion**: Based on your analysis, is the solution likely to be fully correct? Answer with a single word: Yes or No.
82
+ **Answer:** """
83
+
84
+ elif task_type == "math":
85
+ prompt_template = f"""
86
+ You are an expert mathematician and competition judge. Your task is to evaluate a proposed mathematical solution for a given problem based on its logical rigor and accuracy.
87
+
88
+ [Math Problem]
89
+ {prompt}
90
+ [/Math Problem]
91
+
92
+ [Proposed Mathematical Solution]
93
+ {truncated_code}
94
+ [/Proposed Mathematical Solution]
95
+
96
+ **Analysis Steps:**
97
+ 1. Reasoning Validity: Are the logical steps and mathematical properties applied correctly?
98
+ 2. Calculation Accuracy: Are the intermediate calculations or algebraic manipulations accurate?
99
+ 3. Goal Alignment: Does the current reasoning path directly lead toward the final answer required by the problem?
100
+
101
+ **Conclusion**: Based on your analysis, is this solution path sound and likely to result in the correct final answer? Answer with a single word: Yes or No.
102
+ **Answer:** """
103
+
104
+ elif task_type == "reasoning":
105
+ prompt_template = f"""
106
+ You are an expert reading comprehension and faithfulness judge. Your task is to evaluate a generated answer based on the provided context and question.
107
+
108
+ [Context and Question]
109
+ {prompt}
110
+ [/Context and Question]
111
+
112
+ [Proposed Answer]
113
+ {truncated_code}
114
+ [/Proposed Answer]
115
+
116
+ **Analysis Steps :**
117
+ 1. Faithfulness: Is the answer an exact, literal span from the context?
118
+ 2. Relevance: Does the answer directly address the specific question asked without hallucinating external information?
119
+ 3. Accuracy: Does the provided context strictly support this answer?
120
+
121
+ **Conclusion**: Based on your analysis, is the answer fully faithful to the context and correct? Answer with a single word: Yes or No.
122
+ **Answer:** """
123
+
124
+ else:
125
+ prompt_template = f"Is the following answer correct?\nQuestion: {prompt}\nAnswer: {truncated_code}\nAnswer Yes or No.\nAnswer:"
126
+
127
+ verify_text = textwrap.dedent(prompt_template).strip()
128
+ input_ids = self.tokenizer(verify_text, return_tensors="pt").input_ids.to(self.device)
129
+
130
+ max_pos = getattr(self.model.config, "max_position_embeddings",
131
+ getattr(self.model.config, "n_positions",
132
+ getattr(self.model.config, "max_sequence_length", 20480)))
133
+
134
+ if input_ids.shape[1] > max_pos - 16:
135
+ logger.warning("Verifier input is too long, truncating from the left.")
136
+ input_ids = input_ids[:, -(max_pos - 16):]
137
+
138
+ with torch.no_grad():
139
+ outputs = self.model(input_ids)
140
+ logits = outputs.logits[0, -1, :]
141
+
142
+ yes_score = max((logits[i].item() for i in self.yes_ids if i < logits.shape[-1]), default=-float('inf'))
143
+ no_score = max((logits[i].item() for i in self.no_ids if i < logits.shape[-1]), default=-float('inf'))
144
+
145
+ if yes_score == -float('inf') and no_score == -float('inf'): return 0.5
146
+
147
+ probs = torch.softmax(torch.tensor([yes_score, no_score]), dim=0)
148
+ return probs[0].item()
149
+
150
+ def get_reward(self, prompt, code_str, mode="confidence", problem_data=None, current_logits=None, task_type="code"):
151
+ if mode == "svf":
152
+ return self.svf_score(prompt, code_str, task_type=task_type)
153
+ else:
154
+ return self.compute_confidence(current_logits)
Prism/LLaDA/LLaDA_Baseline/dllm_eval/prompts/__init__.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import os
4
+ from typing import Dict
5
+
6
+ from dllm_eval import utils
7
+
8
+
9
+ eval_logger = logging.getLogger(__name__)
10
+
11
+ # Prompt library.
12
+ # Stores prompts in a dictionary indexed by 2 levels:
13
+ # prompt category name, and prompt name.
14
+ # This allows us to access prompts
15
+ PROMPT_REGISTRY: Dict[str, Dict[str, str]] = {
16
+ "qa-basic": {
17
+ "question-newline-answer": "Question: {{question}}\nAnswer:",
18
+ "q-newline-a": "Q: {{question}}\nA:",
19
+ },
20
+ }
21
+
22
+
23
+ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None):
24
+ # unpack prompt name
25
+ category_name, prompt_name = prompt_id.split(":")
26
+ if subset_name is None:
27
+ dataset_full_name = dataset_name
28
+ else:
29
+ dataset_full_name = f"{dataset_name}-{subset_name}"
30
+ eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}")
31
+ if category_name == "promptsource":
32
+ try:
33
+ from promptsource.templates import DatasetTemplates
34
+ except ModuleNotFoundError as exception:
35
+ raise type(exception)(
36
+ "Tried to load a Promptsource template, but promptsource is not installed ",
37
+ "please install promptsource via pip install lm-eval[promptsource] or pip install -e .[promptsource]",
38
+ )
39
+ try:
40
+ if subset_name is None:
41
+ prompts = DatasetTemplates(dataset_name=dataset_name)
42
+ else:
43
+ prompts = DatasetTemplates(
44
+ dataset_name=dataset_name, subset_name=subset_name
45
+ )
46
+ except Exception:
47
+ raise ValueError(f"{dataset_name} and {subset_name} not found")
48
+ if prompt_name in prompts.all_template_names:
49
+ return prompts[prompt_name]
50
+ else:
51
+ raise ValueError(
52
+ f"{prompt_name} not in prompt list {prompts.all_template_names}"
53
+ )
54
+ elif ".yaml" in category_name:
55
+ import yaml
56
+
57
+ with open(category_name, "rb") as file:
58
+ prompt_yaml_file = yaml.full_load(file)
59
+
60
+ prompt_string = prompt_yaml_file["prompts"][prompt_name]
61
+ return PromptString(prompt_string)
62
+ else:
63
+ try:
64
+ return PROMPT_REGISTRY[category_name][prompt_name]
65
+ except Exception:
66
+ raise ValueError(
67
+ f"expected only a single `:` as separator between \
68
+ prompt category and name, but got `{prompt_id}` instead"
69
+ )
70
+
71
+
72
+ def load_prompt_list(
73
+ use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs
74
+ ):
75
+ category_name, prompt_name = use_prompt.split(":")
76
+
77
+ if category_name == "promptsource":
78
+ from promptsource.templates import DatasetTemplates
79
+
80
+ if subset_name is None:
81
+ prompts = DatasetTemplates(dataset_name=dataset_name)
82
+ else:
83
+ prompts = DatasetTemplates(
84
+ dataset_name=dataset_name, subset_name=subset_name
85
+ )
86
+
87
+ prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
88
+
89
+ elif ".yaml" in category_name:
90
+ import yaml
91
+
92
+ if yaml_path is not None:
93
+ category_name = os.path.realpath(os.path.join(yaml_path, category_name))
94
+
95
+ with open(category_name, "rb") as file:
96
+ prompt_yaml_file = yaml.full_load(file)
97
+
98
+ prompt_list = utils.pattern_match(
99
+ prompt_name, prompt_yaml_file["prompts"].keys()
100
+ )
101
+
102
+ # category_name, *prompt_name = use_prompt.split(":")
103
+ # TODO allow to multiple prompt naming
104
+ # if len(prompt_name) > 1:
105
+ # prompt_list = []
106
+ # for prompt in prompt_name:
107
+ # prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
108
+ # else:
109
+ # prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
110
+ return [":".join([category_name, prompt]) for prompt in prompt_list]
111
+
112
+
113
+ class PromptString:
114
+ def __init__(self, prompt_string):
115
+ self.prompt_string = prompt_string
116
+
117
+ def apply(self, doc):
118
+ doc_to_text = self.prompt_string["doc_to_text"]
119
+ doc_to_target = self.prompt_string["doc_to_target"]
120
+
121
+ # TODO need a way to process doc_to_choice
122
+ if "doc_to_choice" in self.prompt_string:
123
+ raise NotImplementedError("Not yet implemented to accept doc_to_choice")
124
+
125
+ text_string = utils.apply_template(doc_to_text, doc)
126
+ target_string = utils.apply_template(doc_to_target, doc)
127
+
128
+ return [text_string, target_string]
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/__init__.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import inspect
3
+ import logging
4
+ import os
5
+ from functools import partial
6
+ from typing import Dict, List, Mapping, Optional, Union
7
+
8
+ from dllm_eval import utils
9
+ from dllm_eval.api.group import ConfigurableGroup, GroupConfig
10
+ from dllm_eval.api.task import ConfigurableTask, Task
11
+ from dllm_eval.evaluator_utils import get_subtask_list
12
+
13
+
14
+ GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
15
+
16
+ eval_logger = logging.getLogger(__name__)
17
+
18
+
19
+ class TaskManager:
20
+ """TaskManager indexes all tasks from the default `dllm_eval/tasks/`
21
+ and an optional directory if provided.
22
+
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ verbosity: Optional[str] = None,
28
+ include_path: Optional[Union[str, List]] = None,
29
+ include_defaults: bool = True,
30
+ metadata: Optional[dict] = None,
31
+ ) -> None:
32
+ if verbosity is not None:
33
+ utils.setup_logging(verbosity)
34
+ self.include_path = include_path
35
+ self.metadata = metadata
36
+ self._task_index = self.initialize_tasks(
37
+ include_path=include_path, include_defaults=include_defaults
38
+ )
39
+ self._all_tasks = sorted(list(self._task_index.keys()))
40
+
41
+ self._all_groups = sorted(
42
+ [x for x in self._all_tasks if self._task_index[x]["type"] == "group"]
43
+ )
44
+ self._all_subtasks = sorted(
45
+ [
46
+ x
47
+ for x in self._all_tasks
48
+ if self._task_index[x]["type"] in ["task", "python_task"]
49
+ ]
50
+ )
51
+ self._all_tags = sorted(
52
+ [x for x in self._all_tasks if self._task_index[x]["type"] == "tag"]
53
+ )
54
+
55
+ self.task_group_map = collections.defaultdict(list)
56
+
57
+ def initialize_tasks(
58
+ self,
59
+ include_path: Optional[Union[str, List]] = None,
60
+ include_defaults: bool = True,
61
+ ) -> dict[str, dict]:
62
+ """Creates a dictionary of tasks indexes.
63
+
64
+ :param include_path: Union[str, List] = None
65
+ An additional path to be searched for tasks recursively.
66
+ Can provide more than one such path as a list.
67
+ :param include_defaults: bool = True
68
+ If set to false, default tasks (those in dllm_eval/tasks/) are not indexed.
69
+ return
70
+ Dictionary of task names as key and task metadata
71
+ """
72
+ if include_defaults:
73
+ all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
74
+ else:
75
+ all_paths = []
76
+ if include_path is not None:
77
+ if isinstance(include_path, str):
78
+ include_path = [include_path]
79
+ all_paths.extend(include_path)
80
+
81
+ task_index = {}
82
+ for task_dir in all_paths:
83
+ tasks = self._get_task_and_group(task_dir)
84
+ task_index = {**tasks, **task_index}
85
+
86
+ return task_index
87
+
88
+ @property
89
+ def all_tasks(self):
90
+ return self._all_tasks
91
+
92
+ @property
93
+ def all_groups(self):
94
+ return self._all_groups
95
+
96
+ @property
97
+ def all_subtasks(self):
98
+ return self._all_subtasks
99
+
100
+ @property
101
+ def all_tags(self):
102
+ return self._all_tags
103
+
104
+ @property
105
+ def task_index(self):
106
+ return self._task_index
107
+
108
+ def list_all_tasks(
109
+ self, list_groups=True, list_tags=True, list_subtasks=True
110
+ ) -> str:
111
+ from pytablewriter import MarkdownTableWriter
112
+
113
+ def sanitize_path(path):
114
+ # don't print full path if we are within the dllm_eval/tasks dir !
115
+ # if we aren't though, provide the full path.
116
+ if "dllm_eval/tasks/" in path:
117
+ return "dllm_eval/tasks/" + path.split("dllm_eval/tasks/")[-1]
118
+ else:
119
+ return path
120
+
121
+ group_table = MarkdownTableWriter()
122
+ group_table.headers = ["Group", "Config Location"]
123
+ gt_values = []
124
+ for g in self.all_groups:
125
+ path = self.task_index[g]["yaml_path"]
126
+ if path == -1:
127
+ path = "---"
128
+ else:
129
+ path = sanitize_path(path)
130
+ gt_values.append([g, path])
131
+ group_table.value_matrix = gt_values
132
+
133
+ tag_table = MarkdownTableWriter()
134
+ tag_table.headers = ["Tag"]
135
+ tag_table.value_matrix = [[t] for t in self.all_tags]
136
+
137
+ subtask_table = MarkdownTableWriter()
138
+ subtask_table.headers = ["Task", "Config Location", "Output Type"]
139
+ st_values = []
140
+ for t in self.all_subtasks:
141
+ path = self.task_index[t]["yaml_path"]
142
+
143
+ output_type = ""
144
+
145
+ # read the yaml file to determine the output type
146
+ if path != -1:
147
+ config = utils.load_yaml_config(path, mode="simple")
148
+ if "output_type" in config:
149
+ output_type = config["output_type"]
150
+ elif (
151
+ "include" in config
152
+ ): # if no output type, check if there is an include with an output type
153
+ include_path = path.split("/")[:-1] + config["include"]
154
+ include_config = utils.load_yaml_config(include_path, mode="simple")
155
+ if "output_type" in include_config:
156
+ output_type = include_config["output_type"]
157
+
158
+ if path == -1:
159
+ path = "---"
160
+ else:
161
+ path = sanitize_path(path)
162
+ st_values.append([t, path, output_type])
163
+ subtask_table.value_matrix = st_values
164
+
165
+ result = "\n"
166
+ if list_groups:
167
+ result += group_table.dumps() + "\n\n"
168
+ if list_tags:
169
+ result += tag_table.dumps() + "\n\n"
170
+ if list_subtasks:
171
+ result += subtask_table.dumps() + "\n\n"
172
+ return result
173
+
174
+ def match_tasks(self, task_list: list[str]) -> list[str]:
175
+ return utils.pattern_match(task_list, self.all_tasks)
176
+
177
+ def _name_is_registered(self, name: str) -> bool:
178
+ if name in self.all_tasks:
179
+ return True
180
+ return False
181
+
182
+ def _name_is_task(self, name: str) -> bool:
183
+ if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"):
184
+ return True
185
+ return False
186
+
187
+ def _name_is_tag(self, name: str) -> bool:
188
+ if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"):
189
+ return True
190
+ return False
191
+
192
+ def _name_is_group(self, name: str) -> bool:
193
+ if self._name_is_registered(name) and (
194
+ self.task_index[name]["type"] == "group"
195
+ ):
196
+ return True
197
+ return False
198
+
199
+ def _name_is_python_task(self, name: str) -> bool:
200
+ if self._name_is_registered(name) and (
201
+ self.task_index[name]["type"] == "python_task"
202
+ ):
203
+ return True
204
+ return False
205
+
206
+ def _config_is_task(self, config: dict) -> bool:
207
+ if ("task" in config) and isinstance(config["task"], str):
208
+ return True
209
+ return False
210
+
211
+ def _config_is_group(self, config: dict) -> bool:
212
+ if ("task" in config) and isinstance(config["task"], list):
213
+ return True
214
+ return False
215
+
216
+ def _config_is_python_task(self, config: dict) -> bool:
217
+ if "class" in config:
218
+ return True
219
+ return False
220
+
221
+ def _get_yaml_path(self, name: str):
222
+ if name not in self.task_index:
223
+ raise ValueError
224
+ return self.task_index[name]["yaml_path"]
225
+
226
+ def _get_config(self, name):
227
+ if name not in self.task_index:
228
+ raise ValueError
229
+ yaml_path = self._get_yaml_path(name)
230
+ if yaml_path == -1:
231
+ return {}
232
+ else:
233
+ return utils.load_yaml_config(yaml_path, mode="full")
234
+
235
+ def _get_tasklist(self, name):
236
+ if self._name_is_task(name):
237
+ raise ValueError
238
+ return self.task_index[name]["task"]
239
+
240
+ def _process_alias(self, config, group=None):
241
+ # If the group is not the same as the original
242
+ # group which the group alias was intended for,
243
+ # Set the group_alias to None instead.
244
+ if ("group_alias" in config) and ("group" in config) and group is not None:
245
+ if config["group"] != group:
246
+ config["group_alias"] = None
247
+ return config
248
+
249
+ def _class_has_config_in_constructor(self, cls):
250
+ constructor = getattr(cls, "__init__", None)
251
+ return (
252
+ "config" in inspect.signature(constructor).parameters
253
+ if constructor
254
+ else False
255
+ )
256
+
257
+ def _load_individual_task_or_group(
258
+ self,
259
+ name_or_config: Optional[Union[str, dict]] = None,
260
+ parent_name: Optional[str] = None,
261
+ update_config: Optional[dict] = None,
262
+ ) -> Mapping:
263
+ def _load_task(config, task):
264
+ if "include" in config:
265
+ config = {
266
+ **utils.load_yaml_config(
267
+ yaml_path=None,
268
+ yaml_config={"include": config.pop("include")},
269
+ mode="full",
270
+ ),
271
+ **config,
272
+ }
273
+ if self._config_is_python_task(config):
274
+ if self._class_has_config_in_constructor(config["class"]):
275
+ task_object = config["class"](config=config)
276
+ else:
277
+ task_object = config["class"]()
278
+ if isinstance(task_object, ConfigurableTask):
279
+ # very scuffed: set task name here. TODO: fixme?
280
+ task_object.config.task = task
281
+ else:
282
+ if self.metadata is not None:
283
+ config["metadata"] = config.get("metadata", {}) | self.metadata
284
+ else:
285
+ config["metadata"] = config.get("metadata", {})
286
+ task_object = ConfigurableTask(config=config)
287
+
288
+ return {task: task_object}
289
+
290
+ def _get_group_and_subtask_from_config(
291
+ config: dict,
292
+ ) -> tuple[ConfigurableGroup, list[str]]:
293
+ if self.metadata is not None:
294
+ config["metadata"] = config.get("metadata", {}) | self.metadata
295
+ group_name = ConfigurableGroup(config=config)
296
+ subtask_list = []
297
+ for task in group_name.config["task"]:
298
+ if isinstance(task, str) and self._name_is_tag(task):
299
+ subtask_list.extend(self._get_tasklist(task))
300
+ else:
301
+ subtask_list.append(task)
302
+ return group_name, subtask_list
303
+
304
+ def _process_group_config(
305
+ config: dict, update_config: dict = None
306
+ ) -> tuple[dict, dict]:
307
+ if update_config is not None:
308
+ config = {**config, **update_config}
309
+ _update_config = {
310
+ k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS
311
+ }
312
+ if not bool(_update_config):
313
+ _update_config = None
314
+
315
+ group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS}
316
+ return group_config, _update_config
317
+
318
+ if isinstance(name_or_config, str):
319
+ if update_config is not None:
320
+ # Process name_or_config as a dict instead
321
+ name_or_config = {"task": name_or_config, **update_config}
322
+ elif self._name_is_task(name_or_config) or self._name_is_python_task(
323
+ name_or_config
324
+ ):
325
+ task_config = self._get_config(name_or_config)
326
+ return _load_task(task_config, task=name_or_config)
327
+ else:
328
+ subtask_list = self._get_tasklist(name_or_config)
329
+ if subtask_list == -1:
330
+ group_config = self._get_config(name_or_config)
331
+ group_config, update_config = _process_group_config(group_config)
332
+ group_name, subtask_list = _get_group_and_subtask_from_config(
333
+ group_config
334
+ )
335
+ else:
336
+ if self._name_is_tag(name_or_config):
337
+ fn = partial(
338
+ self._load_individual_task_or_group,
339
+ update_config=name_or_config
340
+ if isinstance(name_or_config, dict)
341
+ else None,
342
+ )
343
+ return dict(
344
+ collections.ChainMap(*map(fn, reversed(subtask_list)))
345
+ )
346
+ else:
347
+ group_name = ConfigurableGroup(
348
+ config={"group": name_or_config, "task": subtask_list}
349
+ )
350
+
351
+ if isinstance(name_or_config, dict):
352
+ if self._config_is_task(name_or_config):
353
+ name = name_or_config.pop("task")
354
+ if update_config is not None:
355
+ name_or_config = {**name_or_config, **update_config}
356
+ # If the name is registered as a group
357
+ if self._name_is_group(name):
358
+ group_config = self._get_config(name)
359
+
360
+ group_config, update_config = _process_group_config(
361
+ group_config, name_or_config
362
+ )
363
+ group_name, subtask_list = _get_group_and_subtask_from_config(
364
+ group_config
365
+ )
366
+ elif self._name_is_tag(name):
367
+ subtask_list = self._get_tasklist(name)
368
+ fn = partial(
369
+ self._load_individual_task_or_group,
370
+ update_config=name_or_config,
371
+ )
372
+ return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
373
+ else:
374
+ if self._name_is_registered(name):
375
+ base_task_config = self._get_config(name)
376
+
377
+ # Check if this is a duplicate.
378
+ if parent_name is not None:
379
+ num_duplicate = len(
380
+ list(
381
+ filter(
382
+ lambda x: x.startswith(name),
383
+ self.task_group_map[parent_name],
384
+ )
385
+ )
386
+ )
387
+ if num_duplicate > 0:
388
+ name = f"{name}-{num_duplicate}"
389
+ self.task_group_map[parent_name].append(name)
390
+
391
+ task_config = {
392
+ **base_task_config,
393
+ **name_or_config,
394
+ }
395
+ else:
396
+ task_config = name_or_config
397
+ return _load_task(task_config, task=name)
398
+ else:
399
+ group_config, update_config = _process_group_config(name_or_config)
400
+ group_name, subtask_list = _get_group_and_subtask_from_config(
401
+ group_config
402
+ )
403
+
404
+ fn = partial(
405
+ self._load_individual_task_or_group,
406
+ parent_name=group_name,
407
+ update_config=update_config,
408
+ )
409
+ return {
410
+ group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
411
+ }
412
+
413
+ def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
414
+ """Loads a dictionary of task objects from a list
415
+
416
+ :param task_list: Union[str, list] = None
417
+ Single string or list of string of task names to be loaded
418
+
419
+ :return
420
+ Dictionary of task objects
421
+ """
422
+ if isinstance(task_list, str):
423
+ task_list = [task_list]
424
+
425
+ all_loaded_tasks = dict(
426
+ collections.ChainMap(
427
+ *map(
428
+ lambda task: self._load_individual_task_or_group(task),
429
+ task_list,
430
+ )
431
+ )
432
+ )
433
+ return all_loaded_tasks
434
+
435
+ def load_config(self, config: Dict):
436
+ return self._load_individual_task_or_group(config)
437
+
438
+ def _get_task_and_group(self, task_dir: str):
439
+ """Creates a dictionary of tasks index with the following metadata,
440
+ - `type`, that can be either `task`, `python_task`, `group` or `tags`.
441
+ `task` refer to regular task configs, `python_task` are special
442
+ yaml files that only consists of `task` and `class` parameters.
443
+ `group` are group configs. `tags` are labels that can be assigned
444
+ to tasks to assist in sorting and calling tasks of certain themes.
445
+ - `yaml_path`, path to the yaml file. If the entry is a `group` that
446
+ was configured through a task config, the yaml_path will be -1
447
+ and all subtasks will be listed in `task` (see below)
448
+ - `task`, reserved for entries with `type` as `group`. This will list
449
+ all subtasks. When a group config is created (as opposed to task
450
+ config having `group` parameter set), this will be set to -1 to
451
+ avoid recursive indexing. The whole list of subtasks will be loaded
452
+ at evaluation.
453
+
454
+ :param task_dir: str
455
+ A directory to check for tasks
456
+
457
+ :return
458
+ Dictionary of task names as key and task metadata
459
+ """
460
+
461
+ def _populate_tags_and_groups(config, task, tasks_and_groups, print_info):
462
+ # TODO: remove group in next release
463
+ if "tag" in config:
464
+ attr_list = config["tag"]
465
+ if isinstance(attr_list, str):
466
+ attr_list = [attr_list]
467
+
468
+ for tag in attr_list:
469
+ if tag not in tasks_and_groups:
470
+ tasks_and_groups[tag] = {
471
+ "type": "tag",
472
+ "task": [task],
473
+ "yaml_path": -1,
474
+ }
475
+ elif tasks_and_groups[tag]["type"] != "tag":
476
+ eval_logger.info(
477
+ f"The tag '{tag}' is already registered as a group, this tag will not be registered. "
478
+ "This may affect tasks you want to call."
479
+ )
480
+ break
481
+ else:
482
+ tasks_and_groups[tag]["task"].append(task)
483
+
484
+ # TODO: remove group in next release
485
+ print_info = True
486
+ ignore_dirs = [
487
+ "__pycache__",
488
+ ".ipynb_checkpoints",
489
+ ]
490
+ tasks_and_groups = collections.defaultdict()
491
+ for root, dirs, file_list in os.walk(task_dir):
492
+ dirs[:] = [d for d in dirs if d not in ignore_dirs]
493
+ for f in file_list:
494
+ if f.endswith(".yaml"):
495
+ yaml_path = os.path.join(root, f)
496
+ print(yaml_path)
497
+ config = utils.load_yaml_config(yaml_path, mode="simple")
498
+ if self._config_is_python_task(config):
499
+ # This is a python class config
500
+ task = config["task"]
501
+ tasks_and_groups[task] = {
502
+ "type": "python_task",
503
+ "yaml_path": yaml_path,
504
+ }
505
+ _populate_tags_and_groups(
506
+ config, task, tasks_and_groups, print_info
507
+ )
508
+ elif self._config_is_group(config):
509
+ # This is a group config
510
+ tasks_and_groups[config["group"]] = {
511
+ "type": "group",
512
+ "task": -1, # This signals that
513
+ # we don't need to know
514
+ # the task list for indexing
515
+ # as it can be loaded
516
+ # when called.
517
+ "yaml_path": yaml_path,
518
+ }
519
+
520
+ # # Registered the level 1 tasks from a group config
521
+ # for config in config["task"]:
522
+ # if isinstance(config, dict) and self._config_is_task(config):
523
+ # task = config["task"]
524
+ # tasks_and_groups[task] = {
525
+ # "type": "task",
526
+ # "yaml_path": yaml_path,
527
+ # }
528
+
529
+ elif self._config_is_task(config):
530
+ # This is a task config
531
+ task = config["task"]
532
+ tasks_and_groups[task] = {
533
+ "type": "task",
534
+ "yaml_path": yaml_path,
535
+ }
536
+ _populate_tags_and_groups(
537
+ config, task, tasks_and_groups, print_info
538
+ )
539
+ else:
540
+ eval_logger.debug(f"File {f} in {root} could not be loaded")
541
+
542
+ return tasks_and_groups
543
+
544
+
545
+ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
546
+ if "task" in task_config:
547
+ return task_config["task"]
548
+ if "dataset_name" in task_config:
549
+ return "{dataset_path}_{dataset_name}".format(**task_config)
550
+ else:
551
+ return "{dataset_path}".format(**task_config)
552
+
553
+
554
+ def get_task_name_from_object(task_object):
555
+ if hasattr(task_object, "config"):
556
+ return task_object._config["task"]
557
+
558
+ # TODO: scrap this
559
+ # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
560
+ return (
561
+ task_object.EVAL_HARNESS_NAME
562
+ if hasattr(task_object, "EVAL_HARNESS_NAME")
563
+ else type(task_object).__name__
564
+ )
565
+
566
+
567
+ def _check_duplicates(task_dict: dict) -> None:
568
+ """helper function solely used in validating get_task_dict output.
569
+ Takes the output of dllm_eval.evaluator_utils.get_subtask_list and
570
+ returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
571
+ "oversubscribed" to several disjoint groups.
572
+ """
573
+ subtask_names = []
574
+ for key, value in task_dict.items():
575
+ subtask_names.extend(value)
576
+
577
+ duplicate_tasks = {
578
+ task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
579
+ }
580
+
581
+ # locate the potentially problematic groups that seem to 'compete' for constituent subtasks
582
+ competing_groups = [
583
+ group
584
+ for group in task_dict.keys()
585
+ if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
586
+ ]
587
+
588
+ if len(duplicate_tasks) > 0:
589
+ raise ValueError(
590
+ f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
591
+ )
592
+
593
+
594
+ def get_task_dict(
595
+ task_name_list: Union[str, List[Union[str, Dict, Task]]],
596
+ task_manager: Optional[TaskManager] = None,
597
+ ):
598
+ """Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
599
+
600
+ :param task_name_list: List[Union[str, Dict, Task]]
601
+ Name of model or LM object, see dllm_eval.models.get_model
602
+ :param task_manager: TaskManager = None
603
+ A TaskManager object that stores indexed tasks. If not set,
604
+ task_manager will load one. This should be set by the user
605
+ if there are additional paths that want to be included
606
+ via `include_path`
607
+
608
+ :return
609
+ Dictionary of task objects
610
+ """
611
+
612
+ task_name_from_string_dict = {}
613
+ task_name_from_config_dict = {}
614
+ task_name_from_object_dict = {}
615
+
616
+ if isinstance(task_name_list, str):
617
+ task_name_list = [task_name_list]
618
+ elif isinstance(task_name_list, list):
619
+ if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]):
620
+ raise TypeError(
621
+ "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
622
+ )
623
+ else:
624
+ raise TypeError(
625
+ f"Expected a 'str' or 'list' but received {type(task_name_list)}."
626
+ )
627
+
628
+ string_task_name_list = [task for task in task_name_list if isinstance(task, str)]
629
+ others_task_name_list = [
630
+ task for task in task_name_list if not isinstance(task, str)
631
+ ]
632
+ if len(string_task_name_list) > 0:
633
+ if task_manager is None:
634
+ task_manager = TaskManager()
635
+
636
+ task_name_from_string_dict = task_manager.load_task_or_group(
637
+ string_task_name_list
638
+ )
639
+
640
+ for task_element in others_task_name_list:
641
+ if isinstance(task_element, dict):
642
+ task_name_from_config_dict = {
643
+ **task_name_from_config_dict,
644
+ **task_manager.load_config(config=task_element),
645
+ }
646
+
647
+ elif isinstance(task_element, Task):
648
+ task_name_from_object_dict = {
649
+ **task_name_from_object_dict,
650
+ get_task_name_from_object(task_element): task_element,
651
+ }
652
+
653
+ if not set(task_name_from_string_dict.keys()).isdisjoint(
654
+ set(task_name_from_object_dict.keys())
655
+ ):
656
+ raise ValueError
657
+
658
+ final_task_dict = {
659
+ **task_name_from_string_dict,
660
+ **task_name_from_config_dict,
661
+ **task_name_from_object_dict,
662
+ }
663
+
664
+ # behavior can get odd if one tries to invoke several groups that "compete" for the same task.
665
+ # (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask
666
+ # and we'd be unsure which to use and report.)
667
+ # we explicitly check and error in this case.
668
+ _check_duplicates(get_subtask_list(final_task_dict))
669
+
670
+ return final_task_dict
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/gsm8k.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task: gsm8k
2
+ dataset_path: openai/gsm8k
3
+ dataset_name: main
4
+ output_type: generate_until
5
+ training_split: train
6
+ fewshot_split: train
7
+ test_split: test
8
+ doc_to_text: !function utils.gsm_prompt
9
+ doc_to_target: "{{answer.split('####')[-1].strip()}}"
10
+ generation_kwargs:
11
+ until:
12
+ - "[NO_UNTIL_PLACEHOLDER]"
13
+ do_sample: false
14
+ repeats: 1
15
+ num_fewshot: 0
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def gsm_prompt(doc):
2
+ system_prompt = (
3
+ "You are a math expert. You will be given a question to solve. Solve it step by step. Wrap the final answer in a \\boxed{}. \n"
4
+ "Respond in the following format:\n"
5
+ "<reasoning>\n"
6
+ "Your reasoning here\n"
7
+ "</reasoning>\n"
8
+ "<answer>\n"
9
+ "\\boxed{...}\n"
10
+ "</answer>"
11
+ )
12
+ prompt = f"{system_prompt}\n\n{doc['question']}\n\n"
13
+ return prompt
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/humaneval.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task: humaneval
2
+ dataset_path: openai/openai_humaneval
3
+ unsafe_code: true
4
+ output_type: generate_until
5
+ test_split: test
6
+ doc_to_text: "Write a solution to the following problem and make sure that it passes the tests:\n{{prompt}}\n\nFirst, reason about the solution step-by-step. Then, write the code.\nRespond in the following format:\n<reasoning>\nYour reasoning here\n</reasoning>\n<answer>\n```python\nThe complete implementation of the {{entry_point}} function\n```\n</answer>"
7
+ doc_to_target: "{{test}}\ncheck({{entry_point}})"
8
+ generation_kwargs:
9
+ until:
10
+ - "[NO_UNTIL_PLACEHOLDER]"
11
+ do_sample: false
12
+ repeats: 1
13
+ num_fewshot: 0
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate as hf_evaluate
2
+
3
+
4
+ try:
5
+ compute_ = hf_evaluate.load("code_eval")
6
+ test_cases = ["assert add(2, 3)==5"]
7
+ candidates = [["def add(a,b): return a*b"]]
8
+ results = compute_.compute(references=test_cases, predictions=candidates, k=[1])
9
+ except Exception as e:
10
+ raise e
11
+
12
+
13
+ def pass_at_k(references: list[str], predictions: list[list[str]], k: list[int] = None):
14
+ global compute_
15
+ assert k is not None
16
+ if isinstance(k, int):
17
+ k = [k]
18
+ res = compute_.compute(
19
+ references=references,
20
+ predictions=predictions,
21
+ k=k
22
+ )
23
+ return res[0]
24
+
25
+
26
+ def clean_response_string(r: str) -> str:
27
+ cleaned_text = r if r.rfind("```python") == -1 else r[r.rfind("```python"):]
28
+ cleaned_text = cleaned_text if cleaned_text.rfind("```") == -1 else cleaned_text[: cleaned_text.rfind("```")]
29
+ cleaned_text = cleaned_text if cleaned_text.rfind("if __name__ == \"__main__\":") == -1 else cleaned_text[: cleaned_text.rfind("if __name__ == \"__main__\":")]
30
+ return cleaned_text
31
+
32
+
33
+ def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
34
+ return [[doc["prompt"] + r for r in resp] for resp, doc in zip(resps, docs)]
35
+
36
+
37
+ def build_predictions(
38
+ resps: list[list[str]], docs: list[dict]
39
+ ) -> list[list[str]]:
40
+ return [
41
+ [clean_response_string(r) for r in resp]
42
+ for resp, doc in zip(resps, docs)
43
+ ]
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/mbpp.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task: mbpp
2
+ dataset_path: google-research-datasets/mbpp
3
+ dataset_name: full
4
+ unsafe_code: true
5
+ output_type: generate_until
6
+ test_split: test
7
+ doc_to_text: "\n{{text}} Your code should pass these tests:\n\n{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}} \n\nFirst, reason about the solution step-by-step. Then, write the code.\nRespond in the following format:\n<reasoning>\nYour reasoning here\n</reasoning>\n<answer>\n```python\nThe complete implementation of the function\n```\n</answer>"
8
+ doc_to_target: "{% if is_fewshot is defined %}{{code}}\n[DONE]{% else %}{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}{% endif %}"
9
+ target_delimiter: ""
10
+ generation_kwargs:
11
+ until:
12
+ - "[NO_UNTIL_PLACEHOLDER]"
13
+ do_sample: false
14
+ num_fewshot: 0
Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Union
3
+
4
+ import evaluate as hf_evaluate
5
+
6
+
7
+ try:
8
+ pass_at_k = hf_evaluate.load("code_eval")
9
+
10
+ # run simple test to check code execution is enabled before model generation
11
+ test_cases = ["assert add(2, 3)==5"]
12
+ candidates = [["def add(a,b): return a*b"]]
13
+ results = pass_at_k.compute(references=test_cases, predictions=candidates, k=[1])
14
+ except Exception as e:
15
+ raise e
16
+
17
+
18
+ def pass_at_1(
19
+ references: Union[str, list[str]], predictions: Union[str, list[list[str]]]
20
+ ) -> float:
21
+ if isinstance(references, str):
22
+ references = [references]
23
+ if isinstance(predictions[0], str):
24
+ predictions = [[p] for p in predictions]
25
+ return pass_at_k.compute(
26
+ references=references,
27
+ predictions=predictions,
28
+ k=[1],
29
+ num_workers=48
30
+ )[0]["pass@1"]
31
+
32
+
33
+ def extract_code_blocks(text: str) -> str:
34
+ text = re.sub(r"\[DONE\]", "", text)
35
+ text = re.sub(r"<\|eot_id\|>", "", text)
36
+ text = re.sub(r"<\|endoftext\|>", "", text)
37
+ return text
38
+
39
+
40
+ def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
41
+ return [[extract_code_blocks(r) for r in resp] for resp in resps]
42
+
43
+
44
+ def list_fewshot_samples():
45
+ return [
46
+ {
47
+ "task_id": 2,
48
+ "text": "Write a function to find the similar elements from the given two tuple lists.",
49
+ "code": "def similar_elements(test_tup1, test_tup2):\r\n res = tuple(set(test_tup1) & set(test_tup2))\r\n return (res) ",
50
+ "test_list": [
51
+ "assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)",
52
+ "assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)",
53
+ "assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)",
54
+ ],
55
+ "is_fewshot": True,
56
+ },
57
+ {
58
+ "task_id": 3,
59
+ "text": "Write a python function to identify non-prime numbers.",
60
+ "code": "import math\r\ndef is_not_prime(n):\r\n result = False\r\n for i in range(2,int(math.sqrt(n)) + 1):\r\n if n % i == 0:\r\n result = True\r\n return result",
61
+ "test_list": [
62
+ "assert is_not_prime(2) == False",
63
+ "assert is_not_prime(10) == True",
64
+ "assert is_not_prime(35) == True",
65
+ ],
66
+ "is_fewshot": True,
67
+ },
68
+ {
69
+ "task_id": 4,
70
+ "text": "Write a function to find the largest integers from a given list of numbers using heap queue algorithm.",
71
+ "code": "import heapq as hq\r\ndef heap_queue_largest(nums,n):\r\n largest_nums = hq.nlargest(n, nums)\r\n return largest_nums",
72
+ "test_list": [
73
+ "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] ",
74
+ "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] ",
75
+ "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]",
76
+ ],
77
+ "is_fewshot": True,
78
+ },
79
+ ]
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/certifi/__main__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from certifi import contents, where
4
+
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument("-c", "--contents", action="store_true")
7
+ args = parser.parse_args()
8
+
9
+ if args.contents:
10
+ print(contents())
11
+ else:
12
+ print(where())
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2013-2019 Nikolay Kim and Andrew Svetlov
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/METADATA ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: frozenlist
3
+ Version: 1.5.0
4
+ Summary: A list-like structure which implements collections.abc.MutableSequence
5
+ Home-page: https://github.com/aio-libs/frozenlist
6
+ Maintainer: aiohttp team <team@aiohttp.org>
7
+ Maintainer-email: team@aiohttp.org
8
+ License: Apache 2
9
+ Project-URL: Chat: Matrix, https://matrix.to/#/#aio-libs:matrix.org
10
+ Project-URL: Chat: Matrix Space, https://matrix.to/#/#aio-libs-space:matrix.org
11
+ Project-URL: CI: Github Actions, https://github.com/aio-libs/frozenlist/actions
12
+ Project-URL: Code of Conduct, https://github.com/aio-libs/.github/blob/master/CODE_OF_CONDUCT.md
13
+ Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/frozenlist
14
+ Project-URL: Docs: Changelog, https://github.com/aio-libs/frozenlist/blob/master/CHANGES.rst#changelog
15
+ Project-URL: Docs: RTD, https://frozenlist.aio-libs.org
16
+ Project-URL: GitHub: issues, https://github.com/aio-libs/frozenlist/issues
17
+ Project-URL: GitHub: repo, https://github.com/aio-libs/frozenlist
18
+ Classifier: Development Status :: 5 - Production/Stable
19
+ Classifier: Intended Audience :: Developers
20
+ Classifier: License :: OSI Approved :: Apache Software License
21
+ Classifier: Operating System :: POSIX
22
+ Classifier: Operating System :: MacOS :: MacOS X
23
+ Classifier: Operating System :: Microsoft :: Windows
24
+ Classifier: Programming Language :: Cython
25
+ Classifier: Programming Language :: Python
26
+ Classifier: Programming Language :: Python :: 3
27
+ Classifier: Programming Language :: Python :: 3.8
28
+ Classifier: Programming Language :: Python :: 3.9
29
+ Classifier: Programming Language :: Python :: 3.10
30
+ Classifier: Programming Language :: Python :: 3.11
31
+ Classifier: Programming Language :: Python :: 3.12
32
+ Classifier: Programming Language :: Python :: 3.13
33
+ Classifier: Programming Language :: Python :: Implementation :: CPython
34
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
35
+ Requires-Python: >=3.8
36
+ Description-Content-Type: text/x-rst
37
+ License-File: LICENSE
38
+
39
+ frozenlist
40
+ ==========
41
+
42
+ .. image:: https://github.com/aio-libs/frozenlist/workflows/CI/badge.svg
43
+ :target: https://github.com/aio-libs/frozenlist/actions
44
+ :alt: GitHub status for master branch
45
+
46
+ .. image:: https://codecov.io/gh/aio-libs/frozenlist/branch/master/graph/badge.svg
47
+ :target: https://codecov.io/gh/aio-libs/frozenlist
48
+ :alt: codecov.io status for master branch
49
+
50
+ .. image:: https://img.shields.io/pypi/v/frozenlist.svg?logo=Python&logoColor=white
51
+ :target: https://pypi.org/project/frozenlist
52
+ :alt: frozenlist @ PyPI
53
+
54
+ .. image:: https://readthedocs.org/projects/frozenlist/badge/?version=latest
55
+ :target: https://frozenlist.aio-libs.org
56
+ :alt: Read The Docs build status badge
57
+
58
+ .. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat
59
+ :target: https://matrix.to/#/%23aio-libs:matrix.org
60
+ :alt: Matrix Room — #aio-libs:matrix.org
61
+
62
+ .. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat
63
+ :target: https://matrix.to/#/%23aio-libs-space:matrix.org
64
+ :alt: Matrix Space — #aio-libs-space:matrix.org
65
+
66
+ Introduction
67
+ ------------
68
+
69
+ ``frozenlist.FrozenList`` is a list-like structure which implements
70
+ ``collections.abc.MutableSequence``. The list is *mutable* until ``FrozenList.freeze``
71
+ is called, after which list modifications raise ``RuntimeError``:
72
+
73
+
74
+ >>> from frozenlist import FrozenList
75
+ >>> fl = FrozenList([17, 42])
76
+ >>> fl.append('spam')
77
+ >>> fl.append('Vikings')
78
+ >>> fl
79
+ <FrozenList(frozen=False, [17, 42, 'spam', 'Vikings'])>
80
+ >>> fl.freeze()
81
+ >>> fl
82
+ <FrozenList(frozen=True, [17, 42, 'spam', 'Vikings'])>
83
+ >>> fl.frozen
84
+ True
85
+ >>> fl.append("Monty")
86
+ Traceback (most recent call last):
87
+ File "<stdin>", line 1, in <module>
88
+ File "frozenlist/_frozenlist.pyx", line 97, in frozenlist._frozenlist.FrozenList.append
89
+ self._check_frozen()
90
+ File "frozenlist/_frozenlist.pyx", line 19, in frozenlist._frozenlist.FrozenList._check_frozen
91
+ raise RuntimeError("Cannot modify frozen list.")
92
+ RuntimeError: Cannot modify frozen list.
93
+
94
+
95
+ FrozenList is also hashable, but only when frozen. Otherwise it also throws a RuntimeError:
96
+
97
+
98
+ >>> fl = FrozenList([17, 42, 'spam'])
99
+ >>> hash(fl)
100
+ Traceback (most recent call last):
101
+ File "<stdin>", line 1, in <module>
102
+ File "frozenlist/_frozenlist.pyx", line 111, in frozenlist._frozenlist.FrozenList.__hash__
103
+ raise RuntimeError("Cannot hash unfrozen list.")
104
+ RuntimeError: Cannot hash unfrozen list.
105
+ >>> fl.freeze()
106
+ >>> hash(fl)
107
+ 3713081631934410656
108
+ >>> dictionary = {fl: 'Vikings'} # frozen fl can be a dict key
109
+ >>> dictionary
110
+ {<FrozenList(frozen=True, [1, 2])>: 'Vikings'}
111
+
112
+
113
+ Installation
114
+ ------------
115
+
116
+ ::
117
+
118
+ $ pip install frozenlist
119
+
120
+ The library requires Python 3.8 or newer.
121
+
122
+
123
+ Documentation
124
+ -------------
125
+
126
+ https://frozenlist.aio-libs.org
127
+
128
+ Communication channels
129
+ ----------------------
130
+
131
+ We have a *Matrix Space* `#aio-libs-space:matrix.org
132
+ <https://matrix.to/#/%23aio-libs-space:matrix.org>`_ which is
133
+ also accessible via Gitter.
134
+
135
+ Requirements
136
+ ------------
137
+
138
+ - Python >= 3.8
139
+
140
+ License
141
+ -------
142
+
143
+ ``frozenlist`` is offered under the Apache 2 license.
144
+
145
+ Source code
146
+ -----------
147
+
148
+ The project is hosted on GitHub_
149
+
150
+ Please file an issue in the `bug tracker
151
+ <https://github.com/aio-libs/frozenlist/issues>`_ if you have found a bug
152
+ or have some suggestions to improve the library.
153
+
154
+ .. _GitHub: https://github.com/aio-libs/frozenlist
155
+
156
+ =========
157
+ Changelog
158
+ =========
159
+
160
+ ..
161
+ You should *NOT* be adding new change log entries to this file, this
162
+ file is managed by towncrier. You *may* edit previous change logs to
163
+ fix problems like typo corrections or such.
164
+ To add a new change log entry, please see
165
+ https://pip.pypa.io/en/latest/development/contributing/#news-entries
166
+ we named the news folder "changes".
167
+
168
+ WARNING: Don't drop the next directive!
169
+
170
+ .. towncrier release notes start
171
+
172
+ 1.5.0 (2024-10-22)
173
+ ==================
174
+
175
+ Bug fixes
176
+ ---------
177
+
178
+ - An incorrect signature of the ``__class_getitem__`` class method
179
+ has been fixed, adding a missing ``class_item`` argument under
180
+ Python 3.8 and older.
181
+
182
+ This change also improves the code coverage of this method that
183
+ was previously missing -- by `@webknjaz <https://github.com/sponsors/webknjaz>`__.
184
+
185
+
186
+ *Related issues and pull requests on GitHub:*
187
+ `#567 <https://github.com/aio-libs/frozenlist/issues/567>`__, `#571 <https://github.com/aio-libs/frozenlist/issues/571>`__.
188
+
189
+
190
+ Improved documentation
191
+ ----------------------
192
+
193
+ - Rendered issue, PR, and commit links now lead to
194
+ ``frozenlist``'s repo instead of ``yarl``'s repo.
195
+
196
+
197
+ *Related issues and pull requests on GitHub:*
198
+ `#573 <https://github.com/aio-libs/frozenlist/issues/573>`__.
199
+
200
+ - On the ``Contributing docs`` page,
201
+ a link to the ``Towncrier philosophy`` has been fixed.
202
+
203
+
204
+ *Related issues and pull requests on GitHub:*
205
+ `#574 <https://github.com/aio-libs/frozenlist/issues/574>`__.
206
+
207
+
208
+ Packaging updates and notes for downstreams
209
+ -------------------------------------------
210
+
211
+ - A name of a temporary building directory now reflects
212
+ that it's related to ``frozenlist``, not ``yarl``.
213
+
214
+
215
+ *Related issues and pull requests on GitHub:*
216
+ `#573 <https://github.com/aio-libs/frozenlist/issues/573>`__.
217
+
218
+ - Declared Python 3.13 supported officially in the distribution package metadata.
219
+
220
+
221
+ *Related issues and pull requests on GitHub:*
222
+ `#595 <https://github.com/aio-libs/frozenlist/issues/595>`__.
223
+
224
+
225
+ ----
226
+
227
+
228
+ 1.4.1 (2023-12-15)
229
+ ==================
230
+
231
+ Packaging updates and notes for downstreams
232
+ -------------------------------------------
233
+
234
+ - Declared Python 3.12 and PyPy 3.8-3.10 supported officially
235
+ in the distribution package metadata.
236
+
237
+
238
+ *Related issues and pull requests on GitHub:*
239
+ `#553 <https://github.com/aio-libs/frozenlist/issues/553>`__.
240
+
241
+ - Replaced the packaging is replaced from an old-fashioned ``setup.py`` to an
242
+ in-tree `PEP 517 <https://peps.python.org/pep-517>`__ build backend -- by `@webknjaz <https://github.com/sponsors/webknjaz>`__.
243
+
244
+ Whenever the end-users or downstream packagers need to build ``frozenlist``
245
+ from source (a Git checkout or an sdist), they may pass a ``config_settings``
246
+ flag ``pure-python``. If this flag is not set, a C-extension will be built
247
+ and included into the distribution.
248
+
249
+ Here is how this can be done with ``pip``:
250
+
251
+ .. code-block:: console
252
+
253
+ $ python3 -m pip install . --config-settings=pure-python=
254
+
255
+ This will also work with ``-e | --editable``.
256
+
257
+ The same can be achieved via ``pypa/build``:
258
+
259
+ .. code-block:: console
260
+
261
+ $ python3 -m build --config-setting=pure-python=
262
+
263
+ Adding ``-w | --wheel`` can force ``pypa/build`` produce a wheel from source
264
+ directly, as opposed to building an ``sdist`` and then building from it.
265
+
266
+
267
+ *Related issues and pull requests on GitHub:*
268
+ `#560 <https://github.com/aio-libs/frozenlist/issues/560>`__.
269
+
270
+
271
+ Contributor-facing changes
272
+ --------------------------
273
+
274
+ - It is now possible to request line tracing in Cython builds using the
275
+ ``with-cython-tracing`` `PEP 517 <https://peps.python.org/pep-517>`__ config setting
276
+ -- `@webknjaz <https://github.com/sponsors/webknjaz>`__.
277
+
278
+ This can be used in CI and development environment to measure coverage
279
+ on Cython modules, but is not normally useful to the end-users or
280
+ downstream packagers.
281
+
282
+ Here's a usage example:
283
+
284
+ .. code-block:: console
285
+
286
+ $ python3 -Im pip install . --config-settings=with-cython-tracing=true
287
+
288
+ For editable installs, this setting is on by default. Otherwise, it's
289
+ off unless requested explicitly.
290
+
291
+ The following produces C-files required for the Cython coverage
292
+ plugin to map the measurements back to the PYX-files:
293
+
294
+ .. code-block:: console
295
+
296
+ $ python -Im pip install -e .
297
+
298
+ Alternatively, the ``FROZENLIST_CYTHON_TRACING=1`` environment variable
299
+ can be set to do the same as the `PEP 517 <https://peps.python.org/pep-517>`__ config setting.
300
+
301
+
302
+ *Related issues and pull requests on GitHub:*
303
+ `#560 <https://github.com/aio-libs/frozenlist/issues/560>`__.
304
+
305
+ - Coverage collection has been implemented for the Cython modules
306
+ -- by `@webknjaz <https://github.com/sponsors/webknjaz>`__.
307
+
308
+ It will also be reported to Codecov from any non-release CI jobs.
309
+
310
+
311
+ *Related issues and pull requests on GitHub:*
312
+ `#561 <https://github.com/aio-libs/frozenlist/issues/561>`__.
313
+
314
+ - A step-by-step ``Release Guide`` guide has
315
+ been added, describing how to release *frozenlist* -- by `@webknjaz <https://github.com/sponsors/webknjaz>`__.
316
+
317
+ This is primarily targeting the maintainers.
318
+
319
+
320
+ *Related issues and pull requests on GitHub:*
321
+ `#563 <https://github.com/aio-libs/frozenlist/issues/563>`__.
322
+
323
+ - Detailed ``Contributing Guidelines`` on
324
+ authoring the changelog fragments have been published in the
325
+ documentation -- by `@webknjaz <https://github.com/sponsors/webknjaz>`__.
326
+
327
+
328
+ *Related issues and pull requests on GitHub:*
329
+ `#564 <https://github.com/aio-libs/frozenlist/issues/564>`__.
330
+
331
+
332
+ ----
333
+
334
+
335
+ 1.4.0 (2023-07-12)
336
+ ==================
337
+
338
+ The published source distribution package became buildable
339
+ under Python 3.12.
340
+
341
+
342
+ ----
343
+
344
+
345
+ Bugfixes
346
+ --------
347
+
348
+ - Removed an unused ``typing.Tuple`` import
349
+ `#411 <https://github.com/aio-libs/frozenlist/issues/411>`_
350
+
351
+
352
+ Deprecations and Removals
353
+ -------------------------
354
+
355
+ - Dropped Python 3.7 support.
356
+ `#413 <https://github.com/aio-libs/frozenlist/issues/413>`_
357
+
358
+
359
+ Misc
360
+ ----
361
+
362
+ - `#410 <https://github.com/aio-libs/frozenlist/issues/410>`_, `#433 <https://github.com/aio-libs/frozenlist/issues/433>`_
363
+
364
+
365
+ ----
366
+
367
+
368
+ 1.3.3 (2022-11-08)
369
+ ==================
370
+
371
+ - Fixed CI runs when creating a new release, where new towncrier versions
372
+ fail when the current version section is already present.
373
+
374
+
375
+ ----
376
+
377
+
378
+ 1.3.2 (2022-11-08)
379
+ ==================
380
+
381
+ Misc
382
+ ----
383
+
384
+ - Updated the CI runs to better check for test results and to avoid deprecated syntax. `#327 <https://github.com/aio-libs/frozenlist/issues/327>`_
385
+
386
+
387
+ ----
388
+
389
+
390
+ 1.3.1 (2022-08-02)
391
+ ==================
392
+
393
+ The published source distribution package became buildable
394
+ under Python 3.11.
395
+
396
+
397
+ ----
398
+
399
+
400
+ 1.3.0 (2022-01-18)
401
+ ==================
402
+
403
+ Bugfixes
404
+ --------
405
+
406
+ - Do not install C sources with binary distributions.
407
+ `#250 <https://github.com/aio-libs/frozenlist/issues/250>`_
408
+
409
+
410
+ Deprecations and Removals
411
+ -------------------------
412
+
413
+ - Dropped Python 3.6 support
414
+ `#274 <https://github.com/aio-libs/frozenlist/issues/274>`_
415
+
416
+
417
+ ----
418
+
419
+
420
+ 1.2.0 (2021-10-16)
421
+ ==================
422
+
423
+ Features
424
+ --------
425
+
426
+ - ``FrozenList`` now supports being used as a generic type as per PEP 585, e.g. ``frozen_int_list: FrozenList[int]`` (requires Python 3.9 or newer).
427
+ `#172 <https://github.com/aio-libs/frozenlist/issues/172>`_
428
+ - Added support for Python 3.10.
429
+ `#227 <https://github.com/aio-libs/frozenlist/issues/227>`_
430
+ - Started shipping platform-specific wheels with the ``musl`` tag targeting typical Alpine Linux runtimes.
431
+ `#227 <https://github.com/aio-libs/frozenlist/issues/227>`_
432
+ - Started shipping platform-specific arm64 wheels for Apple Silicon.
433
+ `#227 <https://github.com/aio-libs/frozenlist/issues/227>`_
434
+
435
+
436
+ ----
437
+
438
+
439
+ 1.1.1 (2020-11-14)
440
+ ==================
441
+
442
+ Bugfixes
443
+ --------
444
+
445
+ - Provide x86 Windows wheels.
446
+ `#169 <https://github.com/aio-libs/frozenlist/issues/169>`_
447
+
448
+
449
+ ----
450
+
451
+
452
+ 1.1.0 (2020-10-13)
453
+ ==================
454
+
455
+ Features
456
+ --------
457
+
458
+ - Add support for hashing of a frozen list.
459
+ `#136 <https://github.com/aio-libs/frozenlist/issues/136>`_
460
+
461
+ - Support Python 3.8 and 3.9.
462
+
463
+ - Provide wheels for ``aarch64``, ``i686``, ``ppc64le``, ``s390x`` architectures on
464
+ Linux as well as ``x86_64``.
465
+
466
+
467
+ ----
468
+
469
+
470
+ 1.0.0 (2019-11-09)
471
+ ==================
472
+
473
+ Deprecations and Removals
474
+ -------------------------
475
+
476
+ - Dropped support for Python 3.5; only 3.6, 3.7 and 3.8 are supported going forward.
477
+ `#24 <https://github.com/aio-libs/frozenlist/issues/24>`_
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/RECORD ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ frozenlist-1.5.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2
+ frozenlist-1.5.0.dist-info/LICENSE,sha256=b9UkPpLdf5jsacesN3co50kFcJ_1J6W_mNbQJjwE9bY,11332
3
+ frozenlist-1.5.0.dist-info/METADATA,sha256=BpQvB7z2NbU3f4XTQDvhAZ9L08WR4XiYajilj9IY6Yk,13762
4
+ frozenlist-1.5.0.dist-info/RECORD,,
5
+ frozenlist-1.5.0.dist-info/WHEEL,sha256=64hRuO2b8JU2aeheZgbK9oQwal3JVqwtqRhpQNr8ZdQ,224
6
+ frozenlist-1.5.0.dist-info/top_level.txt,sha256=jivtxsPXA3nK3WBWW2LW5Mtu_GHt8UZA13NeCs2cKuA,11
7
+ frozenlist/__init__.py,sha256=ymVtnW3MinO-Ux3cBj_PLEpXnmLawk45el8vcX6IkWY,2371
8
+ frozenlist/__init__.pyi,sha256=vMEoES1xGegPtVXoCi9XydEeHsyuIq-KdeXwP5PdsaA,1470
9
+ frozenlist/__pycache__/__init__.cpython-312.pyc,,
10
+ frozenlist/_frozenlist.cpython-312-x86_64-linux-gnu.so,sha256=n65G8t1lqSUcWICd9rjOJujV1lxtniI2JJQQXtc7BjQ,961592
11
+ frozenlist/_frozenlist.pyx,sha256=4YturclNF7wioO7YX3Vzl7Ldb2-iswe6UrjJOMKSswU,2993
12
+ frozenlist/py.typed,sha256=sow9soTwP9T_gEAQSVh7Gb8855h04Nwmhs2We-JRgZM,7
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/WHEEL ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (75.2.0)
3
+ Root-Is-Purelib: false
4
+ Tag: cp312-cp312-manylinux_2_5_x86_64
5
+ Tag: cp312-cp312-manylinux1_x86_64
6
+ Tag: cp312-cp312-manylinux_2_17_x86_64
7
+ Tag: cp312-cp312-manylinux2014_x86_64
8
+
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ frozenlist
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/LICENSE.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2016 Nathaniel J. Smith <njs@pobox.com> and other contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining
6
+ a copy of this software and associated documentation files (the
7
+ "Software"), to deal in the Software without restriction, including
8
+ without limitation the rights to use, copy, modify, merge, publish,
9
+ distribute, sublicense, and/or sell copies of the Software, and to
10
+ permit persons to whom the Software is furnished to do so, subject to
11
+ the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be
14
+ included in all copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/METADATA ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: h11
3
+ Version: 0.14.0
4
+ Summary: A pure-Python, bring-your-own-I/O implementation of HTTP/1.1
5
+ Home-page: https://github.com/python-hyper/h11
6
+ Author: Nathaniel J. Smith
7
+ Author-email: njs@pobox.com
8
+ License: MIT
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: Implementation :: CPython
13
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3 :: Only
16
+ Classifier: Programming Language :: Python :: 3.7
17
+ Classifier: Programming Language :: Python :: 3.8
18
+ Classifier: Programming Language :: Python :: 3.9
19
+ Classifier: Programming Language :: Python :: 3.10
20
+ Classifier: Topic :: Internet :: WWW/HTTP
21
+ Classifier: Topic :: System :: Networking
22
+ Requires-Python: >=3.7
23
+ License-File: LICENSE.txt
24
+ Requires-Dist: typing-extensions ; python_version < "3.8"
25
+
26
+ h11
27
+ ===
28
+
29
+ .. image:: https://travis-ci.org/python-hyper/h11.svg?branch=master
30
+ :target: https://travis-ci.org/python-hyper/h11
31
+ :alt: Automated test status
32
+
33
+ .. image:: https://codecov.io/gh/python-hyper/h11/branch/master/graph/badge.svg
34
+ :target: https://codecov.io/gh/python-hyper/h11
35
+ :alt: Test coverage
36
+
37
+ .. image:: https://readthedocs.org/projects/h11/badge/?version=latest
38
+ :target: http://h11.readthedocs.io/en/latest/?badge=latest
39
+ :alt: Documentation Status
40
+
41
+ This is a little HTTP/1.1 library written from scratch in Python,
42
+ heavily inspired by `hyper-h2 <https://hyper-h2.readthedocs.io/>`_.
43
+
44
+ It's a "bring-your-own-I/O" library; h11 contains no IO code
45
+ whatsoever. This means you can hook h11 up to your favorite network
46
+ API, and that could be anything you want: synchronous, threaded,
47
+ asynchronous, or your own implementation of `RFC 6214
48
+ <https://tools.ietf.org/html/rfc6214>`_ -- h11 won't judge you.
49
+ (Compare this to the current state of the art, where every time a `new
50
+ network API <https://trio.readthedocs.io/>`_ comes along then someone
51
+ gets to start over reimplementing the entire HTTP protocol from
52
+ scratch.) Cory Benfield made an `excellent blog post describing the
53
+ benefits of this approach
54
+ <https://lukasa.co.uk/2015/10/The_New_Hyper/>`_, or if you like video
55
+ then here's his `PyCon 2016 talk on the same theme
56
+ <https://www.youtube.com/watch?v=7cC3_jGwl_U>`_.
57
+
58
+ This also means that h11 is not immediately useful out of the box:
59
+ it's a toolkit for building programs that speak HTTP, not something
60
+ that could directly replace ``requests`` or ``twisted.web`` or
61
+ whatever. But h11 makes it much easier to implement something like
62
+ ``requests`` or ``twisted.web``.
63
+
64
+ At a high level, working with h11 goes like this:
65
+
66
+ 1) First, create an ``h11.Connection`` object to track the state of a
67
+ single HTTP/1.1 connection.
68
+
69
+ 2) When you read data off the network, pass it to
70
+ ``conn.receive_data(...)``; you'll get back a list of objects
71
+ representing high-level HTTP "events".
72
+
73
+ 3) When you want to send a high-level HTTP event, create the
74
+ corresponding "event" object and pass it to ``conn.send(...)``;
75
+ this will give you back some bytes that you can then push out
76
+ through the network.
77
+
78
+ For example, a client might instantiate and then send a
79
+ ``h11.Request`` object, then zero or more ``h11.Data`` objects for the
80
+ request body (e.g., if this is a POST), and then a
81
+ ``h11.EndOfMessage`` to indicate the end of the message. Then the
82
+ server would then send back a ``h11.Response``, some ``h11.Data``, and
83
+ its own ``h11.EndOfMessage``. If either side violates the protocol,
84
+ you'll get a ``h11.ProtocolError`` exception.
85
+
86
+ h11 is suitable for implementing both servers and clients, and has a
87
+ pleasantly symmetric API: the events you send as a client are exactly
88
+ the ones that you receive as a server and vice-versa.
89
+
90
+ `Here's an example of a tiny HTTP client
91
+ <https://github.com/python-hyper/h11/blob/master/examples/basic-client.py>`_
92
+
93
+ It also has `a fine manual <https://h11.readthedocs.io/>`_.
94
+
95
+ FAQ
96
+ ---
97
+
98
+ *Whyyyyy?*
99
+
100
+ I wanted to play with HTTP in `Curio
101
+ <https://curio.readthedocs.io/en/latest/tutorial.html>`__ and `Trio
102
+ <https://trio.readthedocs.io>`__, which at the time didn't have any
103
+ HTTP libraries. So I thought, no big deal, Python has, like, a dozen
104
+ different implementations of HTTP, surely I can find one that's
105
+ reusable. I didn't find one, but I did find Cory's call-to-arms
106
+ blog-post. So I figured, well, fine, if I have to implement HTTP from
107
+ scratch, at least I can make sure no-one *else* has to ever again.
108
+
109
+ *Should I use it?*
110
+
111
+ Maybe. You should be aware that it's a very young project. But, it's
112
+ feature complete and has an exhaustive test-suite and complete docs,
113
+ so the next step is for people to try using it and see how it goes
114
+ :-). If you do then please let us know -- if nothing else we'll want
115
+ to talk to you before making any incompatible changes!
116
+
117
+ *What are the features/limitations?*
118
+
119
+ Roughly speaking, it's trying to be a robust, complete, and non-hacky
120
+ implementation of the first "chapter" of the HTTP/1.1 spec: `RFC 7230:
121
+ HTTP/1.1 Message Syntax and Routing
122
+ <https://tools.ietf.org/html/rfc7230>`_. That is, it mostly focuses on
123
+ implementing HTTP at the level of taking bytes on and off the wire,
124
+ and the headers related to that, and tries to be anal about spec
125
+ conformance. It doesn't know about higher-level concerns like URL
126
+ routing, conditional GETs, cross-origin cookie policies, or content
127
+ negotiation. But it does know how to take care of framing,
128
+ cross-version differences in keep-alive handling, and the "obsolete
129
+ line folding" rule, so you can focus your energies on the hard /
130
+ interesting parts for your application, and it tries to support the
131
+ full specification in the sense that any useful HTTP/1.1 conformant
132
+ application should be able to use h11.
133
+
134
+ It's pure Python, and has no dependencies outside of the standard
135
+ library.
136
+
137
+ It has a test suite with 100.0% coverage for both statements and
138
+ branches.
139
+
140
+ Currently it supports Python 3 (testing on 3.7-3.10) and PyPy 3.
141
+ The last Python 2-compatible version was h11 0.11.x.
142
+ (Originally it had a Cython wrapper for `http-parser
143
+ <https://github.com/nodejs/http-parser>`_ and a beautiful nested state
144
+ machine implemented with ``yield from`` to postprocess the output. But
145
+ I had to take these out -- the new *parser* needs fewer lines-of-code
146
+ than the old *parser wrapper*, is written in pure Python, uses no
147
+ exotic language syntax, and has more features. It's sad, really; that
148
+ old state machine was really slick. I just need a few sentences here
149
+ to mourn that.)
150
+
151
+ I don't know how fast it is. I haven't benchmarked or profiled it yet,
152
+ so it's probably got a few pointless hot spots, and I've been trying
153
+ to err on the side of simplicity and robustness instead of
154
+ micro-optimization. But at the architectural level I tried hard to
155
+ avoid fundamentally bad decisions, e.g., I believe that all the
156
+ parsing algorithms remain linear-time even in the face of pathological
157
+ input like slowloris, and there are no byte-by-byte loops. (I also
158
+ believe that it maintains bounded memory usage in the face of
159
+ arbitrary/pathological input.)
160
+
161
+ The whole library is ~800 lines-of-code. You can read and understand
162
+ the whole thing in less than an hour. Most of the energy invested in
163
+ this so far has been spent on trying to keep things simple by
164
+ minimizing special-cases and ad hoc state manipulation; even though it
165
+ is now quite small and simple, I'm still annoyed that I haven't
166
+ figured out how to make it even smaller and simpler. (Unfortunately,
167
+ HTTP does not lend itself to simplicity.)
168
+
169
+ The API is ~feature complete and I don't expect the general outlines
170
+ to change much, but you can't judge an API's ergonomics until you
171
+ actually document and use it, so I'd expect some changes in the
172
+ details.
173
+
174
+ *How do I try it?*
175
+
176
+ .. code-block:: sh
177
+
178
+ $ pip install h11
179
+ $ git clone git@github.com:python-hyper/h11
180
+ $ cd h11/examples
181
+ $ python basic-client.py
182
+
183
+ and go from there.
184
+
185
+ *License?*
186
+
187
+ MIT
188
+
189
+ *Code of conduct?*
190
+
191
+ Contributors are requested to follow our `code of conduct
192
+ <https://github.com/python-hyper/h11/blob/master/CODE_OF_CONDUCT.md>`_ in
193
+ all project spaces.
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/RECORD ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h11-0.14.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2
+ h11-0.14.0.dist-info/LICENSE.txt,sha256=N9tbuFkm2yikJ6JYZ_ELEjIAOuob5pzLhRE4rbjm82E,1124
3
+ h11-0.14.0.dist-info/METADATA,sha256=B7pZ0m7WBXNs17vl6hUH9bJTL9s37DaGvY31w7jNxSg,8175
4
+ h11-0.14.0.dist-info/RECORD,,
5
+ h11-0.14.0.dist-info/WHEEL,sha256=ewwEueio1C2XeHTvT17n8dZUJgOvyCWCt0WVNLClP9o,92
6
+ h11-0.14.0.dist-info/top_level.txt,sha256=F7dC4jl3zeh8TGHEPaWJrMbeuoWbS379Gwdi-Yvdcis,4
7
+ h11/__init__.py,sha256=iO1KzkSO42yZ6ffg-VMgbx_ZVTWGUY00nRYEWn-s3kY,1507
8
+ h11/__pycache__/__init__.cpython-312.pyc,,
9
+ h11/__pycache__/_abnf.cpython-312.pyc,,
10
+ h11/__pycache__/_connection.cpython-312.pyc,,
11
+ h11/__pycache__/_events.cpython-312.pyc,,
12
+ h11/__pycache__/_headers.cpython-312.pyc,,
13
+ h11/__pycache__/_readers.cpython-312.pyc,,
14
+ h11/__pycache__/_receivebuffer.cpython-312.pyc,,
15
+ h11/__pycache__/_state.cpython-312.pyc,,
16
+ h11/__pycache__/_util.cpython-312.pyc,,
17
+ h11/__pycache__/_version.cpython-312.pyc,,
18
+ h11/__pycache__/_writers.cpython-312.pyc,,
19
+ h11/_abnf.py,sha256=ybixr0xsupnkA6GFAyMubuXF6Tc1lb_hF890NgCsfNc,4815
20
+ h11/_connection.py,sha256=eS2sorMD0zKLCFiB9lW9W9F_Nzny2tjHa4e6s1ujr1c,26539
21
+ h11/_events.py,sha256=LEfuvg1AbhHaVRwxCd0I-pFn9-ezUOaoL8o2Kvy1PBA,11816
22
+ h11/_headers.py,sha256=RqB8cd8CN0blYPzcLe5qeCh-phv6D1U_CHj4hs67lgQ,10230
23
+ h11/_readers.py,sha256=EbSed0jzwVUiD1nOPAeUcVE4Flf3wXkxfb8c06-OTBM,8383
24
+ h11/_receivebuffer.py,sha256=xrspsdsNgWFxRfQcTXxR8RrdjRXXTK0Io5cQYWpJ1Ws,5252
25
+ h11/_state.py,sha256=k1VL6SDbaPkSrZ-49ewCXDpuiUS69_46YhbWjuV1qEY,13300
26
+ h11/_util.py,sha256=LWkkjXyJaFlAy6Lt39w73UStklFT5ovcvo0TkY7RYuk,4888
27
+ h11/_version.py,sha256=LVyTdiZRzIIEv79UyOgbM5iUrJUllEzlCWaJEYBY1zc,686
28
+ h11/_writers.py,sha256=oFKm6PtjeHfbj4RLX7VB7KDc1gIY53gXG3_HR9ltmTA,5081
29
+ h11/py.typed,sha256=sow9soTwP9T_gEAQSVh7Gb8855h04Nwmhs2We-JRgZM,7
30
+ h11/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
+ h11/tests/__pycache__/__init__.cpython-312.pyc,,
32
+ h11/tests/__pycache__/helpers.cpython-312.pyc,,
33
+ h11/tests/__pycache__/test_against_stdlib_http.cpython-312.pyc,,
34
+ h11/tests/__pycache__/test_connection.cpython-312.pyc,,
35
+ h11/tests/__pycache__/test_events.cpython-312.pyc,,
36
+ h11/tests/__pycache__/test_headers.cpython-312.pyc,,
37
+ h11/tests/__pycache__/test_helpers.cpython-312.pyc,,
38
+ h11/tests/__pycache__/test_io.cpython-312.pyc,,
39
+ h11/tests/__pycache__/test_receivebuffer.cpython-312.pyc,,
40
+ h11/tests/__pycache__/test_state.cpython-312.pyc,,
41
+ h11/tests/__pycache__/test_util.cpython-312.pyc,,
42
+ h11/tests/data/test-file,sha256=ZJ03Rqs98oJw29OHzJg7LlMzyGQaRAY0r3AqBeM2wVU,65
43
+ h11/tests/helpers.py,sha256=a1EVG_p7xU4wRsa3tMPTRxuaKCmretok9sxXWvqfmQA,3355
44
+ h11/tests/test_against_stdlib_http.py,sha256=cojCHgHXFQ8gWhNlEEwl3trmOpN-5uDukRoHnElqo3A,3995
45
+ h11/tests/test_connection.py,sha256=ZbPLDPclKvjgjAhgk-WlCPBaf17c4XUIV2tpaW08jOI,38720
46
+ h11/tests/test_events.py,sha256=LPVLbcV-NvPNK9fW3rraR6Bdpz1hAlsWubMtNaJ5gHg,4657
47
+ h11/tests/test_headers.py,sha256=qd8T1Zenuz5GbD6wklSJ5G8VS7trrYgMV0jT-SMvqg8,5612
48
+ h11/tests/test_helpers.py,sha256=kAo0CEM4LGqmyyP2ZFmhsyq3UFJqoFfAbzu3hbWreRM,794
49
+ h11/tests/test_io.py,sha256=uCZVnjarkRBkudfC1ij-KSCQ71XWJhnkgkgWWkKgYPQ,16386
50
+ h11/tests/test_receivebuffer.py,sha256=3jGbeJM36Akqg_pAhPb7XzIn2NS6RhPg-Ryg8Eu6ytk,3454
51
+ h11/tests/test_state.py,sha256=rqll9WqFsJPE0zSrtCn9LH659mPKsDeXZ-DwXwleuBQ,8928
52
+ h11/tests/test_util.py,sha256=VO5L4nSFe4pgtSwKuv6u_6l0H7UeizF5WKuHTWreg70,2970
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/WHEEL ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: bdist_wheel (0.37.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ h11
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/LICENSE ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2016 Andrew Svetlov and aio-libs contributors
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/METADATA ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: multidict
3
+ Version: 6.1.0
4
+ Summary: multidict implementation
5
+ Home-page: https://github.com/aio-libs/multidict
6
+ Author: Andrew Svetlov
7
+ Author-email: andrew.svetlov@gmail.com
8
+ License: Apache 2
9
+ Project-URL: Chat: Matrix, https://matrix.to/#/#aio-libs:matrix.org
10
+ Project-URL: Chat: Matrix Space, https://matrix.to/#/#aio-libs-space:matrix.org
11
+ Project-URL: CI: GitHub, https://github.com/aio-libs/multidict/actions
12
+ Project-URL: Code of Conduct, https://github.com/aio-libs/.github/blob/master/CODE_OF_CONDUCT.md
13
+ Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/multidict
14
+ Project-URL: Docs: Changelog, https://multidict.aio-libs.org/en/latest/changes/
15
+ Project-URL: Docs: RTD, https://multidict.aio-libs.org
16
+ Project-URL: GitHub: issues, https://github.com/aio-libs/multidict/issues
17
+ Project-URL: GitHub: repo, https://github.com/aio-libs/multidict
18
+ Classifier: Development Status :: 5 - Production/Stable
19
+ Classifier: Intended Audience :: Developers
20
+ Classifier: License :: OSI Approved :: Apache Software License
21
+ Classifier: Programming Language :: Python
22
+ Classifier: Programming Language :: Python :: 3
23
+ Classifier: Programming Language :: Python :: 3.8
24
+ Classifier: Programming Language :: Python :: 3.9
25
+ Classifier: Programming Language :: Python :: 3.10
26
+ Classifier: Programming Language :: Python :: 3.11
27
+ Classifier: Programming Language :: Python :: 3.12
28
+ Classifier: Programming Language :: Python :: 3.13
29
+ Requires-Python: >=3.8
30
+ Description-Content-Type: text/x-rst
31
+ License-File: LICENSE
32
+ Requires-Dist: typing-extensions >=4.1.0 ; python_version < "3.11"
33
+
34
+ =========
35
+ multidict
36
+ =========
37
+
38
+ .. image:: https://github.com/aio-libs/multidict/actions/workflows/ci-cd.yml/badge.svg
39
+ :target: https://github.com/aio-libs/multidict/actions
40
+ :alt: GitHub status for master branch
41
+
42
+ .. image:: https://codecov.io/gh/aio-libs/multidict/branch/master/graph/badge.svg
43
+ :target: https://codecov.io/gh/aio-libs/multidict
44
+ :alt: Coverage metrics
45
+
46
+ .. image:: https://img.shields.io/pypi/v/multidict.svg
47
+ :target: https://pypi.org/project/multidict
48
+ :alt: PyPI
49
+
50
+ .. image:: https://readthedocs.org/projects/multidict/badge/?version=latest
51
+ :target: https://multidict.aio-libs.org
52
+ :alt: Read The Docs build status badge
53
+
54
+ .. image:: https://img.shields.io/pypi/pyversions/multidict.svg
55
+ :target: https://pypi.org/project/multidict
56
+ :alt: Python versions
57
+
58
+ .. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat
59
+ :target: https://matrix.to/#/%23aio-libs:matrix.org
60
+ :alt: Matrix Room — #aio-libs:matrix.org
61
+
62
+ .. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat
63
+ :target: https://matrix.to/#/%23aio-libs-space:matrix.org
64
+ :alt: Matrix Space — #aio-libs-space:matrix.org
65
+
66
+ Multidict is dict-like collection of *key-value pairs* where key
67
+ might occur more than once in the container.
68
+
69
+ Introduction
70
+ ------------
71
+
72
+ *HTTP Headers* and *URL query string* require specific data structure:
73
+ *multidict*. It behaves mostly like a regular ``dict`` but it may have
74
+ several *values* for the same *key* and *preserves insertion ordering*.
75
+
76
+ The *key* is ``str`` (or ``istr`` for case-insensitive dictionaries).
77
+
78
+ ``multidict`` has four multidict classes:
79
+ ``MultiDict``, ``MultiDictProxy``, ``CIMultiDict``
80
+ and ``CIMultiDictProxy``.
81
+
82
+ Immutable proxies (``MultiDictProxy`` and
83
+ ``CIMultiDictProxy``) provide a dynamic view for the
84
+ proxied multidict, the view reflects underlying collection changes. They
85
+ implement the ``collections.abc.Mapping`` interface.
86
+
87
+ Regular mutable (``MultiDict`` and ``CIMultiDict``) classes
88
+ implement ``collections.abc.MutableMapping`` and allows them to change
89
+ their own content.
90
+
91
+
92
+ *Case insensitive* (``CIMultiDict`` and
93
+ ``CIMultiDictProxy``) assume the *keys* are case
94
+ insensitive, e.g.::
95
+
96
+ >>> dct = CIMultiDict(key='val')
97
+ >>> 'Key' in dct
98
+ True
99
+ >>> dct['Key']
100
+ 'val'
101
+
102
+ *Keys* should be ``str`` or ``istr`` instances.
103
+
104
+ The library has optional C Extensions for speed.
105
+
106
+
107
+ License
108
+ -------
109
+
110
+ Apache 2
111
+
112
+ Library Installation
113
+ --------------------
114
+
115
+ .. code-block:: bash
116
+
117
+ $ pip install multidict
118
+
119
+ The library is Python 3 only!
120
+
121
+ PyPI contains binary wheels for Linux, Windows and MacOS. If you want to install
122
+ ``multidict`` on another operating system (or *Alpine Linux* inside a Docker) the
123
+ tarball will be used to compile the library from source. It requires a C compiler and
124
+ Python headers to be installed.
125
+
126
+ To skip the compilation, please use the `MULTIDICT_NO_EXTENSIONS` environment variable,
127
+ e.g.:
128
+
129
+ .. code-block:: bash
130
+
131
+ $ MULTIDICT_NO_EXTENSIONS=1 pip install multidict
132
+
133
+ Please note, the pure Python (uncompiled) version is about 20-50 times slower depending on
134
+ the usage scenario!!!
135
+
136
+
137
+
138
+ Changelog
139
+ ---------
140
+ See `RTD page <http://multidict.aio-libs.org/en/latest/changes>`_.
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/RECORD ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ multidict-6.1.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2
+ multidict-6.1.0.dist-info/LICENSE,sha256=k9Ealo4vDzY3PECBH_bSDhc_WMPKtYhM1mF7v9eVSSo,611
3
+ multidict-6.1.0.dist-info/METADATA,sha256=OnCx5DR4XPf64GIDK4XmcA2e7HLQ_784vMfEQy287kM,4979
4
+ multidict-6.1.0.dist-info/RECORD,,
5
+ multidict-6.1.0.dist-info/WHEEL,sha256=3FRagTIevYnyede1Gym_XNKguJrd07UOyEdLNhxNq20,151
6
+ multidict-6.1.0.dist-info/top_level.txt,sha256=-euDElkk5_qkmfIJ7WiqCab02ZlSFZWynejKg59qZQQ,10
7
+ multidict/__init__.py,sha256=p60Ag5UVACSli1txazSi85foCmHN-cg3qZDCuWdOKng,928
8
+ multidict/__init__.pyi,sha256=SbgC2ew1NvNXWlRKs9o0KhW4moozgMqgQ0OA4Re5JQQ,4840
9
+ multidict/__pycache__/__init__.cpython-312.pyc,,
10
+ multidict/__pycache__/_abc.cpython-312.pyc,,
11
+ multidict/__pycache__/_compat.cpython-312.pyc,,
12
+ multidict/__pycache__/_multidict_base.cpython-312.pyc,,
13
+ multidict/__pycache__/_multidict_py.cpython-312.pyc,,
14
+ multidict/_abc.py,sha256=Zvnrn4SBkrv4QTD7-ZzqNcoxw0f8KStLMPzGvBuGT2w,1190
15
+ multidict/_compat.py,sha256=uCNUpVHJSFOiKUJmRcz3SDqMpkb37C_csc29ijr8Evo,352
16
+ multidict/_multidict.cpython-312-x86_64-linux-gnu.so,sha256=6BwP62oLns2chEgPfwAa8DseIoF0wOWBe81pHjnlqhs,418968
17
+ multidict/_multidict_base.py,sha256=ZndtnZ5oc1sODKmXsv6F9kWvVNCda9xAEEFXkaPoFoA,3979
18
+ multidict/_multidict_py.py,sha256=57h4sYrRIu7EjMX4YpHVIZVrV9-q1KCW3F6rao10D3U,15050
19
+ multidict/py.typed,sha256=e9bmbH3UFxsabQrnNFPG9qxIXztwbcM6IKDYnvZwprY,15
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/WHEEL ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (74.1.2)
3
+ Root-Is-Purelib: false
4
+ Tag: cp312-cp312-manylinux_2_17_x86_64
5
+ Tag: cp312-cp312-manylinux2014_x86_64
6
+