| import os |
| import json |
| import torch |
| import random |
| from torch.utils.data import Dataset |
| from transformers import BertTokenizer |
|
|
|
|
| class MixedDataset(Dataset): |
| def __init__(self, bert_model, stage, anno_file, tool_capacity, seed): |
| self.stage = stage |
| self.seed = seed |
| self.tool_capacity = tool_capacity |
| self.tools, self.samples = self.load_data(anno_file) |
| self.tool_ids = list(self.tools.keys()) |
| self.tokenizer = BertTokenizer.from_pretrained(bert_model) |
|
|
| def load_data(self, anno_file): |
| with open(anno_file, "r") as f: |
| data = json.load(f) |
| tools = data["tools"] |
| samples = data["samples"] |
|
|
| tools = {tool["id"]: tool for tool in tools} |
|
|
| return tools, samples |
|
|
| def encode_text(self, text, padding=True): |
| if padding: |
| inputs = self.tokenizer( |
| text, |
| max_length=128, |
| padding="max_length", |
| truncation=True, |
| ) |
| else: |
| inputs = self.tokenizer( |
| text, |
| max_length=128, |
| truncation=True, |
| ) |
| ids = torch.tensor(inputs["input_ids"], dtype=torch.long) |
| mask = torch.tensor(inputs["attention_mask"], dtype=torch.long) |
|
|
| return ids, mask |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| sample = self.samples[idx] |
| inst = sample["instruction"] |
| inst_ids, inst_mask = self.encode_text(inst) |
|
|
| if self.stage == "train": |
| tool_id = random.choice(sample["tools"]) |
| tool_desc = self.tools[tool_id]["description"] |
| tool_desc_ids, tool_desc_mask = self.encode_text(tool_desc) |
|
|
| return { |
| "inst_ids": inst_ids, |
| "inst_mask": inst_mask, |
| "tool_ids": tool_desc_ids, |
| "tool_mask": tool_desc_mask, |
| } |
| else: |
| |
| |
| correct_tools = sample["tools"] |
|
|
| random.seed(self.seed + idx) |
| wrong_tools = random.sample( |
| [tool for tool in self.tool_ids if tool not in correct_tools], |
| self.tool_capacity - len(correct_tools), |
| ) |
|
|
| correct_tool_mask = torch.tensor( |
| [1] * len(correct_tools) |
| + [0] * (self.tool_capacity - len(correct_tools)), |
| dtype=torch.bool, |
| ) |
|
|
| tools = correct_tools + wrong_tools |
| tool_ids, tool_mask = self.encode_text( |
| [self.tools[tool_id]["description"] for tool_id in tools] |
| ) |
|
|
| return { |
| "inst_ids": inst_ids, |
| "inst_mask": inst_mask, |
| "tool_ids": tool_ids, |
| "tool_mask": tool_mask, |
| "correct_tool_mask": correct_tool_mask, |
| } |
|
|