james-joobs commited on
Commit
9974a90
·
1 Parent(s): cdb9a92

add trainer with ner example

Browse files
Files changed (1) hide show
  1. trainer.py +136 -0
trainer.py CHANGED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, load_metric
2
+
3
+ import numpy as np
4
+
5
+ from transformers import AutoTokenizer
6
+ from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
7
+ from transformers import DataCollatorForTokenClassification
8
+
9
+ label_list = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
10
+
11
+ labels_vocab = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
12
+ labels_vocab_reverse = {v:k for k,v in labels_vocab.items()}
13
+
14
+ metric = load_metric("seqeval")
15
+
16
+ def load_datasets(tokenizer):
17
+ def tokenize_and_align_labels(examples):
18
+ label_all_tokens = False
19
+ tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
20
+
21
+ labels = []
22
+ for i, label in enumerate(examples["ner_tags"]):
23
+ word_ids = tokenized_inputs.word_ids(batch_index=i)
24
+ previous_word_idx = None
25
+ label_ids = []
26
+ for word_idx in word_ids:
27
+ # Special tokens have a word id that is None. We set the label to -100 so they are automatically
28
+ # ignored in the loss function.
29
+ if word_idx is None:
30
+ label_ids.append(-100)
31
+ # We set the label for the first token of each word.
32
+ elif word_idx != previous_word_idx:
33
+ label_ids.append(label[word_idx])
34
+ # For the other tokens in a word, we set the label to either the current label or -100, depending on
35
+ # the label_all_tokens flag.
36
+ else:
37
+ label_ids.append(label[word_idx] if label_all_tokens else -100)
38
+ previous_word_idx = word_idx
39
+
40
+ labels.append(label_ids)
41
+
42
+ tokenized_inputs["labels"] = labels
43
+ return tokenized_inputs
44
+
45
+ datasets = load_dataset("Babelscape/wikineural")
46
+
47
+ train_en_dataset = datasets['train_en']
48
+ val_en_dataset = datasets['val_en']
49
+ test_en_dataset = datasets['test_en']
50
+
51
+ train_tokenized = train_en_dataset.map(tokenize_and_align_labels, batched=True)
52
+ val_tokenized = val_en_dataset.map(tokenize_and_align_labels, batched=True)
53
+ test_tokenized = test_en_dataset.map(tokenize_and_align_labels, batched=True)
54
+ return train_tokenized, val_tokenized, test_tokenized
55
+
56
+ def compute_metrics(p):
57
+ predictions, labels = p
58
+ predictions = np.argmax(predictions, axis=2)
59
+
60
+ # Remove ignored index (special tokens)
61
+ true_predictions = [
62
+ [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
63
+ for prediction, label in zip(predictions, labels)
64
+ ]
65
+ true_labels = [
66
+ [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
67
+ for prediction, label in zip(predictions, labels)
68
+ ]
69
+
70
+ results = metric.compute(predictions=true_predictions, references=true_labels)
71
+ return {
72
+ "precision": results["overall_precision"],
73
+ "recall": results["overall_recall"],
74
+ "f1": results["overall_f1"],
75
+ "accuracy": results["overall_accuracy"],
76
+ }
77
+
78
+ def main():
79
+ MODEL_NAME = "bert-base-cased"
80
+
81
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
82
+ train_tokenized, val_tokenized, test_tokenized = load_dataset(tokenizer)
83
+
84
+ model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME, num_labels=len(label_list),
85
+ label2id=labels_vocab, id2label=labels_vocab_reverse)
86
+ data_collator = DataCollatorForTokenClassification(tokenizer)
87
+
88
+
89
+ args = TrainingArguments(
90
+ "wikineural-multilingual-ner",
91
+ evaluation_strategy = "steps",
92
+ learning_rate=2e-5,
93
+ per_device_train_batch_size=32,
94
+ per_device_eval_batch_size=32,
95
+ num_train_epochs=1,
96
+ do_train=True,
97
+ do_eval=True,
98
+ weight_decay=0.01,
99
+ eval_steps=10000,
100
+ save_steps=10000
101
+ )
102
+
103
+ trainer = Trainer(
104
+ model,
105
+ args,
106
+ train_dataset=train_tokenized,
107
+ eval_dataset=test_tokenized,
108
+ data_collator=data_collator,
109
+ tokenizer=tokenizer,
110
+ compute_metrics=compute_metrics
111
+ )
112
+
113
+ trainer.train()
114
+ trainer.evaluate()
115
+
116
+ predictions, labels, _ = trainer.predict(test_tokenized)
117
+ predictions = np.argmax(predictions, axis=2)
118
+
119
+ # Remove ignored index (special tokens)
120
+ true_predictions = [
121
+ [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
122
+ for prediction, label in zip(predictions, labels)
123
+ ]
124
+ true_labels = [
125
+ [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
126
+ for prediction, label in zip(predictions, labels)
127
+ ]
128
+
129
+ results = metric.compute(predictions=true_predictions, references=true_labels)
130
+ results
131
+
132
+
133
+ return 0
134
+
135
+ if __name__ == "__main__":
136
+ main()