qninhdt commited on
Commit
2c9e8bc
·
1 Parent(s): 0c2ae95
configs/experiment/miniagent-bert-mlp-abs_diff-mult.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: mixed
5
+ - override /model: miniagent
6
+ - override /callbacks: default
7
+ - override /trainer: gpu
8
+
9
+ seed: 42
10
+
11
+ model:
12
+ lr: 0.001
13
+ bert_model: bert-base-uncased
14
+
15
+ inst_proj_model:
16
+ _target_: src.models.mlp_module.MLPProjection
17
+ input_dim: 768
18
+ hidden_dim: 768
19
+ output_dim: 768
20
+
21
+ tool_proj_model:
22
+ _target_: src.models.mlp_module.MLPProjection
23
+ input_dim: 768
24
+ hidden_dim: 768
25
+ output_dim: 768
26
+
27
+ pred_model:
28
+ _target_: src.models.mlp_module.MLPPrediction
29
+ input_dim: 768
30
+ use_abs_diff: true
31
+ use_mult: true
32
+
33
+ data:
34
+ bert_model: bert-base-uncased
35
+ seed: 42
36
+ batch_size: 128
37
+ tool_capacity: 16
configs/trainer/default.yaml CHANGED
@@ -8,6 +8,8 @@ max_epochs: 10
8
  accelerator: cpu
9
  devices: 1
10
 
 
 
11
  # mixed precision for extra speed-up
12
  # precision: 16
13
 
 
8
  accelerator: cpu
9
  devices: 1
10
 
11
+ log_every_n_steps: 10
12
+
13
  # mixed precision for extra speed-up
14
  # precision: 16
15
 
src/data/mixed_datamodule.py CHANGED
@@ -47,6 +47,7 @@ class MixedDataModule(LightningDataModule):
47
  batch_size=self.batch_size,
48
  shuffle=True,
49
  num_workers=self.num_workers,
 
50
  )
51
 
52
  def val_dataloader(self):
 
47
  batch_size=self.batch_size,
48
  shuffle=True,
49
  num_workers=self.num_workers,
50
+ drop_last=True,
51
  )
52
 
53
  def val_dataloader(self):
src/models/miniagent_module.py CHANGED
@@ -65,10 +65,12 @@ class MiniAgentModule(LightningModule):
65
  pred = self.pred_model(inst_emb_r, tool_emb_r) # [BxB, 1]
66
  pred = pred.view(B, B) # [B, B]
67
 
68
- target = torch.eye(B, device=pred.device).float()
69
 
70
- pos_weight = torch.tensor([B - 1], device=pred.device)
71
- loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight)
 
 
72
 
73
  self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
74
 
 
65
  pred = self.pred_model(inst_emb_r, tool_emb_r) # [BxB, 1]
66
  pred = pred.view(B, B) # [B, B]
67
 
68
+ # target = torch.eye(B, device=pred.device).float()
69
 
70
+ # pos_weight = torch.tensor([B - 1], device=pred.device)
71
+ # loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight)
72
+ labels = torch.arange(B, device=pred.device).long()
73
+ loss = (F.cross_entropy(pred, labels) + F.cross_entropy(pred.T, labels)) * 0.5
74
 
75
  self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
76