skatzR commited on
Commit
5d2179c
·
verified ·
1 Parent(s): 67d909f

Update modeling_rqa.py

Browse files
Files changed (1) hide show
  1. modeling_rqa.py +61 -87
modeling_rqa.py CHANGED
@@ -1,8 +1,12 @@
1
- from typing import Any, Dict, List, Optional
2
-
3
  import torch
4
  import torch.nn as nn
5
- from transformers import AutoConfig, AutoModel, PreTrainedModel, PretrainedConfig
 
 
 
 
 
 
6
 
7
 
8
  class RQAModelConfig(PretrainedConfig):
@@ -11,7 +15,7 @@ class RQAModelConfig(PretrainedConfig):
11
  def __init__(
12
  self,
13
  base_model_name: str = "FacebookAI/xlm-roberta-large",
14
- encoder_config: Optional[Dict[str, Any]] = None,
15
  error_types: Optional[List[str]] = None,
16
  schema_version: str = "rqa.v2.2",
17
  has_issue_projection_dim: int = 256,
@@ -19,7 +23,7 @@ class RQAModelConfig(PretrainedConfig):
19
  errors_projection_dim: int = 512,
20
  has_issue_dropout: float = 0.25,
21
  hidden_dropout: float = 0.25,
22
- errors_dropout: float = 0.30,
23
  temperature_has_issue: float = 1.0,
24
  temperature_is_hidden: float = 1.0,
25
  temperature_errors: Optional[List[float]] = None,
@@ -27,28 +31,35 @@ class RQAModelConfig(PretrainedConfig):
27
  threshold_is_hidden: float = 0.5,
28
  threshold_error: float = 0.5,
29
  threshold_errors: Optional[List[float]] = None,
30
- **kwargs,
31
  ):
32
  super().__init__(**kwargs)
33
 
34
- self.schema_version = str(schema_version)
35
  self.base_model_name = base_model_name
36
  self.encoder_config = encoder_config
37
- self.error_types = list(error_types or [])
 
 
 
 
 
 
 
38
  self.num_error_types = len(self.error_types)
 
39
 
40
- self.has_issue_projection_dim = int(has_issue_projection_dim)
41
- self.hidden_projection_dim = int(hidden_projection_dim)
42
- self.errors_projection_dim = int(errors_projection_dim)
43
 
44
- self.has_issue_dropout = float(has_issue_dropout)
45
- self.hidden_dropout = float(hidden_dropout)
46
- self.errors_dropout = float(errors_dropout)
47
 
48
  self.temperature_has_issue = float(temperature_has_issue)
49
  self.temperature_is_hidden = float(temperature_is_hidden)
50
  self.temperature_errors = (
51
- list(temperature_errors)
52
  if temperature_errors is not None
53
  else [1.0] * self.num_error_types
54
  )
@@ -57,9 +68,9 @@ class RQAModelConfig(PretrainedConfig):
57
  self.threshold_is_hidden = float(threshold_is_hidden)
58
  self.threshold_error = float(threshold_error)
59
  self.threshold_errors = (
60
- list(threshold_errors)
61
  if threshold_errors is not None
62
- else [self.threshold_error] * self.num_error_types
63
  )
64
 
65
  try:
@@ -69,32 +80,8 @@ class RQAModelConfig(PretrainedConfig):
69
  pass
70
 
71
 
72
- def build_encoder_config_from_saved_dict(
73
- encoder_config: Optional[Dict[str, Any]],
74
- base_model_name: str,
75
- ):
76
- if encoder_config is None:
77
- return AutoConfig.from_pretrained(base_model_name)
78
-
79
- cfg_dict = dict(encoder_config)
80
- model_type = cfg_dict.pop("model_type", None)
81
- cfg_dict.pop("_name_or_path", None)
82
-
83
- if model_type is not None:
84
- try:
85
- return AutoConfig.for_model(model_type, **cfg_dict)
86
- except Exception:
87
- pass
88
-
89
- return AutoConfig.from_pretrained(base_model_name)
90
-
91
-
92
  class MeanPooling(nn.Module):
93
- def forward(
94
- self,
95
- last_hidden_state: torch.Tensor,
96
- attention_mask: torch.Tensor,
97
- ) -> torch.Tensor:
98
  mask = attention_mask.unsqueeze(-1).float()
99
  summed = torch.sum(last_hidden_state * mask, dim=1)
100
  denom = torch.clamp(mask.sum(dim=1), min=1e-9)
@@ -106,24 +93,17 @@ class RQAModelHF(PreTrainedModel):
106
  _supports_grouped_mm = False
107
 
108
  def __init__(self, config: RQAModelConfig):
 
 
109
  try:
110
  config._experts_implementation = "eager"
111
  config._experts_implementation_internal = "eager"
112
  except Exception:
113
  pass
