sdadas commited on
Commit
72e75d7
·
verified ·
1 Parent(s): 35b7aea

Updated for Transformers 5.4

Browse files
Files changed (1) hide show
  1. modeling_roberta.py +194 -197
modeling_roberta.py CHANGED
@@ -1,197 +1,194 @@
1
- from typing import Unpack
2
- import torch
3
- from transformers import (
4
- RobertaModel,
5
- Cache,
6
- EncoderDecoderCache,
7
- DynamicCache,
8
- DataCollatorWithFlattening,
9
- RobertaForMaskedLM,
10
- RobertaForSequenceClassification,
11
- RobertaForTokenClassification,
12
- RobertaForQuestionAnswering,
13
- RobertaForMultipleChoice,
14
- RobertaForCausalLM
15
- )
16
- from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
17
- from transformers.utils import TransformersKwargs
18
-
19
-
20
- def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
21
- collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True)
22
- features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)])
23
- return features
24
-
25
-
26
- def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor:
27
- if inputs.dim() == 3:
28
- inputs = inputs.squeeze()
29
- if inputs.dim() == 1:
30
- output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
31
- output[indices] = inputs
32
- padded_inputs = output.view(batch, seqlen)
33
- else:
34
- _, *rest = inputs.shape
35
- output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
36
- output[indices] = inputs
37
- padded_inputs = output.view(batch, seqlen, *rest)
38
- return padded_inputs
39
-
40
-
41
- class UnpadRobertaModel(RobertaModel):
42
- _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"]
43
-
44
- def __init__(self, config, add_pooling_layer=True):
45
- super().__init__(config, add_pooling_layer=add_pooling_layer)
46
-
47
- def forward(
48
- self,
49
- input_ids: torch.Tensor | None = None,
50
- attention_mask: torch.Tensor | None = None,
51
- token_type_ids: torch.Tensor | None = None,
52
- position_ids: torch.Tensor | None = None,
53
- inputs_embeds: torch.Tensor | None = None,
54
- encoder_hidden_states: torch.Tensor | None = None,
55
- encoder_attention_mask: torch.Tensor | None = None,
56
- past_key_values: Cache | None = None,
57
- use_cache: bool | None = None,
58
- cache_position: torch.Tensor | None = None,
59
- **kwargs: Unpack[TransformersKwargs],
60
- ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
61
- if self.config.is_decoder:
62
- use_cache = use_cache if use_cache is not None else self.config.use_cache
63
- else:
64
- use_cache = False
65
-
66
- if use_cache and past_key_values is None:
67
- past_key_values = (
68
- EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
69
- if encoder_hidden_states is not None or self.config.is_encoder_decoder
70
- else DynamicCache(config=self.config)
71
- )
72
-
73
- if (input_ids is None) ^ (inputs_embeds is not None):
74
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
75
-
76
- if input_ids is not None:
77
- device = input_ids.device
78
- seq_length = input_ids.shape[1]
79
- batch_size = input_ids.size(0)
80
- else:
81
- device = inputs_embeds.device
82
- seq_length = inputs_embeds.shape[1]
83
- batch_size = inputs_embeds.size(0)
84
-
85
- past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
86
- if cache_position is None:
87
- cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
88
-
89
- indices = None
90
- if self.config._attn_implementation.startswith("flash_attention"):
91
- if input_ids is None or attention_mask is None:
92
- raise ValueError("Unpadding requires both input_ids and attention_mask")
93
- with torch.no_grad():
94
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
95
- features = _unpad_input(input_ids, attention_mask)
96
- input_ids = features["input_ids"].to(device=device)
97
- # roberta requires shifting position_ids by 2
98
- position_ids = (features["position_ids"] + 2).to(device=device)
99
- attention_mask = None
100
- kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device)
101
- kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device)
102
- kwargs["max_length_k"] = features["max_length_k"]
103
- kwargs["max_length_q"] = features["max_length_q"]
104
-
105
- embedding_output = self.embeddings(
106
- input_ids=input_ids,
107
- position_ids=position_ids,
108
- token_type_ids=token_type_ids,
109
- inputs_embeds=inputs_embeds,
110
- past_key_values_length=past_key_values_length,
111
- )
112
-
113
- attention_mask, encoder_attention_mask = self._create_attention_masks(
114
- attention_mask=attention_mask,
115
- encoder_attention_mask=encoder_attention_mask,
116
- embedding_output=embedding_output,
117
- encoder_hidden_states=encoder_hidden_states,
118
- cache_position=cache_position,
119
- past_key_values=past_key_values,
120
- )
121
-
122
- encoder_outputs = self.encoder(
123
- embedding_output,
124
- attention_mask=attention_mask,
125
- encoder_hidden_states=encoder_hidden_states,
126
- encoder_attention_mask=encoder_attention_mask,
127
- past_key_values=past_key_values,
128
- use_cache=use_cache,
129
- cache_position=cache_position,
130
- position_ids=position_ids,
131
- **kwargs,
132
- )
133
-
134
- sequence_output = encoder_outputs.last_hidden_state
135
- if self.config._attn_implementation.startswith("flash_attention"):
136
- sequence_output = _pad_output(
137
- inputs=sequence_output, indices=indices, batch=batch_size, seqlen=seq_length
138
- )
139
-
140
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
141
- return BaseModelOutputWithPoolingAndCrossAttentions(
142
- last_hidden_state=sequence_output,
143
- pooler_output=pooled_output,
144
- past_key_values=encoder_outputs.past_key_values,
145
- )
146
-
147
-
148
- class UnpadRobertaForCausalLM(RobertaForCausalLM):
149
-
150
- def __init__(self, config):
151
- super().__init__(config)
152
- self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
153
- self.post_init()
154
-
155
-
156
- class UnpadRobertaForMaskedLM(RobertaForMaskedLM):
157
-
158
- def __init__(self, config):
159
- super().__init__(config)
160
- self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
161
- self.post_init()
162
-
163
-
164
- class UnpadRobertaForSequenceClassification(RobertaForSequenceClassification):
165
-
166
- def __init__(self, config):
167
- super().__init__(config)
168
- self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
169
- self.post_init()
170
-
171
-
172
- class UnpadRobertaForTokenClassification(RobertaForTokenClassification):
173
-
174
- def __init__(self, config):
175
- super().__init__(config)
176
- self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
177
- self.post_init()
178
-
179
-
180
- class UnpadRobertaForMultipleChoice(RobertaForMultipleChoice):
181
-
182
- def __init__(self, config):
183
- super().__init__(config)
184
- self.roberta = UnpadRobertaModel(config)
185
- self.post_init()
186
-
187
-
188
- class UnpadRobertaForQuestionAnswering(RobertaForQuestionAnswering):
189
-
190
- def __init__(self, config):
191
- super().__init__(config)
192
- self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
193
- self.post_init()
194
-
195
-
196
- def enable_roberta_unpadding():
197
- RobertaModel.forward = UnpadRobertaModel.forward
 
