mini-agent / src /models /miniagent_module.py
qninhdt
cc
3f792e3
from typing import Any, Dict, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning import LightningModule
from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.classification.accuracy import Accuracy
from transformers import BertModel
class MiniAgentModule(LightningModule):
def __init__(
self,
bert_model: str,
inst_proj_model: nn.Module,
tool_proj_model: nn.Module,
pred_model: nn.Module,
lr: float,
) -> None:
super().__init__()
self.save_hyperparameters(
logger=False, ignore=["inst_proj_model", "tool_proj_model", "pred_model"]
)
self.bert_model = BertModel.from_pretrained(bert_model)
self.inst_proj_model = inst_proj_model
self.tool_proj_model = tool_proj_model
self.pred_model = pred_model
self.val_1_acc = Accuracy(task="binary")
self.val_1_precision = MeanMetric()
self.val_1_recall = MeanMetric()
self.val_2_acc = Accuracy(task="binary")
self.val_2_precision = MeanMetric()
self.val_2_recall = MeanMetric()
self.val_other_acc = Accuracy(task="binary")
self.val_other_precision = MeanMetric()
self.val_other_recall = MeanMetric()
self.lr = lr
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
def on_train_start(self) -> None:
pass
def training_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
B = batch["inst_ids"].shape[0]
inst_ids = batch["inst_ids"]
inst_mask = batch["inst_mask"]
tool_ids = batch["tool_ids"]
tool_mask = batch["tool_mask"]
inst_z = self.bert_model(inst_ids, inst_mask, return_dict=False)[0]
tool_z = self.bert_model(tool_ids, tool_mask, return_dict=False)[0]
inst_emb = self.inst_proj_model(inst_z)
tool_emb = self.tool_proj_model(tool_z)
inst_emb_r = inst_emb.unsqueeze(1).repeat(1, B, 1).view(B * B, -1)
tool_emb_r = tool_emb.unsqueeze(0).repeat(B, 1, 1).view(B * B, -1)
pred = self.pred_model(inst_emb_r, tool_emb_r) # [BxB, 1]
pred = pred.view(B, B) # [B, B]
target = torch.eye(B, device=pred.device).float()
pos_weight = torch.tensor([B - 1], device=pred.device)
# pos_weight = torch.tensor([1], device=pred.device)
loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight)
self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
return loss
def on_train_epoch_end(self) -> None:
pass
def validation_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> None:
inst_ids = batch["inst_ids"]
inst_mask = batch["inst_mask"]
tool_ids = batch["tool_ids"]
tool_mask = batch["tool_mask"]
correct_tool_mask = batch["correct_tool_mask"]
B = inst_ids.shape[0] # batch size
C = correct_tool_mask.shape[1] # tool capacity
tool_ids = tool_ids.view(-1, tool_ids.shape[-1]) # [B*C, L]
tool_mask = tool_mask.view(-1, tool_mask.shape[-1]) # [B*C, L]
inst_z = self.bert_model(inst_ids, inst_mask, return_dict=False)[0]
tool_z = self.bert_model(tool_ids, tool_mask, return_dict=False)[0]
inst_emb = self.inst_proj_model(inst_z) # [B, D]
tool_emb = self.tool_proj_model(tool_z) # [B*C, D]
inst_emb_r = inst_emb.unsqueeze(1).repeat(1, C, 1).view(B * C, -1)
tool_emb_r = tool_emb.view(B * C, -1)
pred = self.pred_model(inst_emb_r, tool_emb_r) # [B*C, 1]
pred = pred.view(B, C)
pred = torch.sigmoid(pred)
pred_tool_mask = pred > 0.5
true_pos_mask = pred_tool_mask & correct_tool_mask
one_tool_mask = correct_tool_mask.sum(dim=1) == 1
two_tool_mask = correct_tool_mask.sum(dim=1) == 2
other_mask = ~(one_tool_mask | two_tool_mask)
# one tool
one_tool_pos_sample = (
(pred_tool_mask[one_tool_mask] == correct_tool_mask[one_tool_mask])
.all(dim=1)
.long()
)
one_tool_precision = true_pos_mask[one_tool_mask].sum(dim=1) / torch.clamp(
pred_tool_mask[one_tool_mask].sum(dim=1), min=1
)
one_tool_recall = true_pos_mask[one_tool_mask].sum(dim=1) / torch.clamp(
correct_tool_mask[one_tool_mask].sum(dim=1), min=1
)
# two tool
two_tool_pos_sample = (
(pred_tool_mask[two_tool_mask] == correct_tool_mask[two_tool_mask])
.all(dim=1)
.long()
)
two_tool_precision = true_pos_mask[two_tool_mask].sum(dim=1) / torch.clamp(
pred_tool_mask[two_tool_mask].sum(dim=1), min=1
)
two_tool_recall = true_pos_mask[two_tool_mask].sum(dim=1) / torch.clamp(
correct_tool_mask[two_tool_mask].sum(dim=1), min=1
)
# other
other_pos_sample = (
(pred_tool_mask[other_mask] == correct_tool_mask[other_mask])
.all(dim=1)
.long()
)
other_precision = true_pos_mask[other_mask].sum(dim=1) / torch.clamp(
pred_tool_mask[other_mask].sum(dim=1), min=1
)
other_recall = true_pos_mask[other_mask].sum(dim=1) / torch.clamp(
correct_tool_mask[other_mask].sum(dim=1), min=1
)
if one_tool_pos_sample.sum().item() > 0:
self.val_1_acc.update(
one_tool_pos_sample, torch.ones_like(one_tool_pos_sample)
)
self.val_1_precision.update(one_tool_precision)
self.val_1_recall.update(one_tool_recall)
self.log(
"val/1_acc",
self.val_1_acc,
on_epoch=True,
sync_dist=True,
prog_bar=True,
)
self.log(
"val/1_precision",
self.val_1_precision,
on_epoch=True,
sync_dist=True,
prog_bar=True,
)
self.log(
"val/1_recall",
self.val_1_recall,
on_epoch=True,
sync_dist=True,
prog_bar=True,
)
if two_tool_pos_sample.sum().item() > 0:
self.val_2_acc.update(
two_tool_pos_sample, torch.ones_like(two_tool_pos_sample)
)
self.val_2_precision.update(two_tool_precision)
self.val_2_recall.update(two_tool_recall)
self.log(
"val/2_acc",
self.val_2_acc,
on_epoch=True,
sync_dist=True,
prog_bar=True,
)
self.log(
"val/2_precision",
self.val_2_precision,
on_epoch=True,
sync_dist=True,
prog_bar=True,
)
self.log(
"val/2_recall",
self.val_2_recall,
on_epoch=True,
sync_dist=True,
prog_bar=True,
)
if other_pos_sample.sum().item() > 0:
self.val_other_acc.update(
other_pos_sample, torch.ones_like(other_pos_sample)
)
self.val_other_precision.update(other_precision)
self.val_other_recall.update(other_recall)
self.log(
"val/other_acc",
self.val_other_acc,
on_epoch=True,
sync_dist=True,
prog_bar=True,
)
self.log(
"val/other_precision",
self.val_other_precision,
on_epoch=True,
sync_dist=True,
prog_bar=True,
)
self.log(
"val/other_recall",
self.val_other_recall,
on_epoch=True,
sync_dist=True,
prog_bar=True,
)
def on_validation_epoch_end(self) -> None:
pass
def test_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> None:
pass
def on_test_epoch_end(self) -> None:
pass
def configure_optimizers(self):
opt = torch.optim.AdamW(
[
{"params": self.bert_model.parameters(), "lr": 1e-5},
{
"params": list(self.inst_proj_model.parameters())
+ list(self.tool_proj_model.parameters())
+ list(self.pred_model.parameters()),
"lr": self.lr,
},
],
weight_decay=1e-4,
)
return opt