Thomas Lemberger commited on
Commit
ad1b251
·
1 Parent(s): f737059

uodate loading script

Browse files
Files changed (1) hide show
  1. bio-lm.py +176 -0
bio-lm.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ # template from : https://github.com/huggingface/datasets/blob/master/templates/new_dataset_script.py
18
+
19
+ """Loading script for the biolang dataset for language modeling in biology."""
20
+
21
+ from __future__ import absolute_import, division, print_function
22
+
23
+ import json
24
+ import datasets
25
+
26
+
27
+ class BioLang(datasets.GeneratorBasedBuilder):
28
+ """BioLang: a dataset to train language models in biology."""
29
+
30
+ _CITATION = """\
31
+ @Unpublished{
32
+ huggingface: dataset,
33
+ title = {biolang},
34
+ authors={Thomas Lemberger, EMBO},
35
+ year={2021}
36
+ }
37
+ """
38
+
39
+ _DESCRIPTION = """\
40
+ This dataset is based on abstracts from the open access section of EuropePubMed Central to train language models in the domain of biology.
41
+ """
42
+
43
+ _HOMEPAGE = "https://europepmc.org/downloads/openaccess"
44
+
45
+ _LICENSE = "CC BY 4.0"
46
+
47
+ _URLS = {
48
+ "biolang": "https://huggingface.co/datasets/EMBO/biolang/resolve/main/oapmc_abstracts_figs.zip",
49
+ }
50
+
51
+ VERSION = datasets.Version("0.0.1")
52
+
53
+ BUILDER_CONFIGS = [
54
+ datasets.BuilderConfig(name="SEQ2SEQ", version="0.0.1", description="Control dataset with no masking for seq2seq task."),
55
+ datasets.BuilderConfig(name="MLM", version="0.0.1", description="Dataset for general masked language model."),
56
+ datasets.BuilderConfig(name="DET", version="0.0.1", description="Dataset for part-of-speech (determinant) masked language model."),
57
+ datasets.BuilderConfig(name="VERB", version="0.0.1", description="Dataset for part-of-speech (verbs) masked language model."),
58
+ datasets.BuilderConfig(name="SMALL", version="0.0.1", description="Dataset for part-of-speech (determinants, conjunctions, prepositions, pronouns) masked language model."),
59
+ datasets.BuilderConfig(name="NOUN", version="0.0.1", description="Dataset for part-of-speech (nouns) masked language model."),
60
+ ]
61
+
62
+ DEFAULT_CONFIG_NAME = "MLM" # It's not mandatory to have a default configuration. Just use one if it make sense.
63
+
64
+ def _info(self):
65
+ if self.config.name == "MLM":
66
+ features = datasets.Features({
67
+ "input_ids": datasets.Sequence(feature=datasets.Value("int32")),
68
+ "special_tokens_mask": datasets.Sequence(feature=datasets.Value("int8")),
69
+ })
70
+ elif self.config.name in ["DET", "VERB", "SMALL", "NOUN", "NULL"]:
71
+ features = datasets.Features({
72
+ "input_ids": datasets.Sequence(feature=datasets.Value("int32")),
73
+ "tag_mask": datasets.Sequence(feature=datasets.Value("int8")),
74
+ })
75
+ elif self.config.name == "SEQ2SEQ":
76
+ features = datasets.Features({
77
+ "input_ids": datasets.Sequence(feature=datasets.Value("int32")),
78
+ "labels": datasets.Sequence(feature=datasets.Value("int32"))
79
+ })
80
+
81
+ return datasets.DatasetInfo(
82
+ description=self._DESCRIPTION,
83
+ features=features, # Here we define them above because they are different between the two configurations
84
+ supervised_keys=('input_ids', 'pos_mask'),
85
+ homepage=self._HOMEPAGE,
86
+ license=self._LICENSE,
87
+ citation=self._CITATION,
88
+ )
89
+
90
+ def _split_generators(self, dl_manager):
91
+ """Returns SplitGenerators."""
92
+ if self.config.data_dir:
93
+ data_dir = self.config.data_dir
94
+ else:
95
+ url = self._URLS["biolang"]
96
+ data_dir = dl_manager.download_and_extract(url)
97
+ data_dir += "/oapmc_abstracts_figs"
98
+ return [
99
+ datasets.SplitGenerator(
100
+ name=datasets.Split.TRAIN,
101
+ gen_kwargs={
102
+ "filepath": data_dir + "/train.jsonl",
103
+ "split": "train",
104
+ },
105
+ ),
106
+ datasets.SplitGenerator(
107
+ name=datasets.Split.TEST,
108
+ gen_kwargs={
109
+ "filepath": data_dir + "/test.jsonl",
110
+ "split": "test"
111
+ },
112
+ ),
113
+ datasets.SplitGenerator(
114
+ name=datasets.Split.VALIDATION,
115
+ gen_kwargs={
116
+ "filepath": data_dir + "/eval.jsonl",
117
+ "split": "eval",
118
+ },
119
+ ),
120
+ ]
121
+
122
+ def _generate_examples(self, filepath, split):
123
+ """ Yields examples. """
124
+ with open(filepath, encoding="utf-8") as f:
125
+ for id_, row in enumerate(f):
126
+ data = json.loads(row)
127
+ if self.config.name == "MLM":
128
+ yield id_, {
129
+ "input_ids": data["input_ids"],
130
+ "special_tokens_mask": data['special_tokens_mask']
131
+ }
132
+ # else Part of Speech tags based on
133
+ # Universal POS tags https://universaldependencies.org/u/pos/
134
+ elif self.config.name == "DET":
135
+ pos_mask = [0] * len(data['input_ids'])
136
+ for idx, label in enumerate(data['label_ids']):
137
+ if label == 'DET':
138
+ pos_mask[idx] = 1
139
+ yield id_, {
140
+ "input_ids": data['input_ids'],
141
+ "tag_mask": pos_mask,
142
+ }
143
+ elif self.config.name == "VERB":
144
+ pos_mask = [0] * len(data['input_ids'])
145
+ for idx, label in enumerate(data['label_ids']):
146
+ if label == 'VERB':
147
+ pos_mask[idx] = 1
148
+ yield id_, {
149
+ "input_ids": data['input_ids'],
150
+ "tag_mask": pos_mask,
151
+ }
152
+ elif self.config.name == "SMALL":
153
+ pos_mask = [0] * len(data['input_ids'])
154
+ for idx, label in enumerate(data['label_ids']):
155
+ if label in ['DET', 'CCONJ', 'SCONJ', 'ADP', 'PRON']:
156
+ pos_mask[idx] = 1
157
+ yield id_, {
158
+ "input_ids": data['input_ids'],
159
+ "tag_mask": pos_mask,
160
+ }
161
+ elif self.config.name == "NOUN":
162
+ pos_mask = [0] * len(data['input_ids'])
163
+ for idx, label in enumerate(data['label_ids']):
164
+ if label in ['NOUN']:
165
+ pos_mask[idx] = 1
166
+ yield id_, {
167
+ "input_ids": data['input_ids'],
168
+ "tag_mask": pos_mask,
169
+ }
170
+ elif self.config.name == "SEQ2SEQ":
171
+ "Seq2seq training needs the input_ids as labels, no masking"
172
+ pos_mask = [0] * len(data['input_ids'])
173
+ yield id_, {
174
+ "input_ids": data['input_ids'],
175
+ "labels": data['input_ids']
176
+ }