qninhdt commited on
Commit
166edf7
·
1 Parent(s): 897fe06
.gitignore CHANGED
@@ -1,4 +1,4 @@
1
- data
2
  logs
3
  wandb
4
  __pycache__
 
1
+ datasets
2
  logs
3
  wandb
4
  __pycache__
src/data/__init__.py ADDED
File without changes
src/data/mixed_datamodule.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.data import DataLoader
3
+ from lightning import LightningDataModule
4
+ from .mixed_dataset import MixedDataset
5
+
6
+
7
+ class MixedDataModule(LightningDataModule):
8
+ def __init__(
9
+ self, bert_model, dataset_path, tool_capacity, batch_size, num_workers
10
+ ):
11
+ super().__init__()
12
+ self.bert_model = bert_model
13
+ self.dataset_path = dataset_path
14
+ self.tool_capacity = tool_capacity
15
+ self.batch_size = batch_size
16
+ self.num_workers = num_workers
17
+
18
+ def setup(self, stage=None):
19
+ if stage == "fit":
20
+ self.train_dataset = MixedDataset(
21
+ self.bert_model,
22
+ "train",
23
+ os.path.join(self.dataset_path, "train.json"),
24
+ self.tool_capacity,
25
+ )
26
+ self.val_dataset = MixedDataset(
27
+ self.bert_model,
28
+ "test",
29
+ os.path.join(self.dataset_path, "test.json"),
30
+ self.tool_capacity,
31
+ )
32
+ elif stage == "test":
33
+ self.test_dataset = MixedDataset(
34
+ self.bert_model,
35
+ "test",
36
+ os.path.join(self.dataset_path, "test.json"),
37
+ self.tool_capacity,
38
+ )
39
+
40
+ def train_dataloader(self):
41
+ return DataLoader(
42
+ self.train_dataset,
43
+ batch_size=self.batch_size,
44
+ shuffle=True,
45
+ num_workers=self.num_workers,
46
+ )
47
+
48
+ def val_dataloader(self):
49
+ return DataLoader(
50
+ self.val_dataset,
51
+ batch_size=self.batch_size,
52
+ shuffle=False,
53
+ num_workers=self.num_workers,
54
+ )
55
+
56
+ def test_dataloader(self):
57
+ return DataLoader(
58
+ self.test_dataset,
59
+ batch_size=self.batch_size,
60
+ shuffle=False,
61
+ num_workers=self.num_workers,
62
+ )
src/data/mixed_dataset.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import random
5
+ from torch.utils.data import Dataset
6
+ from transformers import BertTokenizer
7
+
8
+
9
+ class MixedDataset(Dataset):
10
+ def __init__(self, bert_model, stage, anno_file, tool_capacity):
11
+ self.stage = stage
12
+ self.tool_capacity = tool_capacity
13
+ self.tools, self.samples = self.load_data(anno_file)
14
+ self.tool_ids = list(self.tools.keys())
15
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model)
16
+
17
+ def load_data(self, anno_file):
18
+ with open(anno_file, "r") as f:
19
+ data = json.load(f)
20
+ tools = data["tools"]
21
+ samples = data["samples"]
22
+
23
+ tools = {tool["id"]: tool for tool in tools}
24
+
25
+ return tools, samples
26
+
27
+ def encode_text(self, text):
28
+ inputs = self.tokenizer.encode_plus(
29
+ text,
30
+ max_length=128,
31
+ padding="max_length",
32
+ truncation=True,
33
+ )
34
+ ids = torch.tensor(inputs["input_ids"], dtype=torch.long)
35
+ mask = torch.tensor(inputs["attention_mask"], dtype=torch.long)
36
+
37
+ return ids, mask
38
+
39
+ def __len__(self):
40
+ return len(self.samples)
41
+
42
+ def __getitem__(self, idx):
43
+ sample = self.samples[idx]
44
+ inst = sample["instruction"]
45
+ inst_ids, inst_mask = self.encode_text(inst)
46
+
47
+ if self.stage == "train":
48
+ tool_id = random.choice(sample["tools"])
49
+ tool_desc = self.tools[tool_id]["description"]
50
+ tool_desc_ids, tool_desc_mask = self.encode_text(tool_desc)
51
+
52
+ return {
53
+ "inst_ids": inst_ids,
54
+ "inst_mask": inst_mask,
55
+ "tool_desc_ids": tool_desc_ids,
56
+ "tool_desc_mask": tool_desc_mask,
57
+ }
58
+ else:
59
+ # for testing, we sample a random set of tools + the correct tool, size = tool_capacity
60
+ # wrong tools are sampled randomly from self.tools
61
+ correct_tools = sample["tools"]
62
+ wrong_tools = random.sample(
63
+ [tool for tool in self.tool_ids if tool not in correct_tools],
64
+ self.tool_capacity - len(correct_tools),
65
+ )
66
+ tools = correct_tools + wrong_tools
67
+ tool_ids, tool_ids_mask = self.encode_text(
68
+ [self.tools[tool_id]["description"] for tool_id in tools]
69
+ )
70
+
71
+ return {
72
+ "inst_ids": inst_ids,
73
+ "inst_mask": inst_mask,
74
+ "tool_ids": tool_ids,
75
+ "tool_ids_mask": tool_ids_mask,
76
+ }