mini-agent / src /data /mixed_dataset.py
qninhdt
cc
f550456
import os
import json
import torch
import random
from torch.utils.data import Dataset
from transformers import BertTokenizer
class MixedDataset(Dataset):
def __init__(self, bert_model, stage, anno_file, tool_capacity, seed):
self.stage = stage
self.seed = seed
self.tool_capacity = tool_capacity
self.tools, self.samples = self.load_data(anno_file)
self.tool_ids = list(self.tools.keys())
self.tokenizer = BertTokenizer.from_pretrained(bert_model)
def load_data(self, anno_file):
with open(anno_file, "r") as f:
data = json.load(f)
tools = data["tools"]
samples = data["samples"]
tools = {tool["id"]: tool for tool in tools}
return tools, samples
def encode_text(self, text, padding=True):
if padding:
inputs = self.tokenizer(
text,
max_length=128,
padding="max_length",
truncation=True,
)
else:
inputs = self.tokenizer(
text,
max_length=128,
truncation=True,
)
ids = torch.tensor(inputs["input_ids"], dtype=torch.long)
mask = torch.tensor(inputs["attention_mask"], dtype=torch.long)
return ids, mask
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
inst = sample["instruction"]
inst_ids, inst_mask = self.encode_text(inst)
if self.stage == "train":
tool_id = random.choice(sample["tools"])
tool_desc = self.tools[tool_id]["description"]
tool_desc_ids, tool_desc_mask = self.encode_text(tool_desc)
return {
"inst_ids": inst_ids,
"inst_mask": inst_mask,
"tool_ids": tool_desc_ids,
"tool_mask": tool_desc_mask,
}
else:
# for testing, we sample a random set of tools + the correct tool, size = tool_capacity
# wrong tools are sampled randomly from self.tools
correct_tools = sample["tools"]
random.seed(self.seed + idx)
wrong_tools = random.sample(
[tool for tool in self.tool_ids if tool not in correct_tools],
self.tool_capacity - len(correct_tools),
)
correct_tool_mask = torch.tensor(
[1] * len(correct_tools)
+ [0] * (self.tool_capacity - len(correct_tools)),
dtype=torch.bool,
)
tools = correct_tools + wrong_tools
tool_ids, tool_mask = self.encode_text(
[self.tools[tool_id]["description"] for tool_id in tools]
)
return {
"inst_ids": inst_ids,
"inst_mask": inst_mask,
"tool_ids": tool_ids,
"tool_mask": tool_mask,
"correct_tool_mask": correct_tool_mask,
}