qninhdt commited on
Commit
8db8077
·
1 Parent(s): 8cb2a79
configs/callbacks/default.yaml CHANGED
@@ -8,7 +8,7 @@ defaults:
8
  model_checkpoint:
9
  dirpath: ${paths.output_dir}/checkpoints
10
  filename: "epoch_{epoch:03d}"
11
- monitor: "val/acc"
12
  mode: "max"
13
  save_last: True
14
  auto_insert_metric_name: False
 
8
  model_checkpoint:
9
  dirpath: ${paths.output_dir}/checkpoints
10
  filename: "epoch_{epoch:03d}"
11
+ monitor: "val/1_acc"
12
  mode: "max"
13
  save_last: True
14
  auto_insert_metric_name: False
src/models/miniagent_module.py CHANGED
@@ -31,9 +31,17 @@ class MiniAgentModule(LightningModule):
31
  self.tool_proj_model = tool_proj_model
32
  self.pred_model = pred_model
33
 
34
- self.val_acc = Accuracy(task="binary")
35
- self.val_precision = MeanMetric()
36
- self.val_recall = MeanMetric()
 
 
 
 
 
 
 
 
37
 
38
  self.lr = lr
39
 
@@ -69,8 +77,6 @@ class MiniAgentModule(LightningModule):
69
 
70
  pos_weight = torch.tensor([B - 1], device=pred.device)
71
  loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight)
72
- # labels = torch.arange(B, device=pred.device).long()
73
- # loss = (F.cross_entropy(pred, labels) + F.cross_entropy(pred.T, labels)) * 0.5
74
 
75
  self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
76
 
@@ -107,31 +113,151 @@ class MiniAgentModule(LightningModule):
107
  pred = torch.sigmoid(pred)
108
 
109
  pred_tool_mask = pred > 0.5
110
- pos_sample = (pred_tool_mask == correct_tool_mask).all(dim=1).long()
111
 
112
  true_pos_mask = pred_tool_mask & correct_tool_mask
113
 
114
- precision = true_pos_mask.sum(dim=1) / torch.clamp(
115
- pred_tool_mask.sum(dim=1), min=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  )
117
- recall = true_pos_mask.sum(dim=1) / correct_tool_mask.sum(dim=1)
118
-
119
- self.val_acc.update(pos_sample, torch.ones_like(pos_sample))
120
- self.val_precision.update(precision)
121
- self.val_recall.update(recall)
122
-
123
- self.log("val/acc", self.val_acc, on_epoch=True, sync_dist=True, prog_bar=True)
124
- self.log(
125
- "val/precision",
126
- self.val_precision,
127
- on_epoch=True,
128
- sync_dist=True,
129
- prog_bar=True,
130
  )
131
- self.log(
132
- "val/recall", self.val_recall, on_epoch=True, sync_dist=True, prog_bar=True
 
133
  )
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def on_validation_epoch_end(self) -> None:
136
  pass
137
 
 
31
  self.tool_proj_model = tool_proj_model
32
  self.pred_model = pred_model
33
 
34
+ self.val_1_acc = Accuracy(task="binary")
35
+ self.val_1_precision = MeanMetric()
36
+ self.val_1_recall = MeanMetric()
37
+
38
+ self.val_2_acc = Accuracy(task="binary")
39
+ self.val_2_precision = MeanMetric()
40
+ self.val_2_recall = MeanMetric()
41
+
42
+ self.val_other_acc = Accuracy(task="binary")
43
+ self.val_other_precision = MeanMetric()
44
+ self.val_other_recall = MeanMetric()
45
 
46
  self.lr = lr
47
 
 
77
 
78
  pos_weight = torch.tensor([B - 1], device=pred.device)
79
  loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight)
 
 
80
 
81
  self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
82
 
 
113
  pred = torch.sigmoid(pred)
114
 
115
  pred_tool_mask = pred > 0.5
 
116
 
117
  true_pos_mask = pred_tool_mask & correct_tool_mask
118
 
119
+ one_tool_mask = correct_tool_mask.sum(dim=1) == 1
120
+ two_tool_mask = correct_tool_mask.sum(dim=1) == 2
121
+ other_mask = ~(one_tool_mask | two_tool_mask)
122
+
123
+ # one tool
124
+ one_tool_pos_sample = (
125
+ (pred_tool_mask[one_tool_mask] == correct_tool_mask[one_tool_mask])
126
+ .all(dim=1)
127
+ .long()
128
+ )
129
+
130
+ one_tool_precision = true_pos_mask[one_tool_mask].sum(dim=1) / torch.clamp(
131
+ pred_tool_mask[one_tool_mask].sum(dim=1), min=1
132
+ )
133
+
134
+ one_tool_recall = true_pos_mask[one_tool_mask].sum(dim=1) / torch.clamp(
135
+ correct_tool_mask[one_tool_mask].sum(dim=1), min=1
136
+ )
137
+
138
+ # two tool
139
+ two_tool_pos_sample = (
140
+ (pred_tool_mask[two_tool_mask] == correct_tool_mask[two_tool_mask])
141
+ .all(dim=1)
142
+ .long()
143
  )
144
+
145
+ two_tool_precision = true_pos_mask[two_tool_mask].sum(dim=1) / torch.clamp(
146
+ pred_tool_mask[two_tool_mask].sum(dim=1), min=1
 
 
 
 
 
 
 
 
 
 
147
  )
