| import torch
|
| from torch.utils.data import Dataset
|
| from typing import List, Tuple
|
|
|
| class MathDataset(Dataset):
|
| """
|
| A custom PyTorch Dataset to handle the encoded math problem sequences.
|
| It performs the crucial language model shift (X is the input, Y is X shifted by one)
|
| and handles padding.
|
| """
|
| def __init__(self, data: List[str], tokenizer, max_len: int):
|
| self.data = data
|
| self.tokenizer = tokenizer
|
| self.max_len = max_len
|
| self.pad_token_id = tokenizer.pad_token_id
|
|
|
| def __len__(self):
|
|
|
| return len(self.data)
|
|
|
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
| raw_text = self.data[idx]
|
| sequence_ids = self.tokenizer.encode(raw_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| x = sequence_ids[:-1]
|
|
|
|
|
|
|
| y = sequence_ids[1:]
|
|
|
|
|
|
|
|
|
| padding_length = self.max_len - len(x)
|
|
|
|
|
| x_padded = x + [self.pad_token_id] * padding_length
|
| y_padded = y + [self.pad_token_id] * padding_length
|
|
|
|
|
| return torch.tensor(x_padded, dtype=torch.long), torch.tensor(y_padded, dtype=torch.long) |