qninhdt commited on
Commit ·
f550456
1
Parent(s): 9a9a2f5
cc
Browse files
configs/experiment/miniagent-bert-mlp-abs_diff.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /data: mixed
|
| 5 |
+
- override /model: miniagent
|
| 6 |
+
- override /callbacks: default
|
| 7 |
+
- override /trainer: gpu
|
| 8 |
+
|
| 9 |
+
seed: 42
|
| 10 |
+
|
| 11 |
+
model:
|
| 12 |
+
lr: 0.001
|
| 13 |
+
bert_model: bert-base-uncased
|
| 14 |
+
|
| 15 |
+
inst_proj_model:
|
| 16 |
+
_target_: src.models.mlp_module.MLPProjection
|
| 17 |
+
input_dim: 768
|
| 18 |
+
hidden_dim: 768
|
| 19 |
+
output_dim: 768
|
| 20 |
+
|
| 21 |
+
tool_proj_model:
|
| 22 |
+
_target_: src.models.mlp_module.MLPProjection
|
| 23 |
+
input_dim: 768
|
| 24 |
+
hidden_dim: 768
|
| 25 |
+
output_dim: 768
|
| 26 |
+
|
| 27 |
+
pred_model:
|
| 28 |
+
_target_: src.models.mlp_module.MLPPrediction
|
| 29 |
+
input_dim: 768
|
| 30 |
+
use_abs_diff: true
|
| 31 |
+
use_mult: true
|
| 32 |
+
|
| 33 |
+
data:
|
| 34 |
+
bert_model: bert-base-uncased
|
| 35 |
+
seed: 42
|
| 36 |
+
batch_size: 128
|
| 37 |
+
tool_capacity: 16
|
configs/experiment/miniagent-bert-mlp-mult.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /data: mixed
|
| 5 |
+
- override /model: miniagent
|
| 6 |
+
- override /callbacks: default
|
| 7 |
+
- override /trainer: gpu
|
| 8 |
+
|
| 9 |
+
seed: 42
|
| 10 |
+
|
| 11 |
+
model:
|
| 12 |
+
lr: 0.001
|
| 13 |
+
bert_model: bert-base-uncased
|
| 14 |
+
|
| 15 |
+
inst_proj_model:
|
| 16 |
+
_target_: src.models.mlp_module.MLPProjection
|
| 17 |
+
input_dim: 768
|
| 18 |
+
hidden_dim: 768
|
| 19 |
+
output_dim: 768
|
| 20 |
+
|
| 21 |
+
tool_proj_model:
|
| 22 |
+
_target_: src.models.mlp_module.MLPProjection
|
| 23 |
+
input_dim: 768
|
| 24 |
+
hidden_dim: 768
|
| 25 |
+
output_dim: 768
|
| 26 |
+
|
| 27 |
+
pred_model:
|
| 28 |
+
_target_: src.models.mlp_module.MLPPrediction
|
| 29 |
+
input_dim: 768
|
| 30 |
+
use_abs_diff: true
|
| 31 |
+
use_mult: true
|
| 32 |
+
|
| 33 |
+
data:
|
| 34 |
+
bert_model: bert-base-uncased
|
| 35 |
+
seed: 42
|
| 36 |
+
batch_size: 128
|
| 37 |
+
tool_capacity: 16
|
configs/trainer/default.yaml
CHANGED
|
@@ -3,7 +3,7 @@ _target_: lightning.pytorch.trainer.Trainer
|
|
| 3 |
default_root_dir: ${paths.output_dir}
|
| 4 |
|
| 5 |
min_epochs: 1 # prevents early stopping
|
| 6 |
-
max_epochs:
|
| 7 |
|
| 8 |
accelerator: cpu
|
| 9 |
devices: 1
|
|
@@ -18,4 +18,4 @@ check_val_every_n_epoch: 1
|
|
| 18 |
|
| 19 |
# set True to to ensure deterministic results
|
| 20 |
# makes training slower but gives more reproducibility than just setting seeds
|
| 21 |
-
deterministic:
|
|
|
|
| 3 |
default_root_dir: ${paths.output_dir}
|
| 4 |
|
| 5 |
min_epochs: 1 # prevents early stopping
|
| 6 |
+
max_epochs: 50
|
| 7 |
|
| 8 |
accelerator: cpu
|
| 9 |
devices: 1
|
|
|
|
| 18 |
|
| 19 |
# set True to to ensure deterministic results
|
| 20 |
# makes training slower but gives more reproducibility than just setting seeds
|
| 21 |
+
deterministic: True
|
src/data/mixed_dataset.py
CHANGED
|
@@ -25,13 +25,20 @@ class MixedDataset(Dataset):
|
|
| 25 |
|
| 26 |
return tools, samples
|
| 27 |
|
| 28 |
-
def encode_text(self, text):
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
ids = torch.tensor(inputs["input_ids"], dtype=torch.long)
|
| 36 |
mask = torch.tensor(inputs["attention_mask"], dtype=torch.long)
|
| 37 |
|
|
|
|
| 25 |
|
| 26 |
return tools, samples
|
| 27 |
|
| 28 |
+
def encode_text(self, text, padding=True):
|
| 29 |
+
if padding:
|
| 30 |
+
inputs = self.tokenizer(
|
| 31 |
+
text,
|
| 32 |
+
max_length=128,
|
| 33 |
+
padding="max_length",
|
| 34 |
+
truncation=True,
|
| 35 |
+
)
|
| 36 |
+
else:
|
| 37 |
+
inputs = self.tokenizer(
|
| 38 |
+
text,
|
| 39 |
+
max_length=128,
|
| 40 |
+
truncation=True,
|
| 41 |
+
)
|
| 42 |
ids = torch.tensor(inputs["input_ids"], dtype=torch.long)
|
| 43 |
mask = torch.tensor(inputs["attention_mask"], dtype=torch.long)
|
| 44 |
|
src/models/miniagent_module.py
CHANGED
|
@@ -26,8 +26,8 @@ class MiniAgentModule(LightningModule):
|
|
| 26 |
)
|
| 27 |
|
| 28 |
self.bert_model = BertModel.from_pretrained(bert_model)
|
| 29 |
-
self.bert_model.eval()
|
| 30 |
-
self.bert_model.requires_grad_(False)
|
| 31 |
|
| 32 |
self.inst_proj_model = inst_proj_model
|
| 33 |
self.tool_proj_model = tool_proj_model
|
|
@@ -67,8 +67,12 @@ class MiniAgentModule(LightningModule):
|
|
| 67 |
pred = self.pred_model(inst_emb_r, tool_emb_r) # [BxB, 1]
|
| 68 |
pred = pred.view(B, B) # [B, B]
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
|
| 74 |
|
|
@@ -142,5 +146,16 @@ class MiniAgentModule(LightningModule):
|
|
| 142 |
pass
|
| 143 |
|
| 144 |
def configure_optimizers(self):
|
| 145 |
-
opt = torch.optim.AdamW(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
return opt
|
|
|
|
| 26 |
)
|
| 27 |
|
| 28 |
self.bert_model = BertModel.from_pretrained(bert_model)
|
| 29 |
+
# self.bert_model.eval()
|
| 30 |
+
# self.bert_model.requires_grad_(False)
|
| 31 |
|
| 32 |
self.inst_proj_model = inst_proj_model
|
| 33 |
self.tool_proj_model = tool_proj_model
|
|
|
|
| 67 |
pred = self.pred_model(inst_emb_r, tool_emb_r) # [BxB, 1]
|
| 68 |
pred = pred.view(B, B) # [B, B]
|
| 69 |
|
| 70 |
+
target = torch.eye(B, device=pred.device).float()
|
| 71 |
+
|
| 72 |
+
pos_weight = torch.tensor([B - 1], device=pred.device)
|
| 73 |
+
loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight)
|
| 74 |
+
# labels = torch.arange(B, device=pred.device).long()
|
| 75 |
+
# loss = (F.cross_entropy(pred, labels) + F.cross_entropy(pred.T, labels)) * 0.5
|
| 76 |
|
| 77 |
self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
|
| 78 |
|
|
|
|
| 146 |
pass
|
| 147 |
|
| 148 |
def configure_optimizers(self):
|
| 149 |
+
opt = torch.optim.AdamW(
|
| 150 |
+
[
|
| 151 |
+
{"params": self.bert_model.parameters(), "lr": 1e-5},
|
| 152 |
+
{
|
| 153 |
+
"params": list(self.inst_proj_model.parameters())
|
| 154 |
+
+ list(self.tool_proj_model.parameters())
|
| 155 |
+
+ list(self.pred_model.parameters()),
|
| 156 |
+
"lr": self.lr,
|
| 157 |
+
},
|
| 158 |
+
],
|
| 159 |
+
weight_decay=1e-4,
|
| 160 |
+
)
|
| 161 |
return opt
|