qninhdt commited on
Commit
f550456
·
1 Parent(s): 9a9a2f5
configs/experiment/miniagent-bert-mlp-abs_diff.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: mixed
5
+ - override /model: miniagent
6
+ - override /callbacks: default
7
+ - override /trainer: gpu
8
+
9
+ seed: 42
10
+
11
+ model:
12
+ lr: 0.001
13
+ bert_model: bert-base-uncased
14
+
15
+ inst_proj_model:
16
+ _target_: src.models.mlp_module.MLPProjection
17
+ input_dim: 768
18
+ hidden_dim: 768
19
+ output_dim: 768
20
+
21
+ tool_proj_model:
22
+ _target_: src.models.mlp_module.MLPProjection
23
+ input_dim: 768
24
+ hidden_dim: 768
25
+ output_dim: 768
26
+
27
+ pred_model:
28
+ _target_: src.models.mlp_module.MLPPrediction
29
+ input_dim: 768
30
+ use_abs_diff: true
31
+ use_mult: true
32
+
33
+ data:
34
+ bert_model: bert-base-uncased
35
+ seed: 42
36
+ batch_size: 128
37
+ tool_capacity: 16
configs/experiment/miniagent-bert-mlp-mult.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: mixed
5
+ - override /model: miniagent
6
+ - override /callbacks: default
7
+ - override /trainer: gpu
8
+
9
+ seed: 42
10
+
11
+ model:
12
+ lr: 0.001
13
+ bert_model: bert-base-uncased
14
+
15
+ inst_proj_model:
16
+ _target_: src.models.mlp_module.MLPProjection
17
+ input_dim: 768
18
+ hidden_dim: 768
19
+ output_dim: 768
20
+
21
+ tool_proj_model:
22
+ _target_: src.models.mlp_module.MLPProjection
23
+ input_dim: 768
24
+ hidden_dim: 768
25
+ output_dim: 768
26
+
27
+ pred_model:
28
+ _target_: src.models.mlp_module.MLPPrediction
29
+ input_dim: 768
30
+ use_abs_diff: true
31
+ use_mult: true
32
+
33
+ data:
34
+ bert_model: bert-base-uncased
35
+ seed: 42
36
+ batch_size: 128
37
+ tool_capacity: 16
configs/trainer/default.yaml CHANGED
@@ -3,7 +3,7 @@ _target_: lightning.pytorch.trainer.Trainer
3
  default_root_dir: ${paths.output_dir}
4
 
5
  min_epochs: 1 # prevents early stopping
6
- max_epochs: 20
7
 
8
  accelerator: cpu
9
  devices: 1
@@ -18,4 +18,4 @@ check_val_every_n_epoch: 1
18
 
19
  # set True to to ensure deterministic results
20
  # makes training slower but gives more reproducibility than just setting seeds
21
- deterministic: False
 
3
  default_root_dir: ${paths.output_dir}
4
 
5
  min_epochs: 1 # prevents early stopping
6
+ max_epochs: 50
7
 
8
  accelerator: cpu
9
  devices: 1
 
18
 
19
  # set True to to ensure deterministic results
20
  # makes training slower but gives more reproducibility than just setting seeds
21
+ deterministic: True
src/data/mixed_dataset.py CHANGED
@@ -25,13 +25,20 @@ class MixedDataset(Dataset):
25
 
26
  return tools, samples
27
 
28
- def encode_text(self, text):
29
- inputs = self.tokenizer(
30
- text,
31
- max_length=128,
32
- padding="max_length",
33
- truncation=True,
34
- )
 
 
 
 
 
 
 
35
  ids = torch.tensor(inputs["input_ids"], dtype=torch.long)
36
  mask = torch.tensor(inputs["attention_mask"], dtype=torch.long)
37
 
 
25
 
26
  return tools, samples
27
 
28
+ def encode_text(self, text, padding=True):
29
+ if padding:
30
+ inputs = self.tokenizer(
31
+ text,
32
+ max_length=128,
33
+ padding="max_length",
34
+ truncation=True,
35
+ )
36
+ else:
37
+ inputs = self.tokenizer(
38
+ text,
39
+ max_length=128,
40
+ truncation=True,
41
+ )
42
  ids = torch.tensor(inputs["input_ids"], dtype=torch.long)
43
  mask = torch.tensor(inputs["attention_mask"], dtype=torch.long)
44
 
src/models/miniagent_module.py CHANGED
@@ -26,8 +26,8 @@ class MiniAgentModule(LightningModule):
26
  )
27
 
28
  self.bert_model = BertModel.from_pretrained(bert_model)
29
- self.bert_model.eval()
30
- self.bert_model.requires_grad_(False)
31
 
32
  self.inst_proj_model = inst_proj_model
33
  self.tool_proj_model = tool_proj_model
@@ -67,8 +67,12 @@ class MiniAgentModule(LightningModule):
67
  pred = self.pred_model(inst_emb_r, tool_emb_r) # [BxB, 1]
68
  pred = pred.view(B, B) # [B, B]
69
 
70
- labels = torch.arange(B, device=pred.device).long()
71
- loss = (F.cross_entropy(pred, labels) + F.cross_entropy(pred.T, labels)) * 0.5
 
 
 
 
72
 
73
  self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
74
 
@@ -142,5 +146,16 @@ class MiniAgentModule(LightningModule):
142
  pass
143
 
144
  def configure_optimizers(self):
145
- opt = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-4)
 
 
 
 
 
 
 
 
 
 
 
146
  return opt
 
26
  )
27
 
28
  self.bert_model = BertModel.from_pretrained(bert_model)
29
+ # self.bert_model.eval()
30
+ # self.bert_model.requires_grad_(False)
31
 
32
  self.inst_proj_model = inst_proj_model
33
  self.tool_proj_model = tool_proj_model
 
67
  pred = self.pred_model(inst_emb_r, tool_emb_r) # [BxB, 1]
68
  pred = pred.view(B, B) # [B, B]
69
 
70
+ target = torch.eye(B, device=pred.device).float()
71
+
72
+ pos_weight = torch.tensor([B - 1], device=pred.device)
73
+ loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight)
74
+ # labels = torch.arange(B, device=pred.device).long()
75
+ # loss = (F.cross_entropy(pred, labels) + F.cross_entropy(pred.T, labels)) * 0.5
76
 
77
  self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
78
 
 
146
  pass
147
 
148
  def configure_optimizers(self):
149
+ opt = torch.optim.AdamW(
150
+ [
151
+ {"params": self.bert_model.parameters(), "lr": 1e-5},
152
+ {
153
+ "params": list(self.inst_proj_model.parameters())
154
+ + list(self.tool_proj_model.parameters())
155
+ + list(self.pred_model.parameters()),
156
+ "lr": self.lr,
157
+ },
158
+ ],
159
+ weight_decay=1e-4,
160
+ )
161
  return opt