qninhdt commited on
Commit ·
2c9e8bc
1
Parent(s): 0c2ae95
cc
Browse files
configs/experiment/miniagent-bert-mlp-abs_diff-mult.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /data: mixed
|
| 5 |
+
- override /model: miniagent
|
| 6 |
+
- override /callbacks: default
|
| 7 |
+
- override /trainer: gpu
|
| 8 |
+
|
| 9 |
+
seed: 42
|
| 10 |
+
|
| 11 |
+
model:
|
| 12 |
+
lr: 0.001
|
| 13 |
+
bert_model: bert-base-uncased
|
| 14 |
+
|
| 15 |
+
inst_proj_model:
|
| 16 |
+
_target_: src.models.mlp_module.MLPProjection
|
| 17 |
+
input_dim: 768
|
| 18 |
+
hidden_dim: 768
|
| 19 |
+
output_dim: 768
|
| 20 |
+
|
| 21 |
+
tool_proj_model:
|
| 22 |
+
_target_: src.models.mlp_module.MLPProjection
|
| 23 |
+
input_dim: 768
|
| 24 |
+
hidden_dim: 768
|
| 25 |
+
output_dim: 768
|
| 26 |
+
|
| 27 |
+
pred_model:
|
| 28 |
+
_target_: src.models.mlp_module.MLPPrediction
|
| 29 |
+
input_dim: 768
|
| 30 |
+
use_abs_diff: true
|
| 31 |
+
use_mult: true
|
| 32 |
+
|
| 33 |
+
data:
|
| 34 |
+
bert_model: bert-base-uncased
|
| 35 |
+
seed: 42
|
| 36 |
+
batch_size: 128
|
| 37 |
+
tool_capacity: 16
|
configs/trainer/default.yaml
CHANGED
|
@@ -8,6 +8,8 @@ max_epochs: 10
|
|
| 8 |
accelerator: cpu
|
| 9 |
devices: 1
|
| 10 |
|
|
|
|
|
|
|
| 11 |
# mixed precision for extra speed-up
|
| 12 |
# precision: 16
|
| 13 |
|
|
|
|
| 8 |
accelerator: cpu
|
| 9 |
devices: 1
|
| 10 |
|
| 11 |
+
log_every_n_steps: 10
|
| 12 |
+
|
| 13 |
# mixed precision for extra speed-up
|
| 14 |
# precision: 16
|
| 15 |
|
src/data/mixed_datamodule.py
CHANGED
|
@@ -47,6 +47,7 @@ class MixedDataModule(LightningDataModule):
|
|
| 47 |
batch_size=self.batch_size,
|
| 48 |
shuffle=True,
|
| 49 |
num_workers=self.num_workers,
|
|
|
|
| 50 |
)
|
| 51 |
|
| 52 |
def val_dataloader(self):
|
|
|
|
| 47 |
batch_size=self.batch_size,
|
| 48 |
shuffle=True,
|
| 49 |
num_workers=self.num_workers,
|
| 50 |
+
drop_last=True,
|
| 51 |
)
|
| 52 |
|
| 53 |
def val_dataloader(self):
|
src/models/miniagent_module.py
CHANGED
|
@@ -65,10 +65,12 @@ class MiniAgentModule(LightningModule):
|
|
| 65 |
pred = self.pred_model(inst_emb_r, tool_emb_r) # [BxB, 1]
|
| 66 |
pred = pred.view(B, B) # [B, B]
|
| 67 |
|
| 68 |
-
target = torch.eye(B, device=pred.device).float()
|
| 69 |
|
| 70 |
-
pos_weight = torch.tensor([B - 1], device=pred.device)
|
| 71 |
-
loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight)
|
|
|
|
|
|
|
| 72 |
|
| 73 |
self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
|
| 74 |
|
|
|
|
| 65 |
pred = self.pred_model(inst_emb_r, tool_emb_r) # [BxB, 1]
|
| 66 |
pred = pred.view(B, B) # [B, B]
|
| 67 |
|
| 68 |
+
# target = torch.eye(B, device=pred.device).float()
|
| 69 |
|
| 70 |
+
# pos_weight = torch.tensor([B - 1], device=pred.device)
|
| 71 |
+
# loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight)
|
| 72 |
+
labels = torch.arange(B, device=pred.device).long()
|
| 73 |
+
loss = (F.cross_entropy(pred, labels) + F.cross_entropy(pred.T, labels)) * 0.5
|
| 74 |
|
| 75 |
self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
|
| 76 |
|