cc
Browse files- .gitignore +1 -1
- src/data/__init__.py +0 -0
- src/data/mixed_datamodule.py +62 -0
- src/data/mixed_dataset.py +76 -0
.gitignore
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
logs
|
| 3 |
wandb
|
| 4 |
__pycache__
|
|
|
|
| 1 |
+
datasets
|
| 2 |
logs
|
| 3 |
wandb
|
| 4 |
__pycache__
|
src/data/__init__.py
ADDED
|
File without changes
|
src/data/mixed_datamodule.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
from lightning import LightningDataModule
|
| 4 |
+
from .mixed_dataset import MixedDataset
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class MixedDataModule(LightningDataModule):
|
| 8 |
+
def __init__(
|
| 9 |
+
self, bert_model, dataset_path, tool_capacity, batch_size, num_workers
|
| 10 |
+
):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.bert_model = bert_model
|
| 13 |
+
self.dataset_path = dataset_path
|
| 14 |
+
self.tool_capacity = tool_capacity
|
| 15 |
+
self.batch_size = batch_size
|
| 16 |
+
self.num_workers = num_workers
|
| 17 |
+
|
| 18 |
+
def setup(self, stage=None):
|
| 19 |
+
if stage == "fit":
|
| 20 |
+
self.train_dataset = MixedDataset(
|
| 21 |
+
self.bert_model,
|
| 22 |
+
"train",
|
| 23 |
+
os.path.join(self.dataset_path, "train.json"),
|
| 24 |
+
self.tool_capacity,
|
| 25 |
+
)
|
| 26 |
+
self.val_dataset = MixedDataset(
|
| 27 |
+
self.bert_model,
|
| 28 |
+
"test",
|
| 29 |
+
os.path.join(self.dataset_path, "test.json"),
|
| 30 |
+
self.tool_capacity,
|
| 31 |
+
)
|
| 32 |
+
elif stage == "test":
|
| 33 |
+
self.test_dataset = MixedDataset(
|
| 34 |
+
self.bert_model,
|
| 35 |
+
"test",
|
| 36 |
+
os.path.join(self.dataset_path, "test.json"),
|
| 37 |
+
self.tool_capacity,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def train_dataloader(self):
|
| 41 |
+
return DataLoader(
|
| 42 |
+
self.train_dataset,
|
| 43 |
+
batch_size=self.batch_size,
|
| 44 |
+
shuffle=True,
|
| 45 |
+
num_workers=self.num_workers,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def val_dataloader(self):
|
| 49 |
+
return DataLoader(
|
| 50 |
+
self.val_dataset,
|
| 51 |
+
batch_size=self.batch_size,
|
| 52 |
+
shuffle=False,
|
| 53 |
+
num_workers=self.num_workers,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def test_dataloader(self):
|
| 57 |
+
return DataLoader(
|
| 58 |
+
self.test_dataset,
|
| 59 |
+
batch_size=self.batch_size,
|
| 60 |
+
shuffle=False,
|
| 61 |
+
num_workers=self.num_workers,
|
| 62 |
+
)
|
src/data/mixed_dataset.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import random
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from transformers import BertTokenizer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MixedDataset(Dataset):
|
| 10 |
+
def __init__(self, bert_model, stage, anno_file, tool_capacity):
|
| 11 |
+
self.stage = stage
|
| 12 |
+
self.tool_capacity = tool_capacity
|
| 13 |
+
self.tools, self.samples = self.load_data(anno_file)
|
| 14 |
+
self.tool_ids = list(self.tools.keys())
|
| 15 |
+
self.tokenizer = BertTokenizer.from_pretrained(bert_model)
|
| 16 |
+
|
| 17 |
+
def load_data(self, anno_file):
|
| 18 |
+
with open(anno_file, "r") as f:
|
| 19 |
+
data = json.load(f)
|
| 20 |
+
tools = data["tools"]
|
| 21 |
+
samples = data["samples"]
|
| 22 |
+
|
| 23 |
+
tools = {tool["id"]: tool for tool in tools}
|
| 24 |
+
|
| 25 |
+
return tools, samples
|
| 26 |
+
|
| 27 |
+
def encode_text(self, text):
|
| 28 |
+
inputs = self.tokenizer.encode_plus(
|
| 29 |
+
text,
|
| 30 |
+
max_length=128,
|
| 31 |
+
padding="max_length",
|
| 32 |
+
truncation=True,
|
| 33 |
+
)
|
| 34 |
+
ids = torch.tensor(inputs["input_ids"], dtype=torch.long)
|
| 35 |
+
mask = torch.tensor(inputs["attention_mask"], dtype=torch.long)
|
| 36 |
+
|
| 37 |
+
return ids, mask
|
| 38 |
+
|
| 39 |
+
def __len__(self):
|
| 40 |
+
return len(self.samples)
|
| 41 |
+
|
| 42 |
+
def __getitem__(self, idx):
|
| 43 |
+
sample = self.samples[idx]
|
| 44 |
+
inst = sample["instruction"]
|
| 45 |
+
inst_ids, inst_mask = self.encode_text(inst)
|
| 46 |
+
|
| 47 |
+
if self.stage == "train":
|
| 48 |
+
tool_id = random.choice(sample["tools"])
|
| 49 |
+
tool_desc = self.tools[tool_id]["description"]
|
| 50 |
+
tool_desc_ids, tool_desc_mask = self.encode_text(tool_desc)
|
| 51 |
+
|
| 52 |
+
return {
|
| 53 |
+
"inst_ids": inst_ids,
|
| 54 |
+
"inst_mask": inst_mask,
|
| 55 |
+
"tool_desc_ids": tool_desc_ids,
|
| 56 |
+
"tool_desc_mask": tool_desc_mask,
|
| 57 |
+
}
|
| 58 |
+
else:
|
| 59 |
+
# for testing, we sample a random set of tools + the correct tool, size = tool_capacity
|
| 60 |
+
# wrong tools are sampled randomly from self.tools
|
| 61 |
+
correct_tools = sample["tools"]
|
| 62 |
+
wrong_tools = random.sample(
|
| 63 |
+
[tool for tool in self.tool_ids if tool not in correct_tools],
|
| 64 |
+
self.tool_capacity - len(correct_tools),
|
| 65 |
+
)
|
| 66 |
+
tools = correct_tools + wrong_tools
|
| 67 |
+
tool_ids, tool_ids_mask = self.encode_text(
|
| 68 |
+
[self.tools[tool_id]["description"] for tool_id in tools]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return {
|
| 72 |
+
"inst_ids": inst_ids,
|
| 73 |
+
"inst_mask": inst_mask,
|
| 74 |
+
"tool_ids": tool_ids,
|
| 75 |
+
"tool_ids_mask": tool_ids_mask,
|
| 76 |
+
}
|