sdadas commited on
Commit
8ef3829
·
verified ·
1 Parent(s): 18d81cd

Upload modeling_eurobert.py

Browse files
Files changed (1) hide show
  1. modeling_eurobert.py +136 -0
modeling_eurobert.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Unpack
2
+ import torch
3
+ from transformers import DataCollatorWithFlattening
4
+ from transformers.masking_utils import create_bidirectional_mask
5
+ from transformers.modeling_outputs import BaseModelOutput
6
+ from transformers.models.eurobert import (
7
+ EuroBertForMaskedLM,
8
+ EuroBertModel,
9
+ EuroBertForSequenceClassification,
10
+ EuroBertForTokenClassification
11
+ )
12
+ from transformers.utils import TransformersKwargs
13
+
14
+
15
+ def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
16
+ collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True)
17
+ features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)])
18
+ return features
19
+
20
+
21
+ def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor:
22
+ if inputs.dim() == 3:
23
+ inputs = inputs.squeeze()
24
+ if inputs.dim() == 1:
25
+ output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
26
+ output[indices] = inputs
27
+ padded_inputs = output.view(batch, seqlen)
28
+ else:
29
+ _, *rest = inputs.shape
30
+ output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
31
+ output[indices] = inputs
32
+ padded_inputs = output.view(batch, seqlen, *rest)
33
+ return padded_inputs
34
+
35
+
36
+ class UnpadEuroBertModel(EuroBertModel):
37
+
38
+ def __init__(self, config):
39
+ super().__init__(config)
40
+
41
+ def forward(
42
+ self,
43
+ input_ids: torch.LongTensor = None,
44
+ attention_mask: torch.Tensor | None = None,
45
+ position_ids: torch.LongTensor | None = None,
46
+ inputs_embeds: torch.FloatTensor | None = None,
47
+ **kwargs: Unpack[TransformersKwargs],
48
+ ) -> tuple | BaseModelOutput:
49
+ if (input_ids is None) ^ (inputs_embeds is not None):
50
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
51
+
52
+ if input_ids is not None:
53
+ device = input_ids.device
54
+ seq_length = input_ids.shape[1]
55
+ batch_size = input_ids.size(0)
56
+ else:
57
+ device = inputs_embeds.device
58
+ seq_length = inputs_embeds.shape[1]
59
+ batch_size = inputs_embeds.size(0)
60
+
61
+ indices = None
62
+ if self.config._attn_implementation.startswith("flash_attention"):
63
+ if input_ids is None or attention_mask is None:
64
+ raise ValueError("Unpadding requires both input_ids and attention_mask")
65
+ with torch.no_grad():
66
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
67
+ features = _unpad_input(input_ids, attention_mask)
68
+ input_ids = features["input_ids"].to(device=device)
69
+ position_ids = features["position_ids"].to(device=device)
70
+ attention_mask = None
71
+ kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device)
72
+ kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device)
73
+ kwargs["max_length_k"] = features["max_length_k"]
74
+ kwargs["max_length_q"] = features["max_length_q"]
75
+
76
+ if inputs_embeds is None:
77
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
78
+
79
+ if position_ids is None:
80
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
81
+
82
+ bidirectional_mask = create_bidirectional_mask(
83
+ config=self.config,
84
+ inputs_embeds=inputs_embeds,
85
+ attention_mask=attention_mask,
86
+ )
87
+
88
+ hidden_states = inputs_embeds
89
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
90
+
91
+ for encoder_layer in self.layers[: self.config.num_hidden_layers]:
92
+ hidden_states = encoder_layer(
93
+ hidden_states,
94
+ attention_mask=bidirectional_mask,
95
+ position_embeddings=position_embeddings,
96
+ position_ids=position_ids,
97
+ **kwargs,
98
+ )
99
+
100
+ hidden_states = self.norm(hidden_states)
101
+ if self.config._attn_implementation.startswith("flash_attention"):
102
+ hidden_states = _pad_output(
103
+ inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_length
104
+ )
105
+
106
+ return BaseModelOutput(
107
+ last_hidden_state=hidden_states,
108
+ )
109
+
110
+
111
+ class UnpadEuroBertForMaskedLM(EuroBertForMaskedLM):
112
+
113
+ def __init__(self, config):
114
+ super().__init__(config)
115
+ self.model = UnpadEuroBertModel(config)
116
+ self.post_init()
117
+
118
+
119
+ class UnpadEuroBertForSequenceClassification(EuroBertForSequenceClassification):
120
+
121
+ def __init__(self, config):
122
+ super().__init__(config)
123
+ self.model = UnpadEuroBertModel(config)
124
+ self.post_init()
125
+
126
+
127
+ class UnpadEuroBertForTokenClassification(EuroBertForTokenClassification):
128
+
129
+ def __init__(self, config):
130
+ super().__init__(config)
131
+ self.model = UnpadEuroBertModel(config)
132
+ self.post_init()
133
+
134
+
135
+ def enable_eurobert_unpadding():
136
+ EuroBertModel.forward = UnpadEuroBertModel.forward