148
+
149
+ two_tool_recall = true_pos_mask[two_tool_mask].sum(dim=1) / torch.clamp(
150
+ correct_tool_mask[two_tool_mask].sum(dim=1), min=1
151
  )
152
 
153
+ # other
154
+ other_pos_sample = (
155
+ (pred_tool_mask[other_mask] == correct_tool_mask[other_mask])
156
+ .all(dim=1)
157
+ .long()
158
+ )
159
+
160
+ other_precision = true_pos_mask[other_mask].sum(dim=1) / torch.clamp(
161
+ pred_tool_mask[other_mask].sum(dim=1), min=1
162
+ )
163
+
164
+ other_recall = true_pos_mask[other_mask].sum(dim=1) / torch.clamp(
165
+ correct_tool_mask[other_mask].sum(dim=1), min=1
166
+ )
167
+
168
+ if one_tool_pos_sample.sum().item() > 0:
169
+ self.val_1_acc.update(
170
+ one_tool_pos_sample, torch.ones_like(one_tool_pos_sample)
171
+ )
172
+ self.val_1_precision.update(one_tool_precision)
173
+ self.val_1_recall.update(one_tool_recall)
174
+
175
+ self.log(
176
+ "val/1_acc",
177
+ self.val_1_acc,
178
+ on_epoch=True,
179
+ sync_dist=True,
180
+ prog_bar=True,
181
+ )
182
+
183
+ self.log(
184
+ "val/1_precision",
185
+ self.val_1_precision,
186
+ on_epoch=True,
187
+ sync_dist=True,
188
+ prog_bar=True,
189
+ )
190
+
191
+ self.log(
192
+ "val/1_recall",
193
+ self.val_1_recall,
194
+ on_epoch=True,
195
+ sync_dist=True,
196
+ prog_bar=True,
197
+ )
198
+
199
+ if two_tool_pos_sample.sum().item() > 0:
200
+ self.val_2_acc.update(
201
+ two_tool_pos_sample, torch.ones_like(two_tool_pos_sample)
202
+ )
203
+ self.val_2_precision.update(two_tool_precision)
204
+ self.val_2_recall.update(two_tool_recall)
205
+
206
+ self.log(
207
+ "val/2_acc",
208
+ self.val_2_acc,
209
+ on_epoch=True,
210
+ sync_dist=True,
211
+ prog_bar=True,
212
+ )
213
+
214
+ self.log(
215
+ "val/2_precision",
216
+ self.val_2_precision,
217
+ on_epoch=True,
218
+ sync_dist=True,
219
+ prog_bar=True,
220
+ )
221
+
222
+ self.log(
223
+ "val/2_recall",
224
+ self.val_2_recall,
225
+ on_epoch=True,
226
+ sync_dist=True,
227
+ prog_bar=True,
228
+ )
229
+
230
+ if other_pos_sample.sum().item() > 0:
231
+ self.val_other_acc.update(
232
+ other_pos_sample, torch.ones_like(other_pos_sample)
233
+ )
234
+ self.val_other_precision.update(other_precision)
235
+ self.val_other_recall.update(other_recall)
236
+
237
+ self.log(
238
+ "val/other_acc",
239
+ self.val_other_acc,
240
+ on_epoch=True,
241
+ sync_dist=True,
242
+ prog_bar=True,
243
+ )
244
+
245
+ self.log(
246
+ "val/other_precision",
247
+ self.val_other_precision,
248
+ on_epoch=True,
249
+ sync_dist=True,
250
+ prog_bar=True,
251
+ )
252
+
253
+ self.log(
254
+ "val/other_recall",
255
+ self.val_other_recall,
256
+ on_epoch=True,
257
+ sync_dist=True,
258
+ prog_bar=True,
259
+ )
260
+
261
  def on_validation_epoch_end(self) -> None:
262
  pass
263
 
src/models/mlp_module.py CHANGED
@@ -7,7 +7,7 @@ class MLPProjection(nn.Module):
7
  def __init__(self, input_dim, hidden_dim, output_dim):
8
  super().__init__()
9
  self.linear1 = nn.Linear(input_dim, hidden_dim)
10
- self.dropout = nn.Dropout(0.3)
11
  self.linear2 = nn.Linear(hidden_dim, output_dim)
12
 
13
  def forward(self, x_output):
@@ -34,10 +34,10 @@ class MLPPrediction(nn.Module):
34
  self.mlp = nn.Sequential(
35
  nn.Linear(real_input_dim, 512),
36
  nn.SiLU(),
37
- nn.Dropout(0.3),
38
  nn.Linear(512, 256),
39
  nn.SiLU(),
40
- nn.Dropout(0.3),
41
  nn.Linear(256, 128),
42
  nn.SiLU(),
43
  nn.Linear(128, 1),
 
7
  def __init__(self, input_dim, hidden_dim, output_dim):
8
  super().__init__()
9
  self.linear1 = nn.Linear(input_dim, hidden_dim)
10
+ self.dropout = nn.Dropout(0.5)
11
  self.linear2 = nn.Linear(hidden_dim, output_dim)
12
 
13
  def forward(self, x_output):
 
34
  self.mlp = nn.Sequential(
35
  nn.Linear(real_input_dim, 512),
36
  nn.SiLU(),
37
+ nn.Dropout(0.5),
38
  nn.Linear(512, 256),
39
  nn.SiLU(),
40
+ nn.Dropout(0.5),
41
  nn.Linear(256, 128),
42
  nn.SiLU(),
43
  nn.Linear(128, 1),