114
- super().__init__(config)
115
-
116
- if config.encoder_config is None:
117
- base_cfg = AutoConfig.from_pretrained(config.base_model_name)
118
- config.encoder_config = base_cfg.to_dict()
119
-
120
- enc_cfg = build_encoder_config_from_saved_dict(
121
- encoder_config=config.encoder_config,
122
- base_model_name=config.base_model_name,
123
- )
124
- self.encoder = AutoModel.from_config(enc_cfg)
125
 
 
126
  hidden_size = self.encoder.config.hidden_size
 
127
  self.pooler = MeanPooling()
128
 
129
  self.has_issue_projection = nn.Sequential(
@@ -132,12 +112,14 @@ class RQAModelHF(PreTrainedModel):
132
  nn.GELU(),
133
  nn.Dropout(config.has_issue_dropout),
134
  )
 
135
  self.hidden_projection = nn.Sequential(
136
  nn.Linear(hidden_size, config.hidden_projection_dim),
137
  nn.LayerNorm(config.hidden_projection_dim),
138
  nn.GELU(),
139
  nn.Dropout(config.hidden_dropout),
140
  )
 
141
  self.errors_projection = nn.Sequential(
142
  nn.Linear(hidden_size, config.errors_projection_dim),
143
  nn.LayerNorm(config.errors_projection_dim),
@@ -155,11 +137,10 @@ class RQAModelHF(PreTrainedModel):
155
  self.log_var_has_issue = nn.Parameter(torch.zeros(1))
156
  self.log_var_is_hidden = nn.Parameter(torch.zeros(1))
157
  self.log_var_errors = nn.Parameter(torch.zeros(1))
158
- with torch.no_grad():
159
- self.log_var_has_issue.clamp_(-5, 5)
160
- self.log_var_is_hidden.clamp_(-5, 5)
161
- self.log_var_errors.clamp_(-5, 5)
162
 
 
 
 
163
  for module in [
164
  self.has_issue_projection[0],
165
  self.hidden_projection[0],
@@ -168,47 +149,40 @@ class RQAModelHF(PreTrainedModel):
168
  self.is_hidden_head,
169
  self.errors_head,
170
  ]:
171
- setattr(module, "_rqa_custom_init", True)
172
-
173
- self.post_init()
 
174
 
175
- def _init_weights(self, module):
176
- if isinstance(module, nn.Linear) and getattr(module, "_rqa_custom_init", False):
177
- nn.init.xavier_uniform_(module.weight)
178
- if module.bias is not None:
179
- nn.init.zeros_(module.bias)
180
-
181
- def forward(
182
- self,
183
- input_ids: torch.Tensor,
184
- attention_mask: torch.Tensor,
185
- **kwargs,
186
- ) -> Dict[str, torch.Tensor]:
187
  outputs = self.encoder(
188
  input_ids=input_ids,
189
  attention_mask=attention_mask,
190
  return_dict=True,
191
- **kwargs,
192
  )
 
193
  pooled = self.pooler(outputs.last_hidden_state, attention_mask)
194
 
195
- issue_features = self.has_issue_projection(pooled)
196
- hidden_features = self.hidden_projection(pooled)
197
- error_features = self.errors_projection(pooled)
 
 
 
 
 
 
 
 
198
 
199
  return {
200
- "has_issue_logits": self.has_issue_head(issue_features).squeeze(-1),
201
- "is_hidden_logits": self.is_hidden_head(hidden_features).squeeze(-1),
202
- "errors_logits": self.errors_head(error_features),
203
  }
204
 
205
 
206
- try:
207
- AutoConfig.register("rqa_v2_2", RQAModelConfig)
208
- except ValueError:
209
- pass
210
 
211
- try:
212
- AutoModel.register(RQAModelConfig, RQAModelHF)
213
- except ValueError:
214
- pass
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from typing import List, Optional
4
+ from transformers import (
5
+ AutoConfig,
6
+ AutoModel,
7
+ PreTrainedModel,
8
+ PretrainedConfig,
9
+ )
10
 
11
 
12
  class RQAModelConfig(PretrainedConfig):
 
15
  def __init__(
16
  self,
17
  base_model_name: str = "FacebookAI/xlm-roberta-large",
18
+ encoder_config: Optional[dict] = None,
19
  error_types: Optional[List[str]] = None,
20
  schema_version: str = "rqa.v2.2",
21
  has_issue_projection_dim: int = 256,
 
23
  errors_projection_dim: int = 512,
24
  has_issue_dropout: float = 0.25,
25
  hidden_dropout: float = 0.25,
26
+ errors_dropout: float = 0.3,
27
  temperature_has_issue: float = 1.0,
28
  temperature_is_hidden: float = 1.0,
29
  temperature_errors: Optional[List[float]] = None,
 
31
  threshold_is_hidden: float = 0.5,
32
  threshold_error: float = 0.5,
33
  threshold_errors: Optional[List[float]] = None,
34
+ **kwargs
35
  ):
36
  super().__init__(**kwargs)
37
 
 
38
  self.base_model_name = base_model_name
39
  self.encoder_config = encoder_config
40
+ self.error_types = error_types or [
41
+ "false_causality",
42
+ "unsupported_claim",
43
+ "overgeneralization",
44
+ "missing_premise",
45
+ "contradiction",
46
+ "circular_reasoning",
47
+ ]
48
  self.num_error_types = len(self.error_types)
49
+ self.schema_version = schema_version
50
 
51
+ self.has_issue_projection_dim = has_issue_projection_dim
52
+ self.hidden_projection_dim = hidden_projection_dim
53
+ self.errors_projection_dim = errors_projection_dim
54
 
55
+ self.has_issue_dropout = has_issue_dropout
56
+ self.hidden_dropout = hidden_dropout
57
+ self.errors_dropout = errors_dropout
58
 
59
  self.temperature_has_issue = float(temperature_has_issue)
60
  self.temperature_is_hidden = float(temperature_is_hidden)
61
  self.temperature_errors = (
62
+ temperature_errors
63
  if temperature_errors is not None
64
  else [1.0] * self.num_error_types
65
  )
 
68
  self.threshold_is_hidden = float(threshold_is_hidden)
69
  self.threshold_error = float(threshold_error)
70
  self.threshold_errors = (
71
+ threshold_errors
72
  if threshold_errors is not None
73
+ else [float(threshold_error)] * self.num_error_types
74
  )
75
 
76
  try:
 
80
  pass
81
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  class MeanPooling(nn.Module):
84
+ def forward(self, last_hidden_state, attention_mask):
 
 
 
 
85
  mask = attention_mask.unsqueeze(-1).float()
86
  summed = torch.sum(last_hidden_state * mask, dim=1)
87
  denom = torch.clamp(mask.sum(dim=1), min=1e-9)
 
93
  _supports_grouped_mm = False
94
 
95
  def __init__(self, config: RQAModelConfig):
96
+ super().__init__(config)
97
+
98
  try:
99
  config._experts_implementation = "eager"
100
  config._experts_implementation_internal = "eager"
101
  except Exception:
102
  pass
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ self.encoder = AutoModel.from_pretrained(config.base_model_name)
105
  hidden_size = self.encoder.config.hidden_size
106
+
107
  self.pooler = MeanPooling()
108
 
109
  self.has_issue_projection = nn.Sequential(
 
112
  nn.GELU(),
113
  nn.Dropout(config.has_issue_dropout),
114
  )
115
+
116
  self.hidden_projection = nn.Sequential(
117
  nn.Linear(hidden_size, config.hidden_projection_dim),
118
  nn.LayerNorm(config.hidden_projection_dim),
119
  nn.GELU(),
120
  nn.Dropout(config.hidden_dropout),
121
  )
122
+
123
  self.errors_projection = nn.Sequential(
124
  nn.Linear(hidden_size, config.errors_projection_dim),
125
  nn.LayerNorm(config.errors_projection_dim),
 
137
  self.log_var_has_issue = nn.Parameter(torch.zeros(1))
138
  self.log_var_is_hidden = nn.Parameter(torch.zeros(1))
139
  self.log_var_errors = nn.Parameter(torch.zeros(1))
 
 
 
 
140
 
141
+ self._init_custom_weights()
142
+
143
+ def _init_custom_weights(self):
144
  for module in [
145
  self.has_issue_projection[0],
146
  self.hidden_projection[0],
 
149
  self.is_hidden_head,
150
  self.errors_head,
151
  ]:
152
+ if isinstance(module, nn.Linear):
153
+ nn.init.xavier_uniform_(module.weight)
154
+ if module.bias is not None:
155
+ nn.init.zeros_(module.bias)
156
 
157
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
158
  outputs = self.encoder(
159
  input_ids=input_ids,
160
  attention_mask=attention_mask,
161
  return_dict=True,
 
162
  )
163
+
164
  pooled = self.pooler(outputs.last_hidden_state, attention_mask)
165
 
166
+ has_issue_logits = self.has_issue_head(
167
+ self.has_issue_projection(pooled)
168
+ ).squeeze(-1)
169
+
170
+ is_hidden_logits = self.is_hidden_head(
171
+ self.hidden_projection(pooled)
172
+ ).squeeze(-1)
173
+
174
+ errors_logits = self.errors_head(
175
+ self.errors_projection(pooled)
176
+ )
177
 
178
  return {
179
+ "has_issue_logits": has_issue_logits,
180
+ "is_hidden_logits": is_hidden_logits,
181
+ "errors_logits": errors_logits,
182
  }
183
 
184
 
185
+ AutoConfig.register("rqa_v2_2", RQAModelConfig)
186
+ AutoModel.register(RQAModelConfig, RQAModelHF)
 
 
187
 
188
+ print("✅ RQA-R2 зарегистрирован в Transformers")