1
+ from typing import Unpack
2
+ import torch
3
+ from transformers import (
4
+ RobertaModel,
5
+ Cache,
6
+ EncoderDecoderCache,
7
+ DynamicCache,
8
+ DataCollatorWithFlattening,
9
+ RobertaForMaskedLM,
10
+ RobertaForSequenceClassification,
11
+ RobertaForTokenClassification,
12
+ RobertaForQuestionAnswering,
13
+ RobertaForMultipleChoice,
14
+ RobertaForCausalLM
15
+ )
16
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
17
+ from transformers.utils import TransformersKwargs
18
+
19
+
20
+ def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
21
+ collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True)
22
+ features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)])
23
+ return features
24
+
25
+
26
+ def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor:
27
+ if inputs.dim() == 3:
28
+ inputs = inputs.squeeze()
29
+ if inputs.dim() == 1:
30
+ output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
31
+ output[indices] = inputs
32
+ padded_inputs = output.view(batch, seqlen)
33
+ else:
34
+ _, *rest = inputs.shape
35
+ output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
36
+ output[indices] = inputs
37
+ padded_inputs = output.view(batch, seqlen, *rest)
38
+ return padded_inputs
39
+
40
+
41
+ class UnpadRobertaModel(RobertaModel):
42
+ _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"]
43
+
44
+ def __init__(self, config, add_pooling_layer=True):
45
+ super().__init__(config, add_pooling_layer=add_pooling_layer)
46
+
47
+ def forward(
48
+ self,
49
+ input_ids: torch.Tensor | None = None,
50
+ attention_mask: torch.Tensor | None = None,
51
+ token_type_ids: torch.Tensor | None = None,
52
+ position_ids: torch.Tensor | None = None,
53
+ inputs_embeds: torch.Tensor | None = None,
54
+ encoder_hidden_states: torch.Tensor | None = None,
55
+ encoder_attention_mask: torch.Tensor | None = None,
56
+ past_key_values: Cache | None = None,
57
+ use_cache: bool | None = None,
58
+ cache_position: torch.Tensor | None = None,
59
+ **kwargs: Unpack[TransformersKwargs],
60
+ ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
61
+
62
+ if (input_ids is None) ^ (inputs_embeds is not None):
63
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
64
+
65
+ if self.config.is_decoder:
66
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
67
+ else:
68
+ use_cache = False
69
+
70
+ if use_cache and past_key_values is None:
71
+ past_key_values = (
72
+ EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
73
+ if encoder_hidden_states is not None or self.config.is_encoder_decoder
74
+ else DynamicCache(config=self.config)
75
+ )
76
+
77
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
78
+
79
+ if input_ids is not None:
80
+ device = input_ids.device
81
+ seq_length = input_ids.shape[1]
82
+ batch_size = input_ids.size(0)
83
+ else:
84
+ device = inputs_embeds.device
85
+ seq_length = inputs_embeds.shape[1]
86
+ batch_size = inputs_embeds.size(0)
87
+
88
+ indices = None
89
+ if self.config._attn_implementation.startswith("flash_attention"):
90
+ if input_ids is None or attention_mask is None:
91
+ raise ValueError("Unpadding requires both input_ids and attention_mask")
92
+ with torch.no_grad():
93
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
94
+ features = _unpad_input(input_ids, attention_mask)
95
+ input_ids = features["input_ids"].to(device=device)
96
+ # roberta requires shifting position_ids by 2
97
+ position_ids = (features["position_ids"] + 2).to(device=device)
98
+ attention_mask = None
99
+ kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device)
100
+ kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device)
101
+ kwargs["max_length_k"] = features["max_length_k"]
102
+ kwargs["max_length_q"] = features["max_length_q"]
103
+
104
+ embedding_output = self.embeddings(
105
+ input_ids=input_ids,
106
+ position_ids=position_ids,
107
+ token_type_ids=token_type_ids,
108
+ inputs_embeds=inputs_embeds,
109
+ past_key_values_length=past_key_values_length,
110
+ )
111
+
112
+ attention_mask, encoder_attention_mask = self._create_attention_masks(
113
+ attention_mask=attention_mask,
114
+ encoder_attention_mask=encoder_attention_mask,
115
+ embedding_output=embedding_output,
116
+ encoder_hidden_states=encoder_hidden_states,
117
+ past_key_values=past_key_values,
118
+ )
119
+
120
+ encoder_outputs = self.encoder(
121
+ embedding_output,
122
+ attention_mask=attention_mask,
123
+ encoder_hidden_states=encoder_hidden_states,
124
+ encoder_attention_mask=encoder_attention_mask,
125
+ past_key_values=past_key_values,
126
+ use_cache=use_cache,
127
+ position_ids=position_ids,
128
+ **kwargs,
129
+ )
130
+
131
+ sequence_output = encoder_outputs.last_hidden_state
132
+ if self.config._attn_implementation.startswith("flash_attention"):
133
+ sequence_output = _pad_output(
134
+ inputs=sequence_output, indices=indices, batch=batch_size, seqlen=seq_length
135
+ )
136
+
137
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
138
+ return BaseModelOutputWithPoolingAndCrossAttentions(
139
+ last_hidden_state=sequence_output,
140
+ pooler_output=pooled_output,
141
+ past_key_values=encoder_outputs.past_key_values,
142
+ )
143
+
144
+
145
+ class UnpadRobertaForCausalLM(RobertaForCausalLM):
146
+
147
+ def __init__(self, config):
148
+ super().__init__(config)
149
+ self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
150
+ self.post_init()
151
+
152
+
153
+ class UnpadRobertaForMaskedLM(RobertaForMaskedLM):
154
+
155
+ def __init__(self, config):
156
+ super().__init__(config)
157
+ self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
158
+ self.post_init()
159
+
160
+
161
+ class UnpadRobertaForSequenceClassification(RobertaForSequenceClassification):
162
+
163
+ def __init__(self, config):
164
+ super().__init__(config)
165
+ self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
166
+ self.post_init()
167
+
168
+
169
+ class UnpadRobertaForTokenClassification(RobertaForTokenClassification):
170
+
171
+ def __init__(self, config):
172
+ super().__init__(config)
173
+ self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
174
+ self.post_init()
175
+
176
+
177
+ class UnpadRobertaForMultipleChoice(RobertaForMultipleChoice):
178
+
179
+ def __init__(self, config):
180
+ super().__init__(config)
181
+ self.roberta = UnpadRobertaModel(config)
182
+ self.post_init()
183
+
184
+
185
+ class UnpadRobertaForQuestionAnswering(RobertaForQuestionAnswering):
186
+
187
+ def __init__(self, config):
188
+ super().__init__(config)
189
+ self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
190
+ self.post_init()
191
+
192
+
193
+ def enable_roberta_unpadding():
194
+ RobertaModel.forward = UnpadRobertaModel.forward