| from typing import Any, Dict, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from lightning import LightningModule |
| from torchmetrics import MaxMetric, MeanMetric |
| from torchmetrics.classification.accuracy import Accuracy |
|
|
| from transformers import BertModel |
|
|
|
|
| class MiniAgentModule(LightningModule): |
| def __init__( |
| self, |
| bert_model: str, |
| inst_proj_model: nn.Module, |
| tool_proj_model: nn.Module, |
| pred_model: nn.Module, |
| lr: float, |
| ) -> None: |
| super().__init__() |
|
|
| self.save_hyperparameters( |
| logger=False, ignore=["inst_proj_model", "tool_proj_model", "pred_model"] |
| ) |
|
|
| self.bert_model = BertModel.from_pretrained(bert_model) |
|
|
| self.inst_proj_model = inst_proj_model |
| self.tool_proj_model = tool_proj_model |
| self.pred_model = pred_model |
|
|
| self.val_1_acc = Accuracy(task="binary") |
| self.val_1_precision = MeanMetric() |
| self.val_1_recall = MeanMetric() |
|
|
| self.val_2_acc = Accuracy(task="binary") |
| self.val_2_precision = MeanMetric() |
| self.val_2_recall = MeanMetric() |
|
|
| self.val_other_acc = Accuracy(task="binary") |
| self.val_other_precision = MeanMetric() |
| self.val_other_recall = MeanMetric() |
|
|
| self.lr = lr |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| pass |
|
|
| def on_train_start(self) -> None: |
| pass |
|
|
| def training_step( |
| self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int |
| ) -> torch.Tensor: |
| B = batch["inst_ids"].shape[0] |
|
|
| inst_ids = batch["inst_ids"] |
| inst_mask = batch["inst_mask"] |
| tool_ids = batch["tool_ids"] |
| tool_mask = batch["tool_mask"] |
|
|
| inst_z = self.bert_model(inst_ids, inst_mask, return_dict=False)[0] |
| tool_z = self.bert_model(tool_ids, tool_mask, return_dict=False)[0] |
|
|
| inst_emb = self.inst_proj_model(inst_z) |
| tool_emb = self.tool_proj_model(tool_z) |
|
|
| inst_emb_r = inst_emb.unsqueeze(1).repeat(1, B, 1).view(B * B, -1) |
| tool_emb_r = tool_emb.unsqueeze(0).repeat(B, 1, 1).view(B * B, -1) |
|
|
| pred = self.pred_model(inst_emb_r, tool_emb_r) |
| pred = pred.view(B, B) |
|
|
| target = torch.eye(B, device=pred.device).float() |
|
|
| pos_weight = torch.tensor([B - 1], device=pred.device) |
| |
| loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight) |
|
|
| self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True) |
|
|
| return loss |
|
|
| def on_train_epoch_end(self) -> None: |
| pass |
|
|
| def validation_step( |
| self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int |
| ) -> None: |
| inst_ids = batch["inst_ids"] |
| inst_mask = batch["inst_mask"] |
| tool_ids = batch["tool_ids"] |
| tool_mask = batch["tool_mask"] |
| correct_tool_mask = batch["correct_tool_mask"] |
|
|
| B = inst_ids.shape[0] |
| C = correct_tool_mask.shape[1] |
| tool_ids = tool_ids.view(-1, tool_ids.shape[-1]) |
| tool_mask = tool_mask.view(-1, tool_mask.shape[-1]) |
|
|
| inst_z = self.bert_model(inst_ids, inst_mask, return_dict=False)[0] |
| tool_z = self.bert_model(tool_ids, tool_mask, return_dict=False)[0] |
|
|
| inst_emb = self.inst_proj_model(inst_z) |
| tool_emb = self.tool_proj_model(tool_z) |
|
|
| inst_emb_r = inst_emb.unsqueeze(1).repeat(1, C, 1).view(B * C, -1) |
| tool_emb_r = tool_emb.view(B * C, -1) |
|
|
| pred = self.pred_model(inst_emb_r, tool_emb_r) |
| pred = pred.view(B, C) |
| pred = torch.sigmoid(pred) |
|
|
| pred_tool_mask = pred > 0.5 |
|
|
| true_pos_mask = pred_tool_mask & correct_tool_mask |
|
|
| one_tool_mask = correct_tool_mask.sum(dim=1) == 1 |
| two_tool_mask = correct_tool_mask.sum(dim=1) == 2 |
| other_mask = ~(one_tool_mask | two_tool_mask) |
|
|
| |
| one_tool_pos_sample = ( |
| (pred_tool_mask[one_tool_mask] == correct_tool_mask[one_tool_mask]) |
| .all(dim=1) |
| .long() |
| ) |
|
|
| one_tool_precision = true_pos_mask[one_tool_mask].sum(dim=1) / torch.clamp( |
| pred_tool_mask[one_tool_mask].sum(dim=1), min=1 |
| ) |
|
|
| one_tool_recall = true_pos_mask[one_tool_mask].sum(dim=1) / torch.clamp( |
| correct_tool_mask[one_tool_mask].sum(dim=1), min=1 |
| ) |
|
|
| |
| two_tool_pos_sample = ( |
| (pred_tool_mask[two_tool_mask] == correct_tool_mask[two_tool_mask]) |
| .all(dim=1) |
| .long() |
| ) |
|
|
| two_tool_precision = true_pos_mask[two_tool_mask].sum(dim=1) / torch.clamp( |
| pred_tool_mask[two_tool_mask].sum(dim=1), min=1 |
| ) |
|
|
| two_tool_recall = true_pos_mask[two_tool_mask].sum(dim=1) / torch.clamp( |
| correct_tool_mask[two_tool_mask].sum(dim=1), min=1 |
| ) |
|
|
| |
| other_pos_sample = ( |
| (pred_tool_mask[other_mask] == correct_tool_mask[other_mask]) |
| .all(dim=1) |
| .long() |
| ) |
|
|
| other_precision = true_pos_mask[other_mask].sum(dim=1) / torch.clamp( |
| pred_tool_mask[other_mask].sum(dim=1), min=1 |
| ) |
|
|
| other_recall = true_pos_mask[other_mask].sum(dim=1) / torch.clamp( |
| correct_tool_mask[other_mask].sum(dim=1), min=1 |
| ) |
|
|
| if one_tool_pos_sample.sum().item() > 0: |
| self.val_1_acc.update( |
| one_tool_pos_sample, torch.ones_like(one_tool_pos_sample) |
| ) |
| self.val_1_precision.update(one_tool_precision) |
| self.val_1_recall.update(one_tool_recall) |
|
|
| self.log( |
| "val/1_acc", |
| self.val_1_acc, |
| on_epoch=True, |
| sync_dist=True, |
| prog_bar=True, |
| ) |
|
|
| self.log( |
| "val/1_precision", |
| self.val_1_precision, |
| on_epoch=True, |
| sync_dist=True, |
| prog_bar=True, |
| ) |
|
|
| self.log( |
| "val/1_recall", |
| self.val_1_recall, |
| on_epoch=True, |
| sync_dist=True, |
| prog_bar=True, |
| ) |
|
|
| if two_tool_pos_sample.sum().item() > 0: |
| self.val_2_acc.update( |
| two_tool_pos_sample, torch.ones_like(two_tool_pos_sample) |
| ) |
| self.val_2_precision.update(two_tool_precision) |
| self.val_2_recall.update(two_tool_recall) |
|
|
| self.log( |
| "val/2_acc", |
| self.val_2_acc, |
| on_epoch=True, |
| sync_dist=True, |
| prog_bar=True, |
| ) |
|
|
| self.log( |
| "val/2_precision", |
| self.val_2_precision, |
| on_epoch=True, |
| sync_dist=True, |
| prog_bar=True, |
| ) |
|
|
| self.log( |
| "val/2_recall", |
| self.val_2_recall, |
| on_epoch=True, |
| sync_dist=True, |
| prog_bar=True, |
| ) |
|
|
| if other_pos_sample.sum().item() > 0: |
| self.val_other_acc.update( |
| other_pos_sample, torch.ones_like(other_pos_sample) |
| ) |
| self.val_other_precision.update(other_precision) |
| self.val_other_recall.update(other_recall) |
|
|
| self.log( |
| "val/other_acc", |
| self.val_other_acc, |
| on_epoch=True, |
| sync_dist=True, |
| prog_bar=True, |
| ) |
|
|
| self.log( |
| "val/other_precision", |
| self.val_other_precision, |
| on_epoch=True, |
| sync_dist=True, |
| prog_bar=True, |
| ) |
|
|
| self.log( |
| "val/other_recall", |
| self.val_other_recall, |
| on_epoch=True, |
| sync_dist=True, |
| prog_bar=True, |
| ) |
|
|
| def on_validation_epoch_end(self) -> None: |
| pass |
|
|
| def test_step( |
| self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int |
| ) -> None: |
| pass |
|
|
| def on_test_epoch_end(self) -> None: |
| pass |
|
|
| def configure_optimizers(self): |
| opt = torch.optim.AdamW( |
| [ |
| {"params": self.bert_model.parameters(), "lr": 1e-5}, |
| { |
| "params": list(self.inst_proj_model.parameters()) |
| + list(self.tool_proj_model.parameters()) |
| + list(self.pred_model.parameters()), |
| "lr": self.lr, |
| }, |
| ], |
| weight_decay=1e-4, |
| ) |
| return opt |
|
|