| import torch |
| from torch.utils.data import Dataset |
| import json |
| import random |
|
|
|
|
| class MTPDataset(Dataset): |
| """Dataset mejorado con augmentación de datos""" |
| |
| def __init__(self, corpus_path, tokenizer, max_seq_len=512, |
| use_augmentation=False, augmentation_prob=0.3): |
| self.tokenizer = tokenizer |
| self.max_seq_len = max_seq_len |
| self.use_augmentation = use_augmentation |
| self.augmentation_prob = augmentation_prob |
| self.data = [] |
| |
| |
| with open(corpus_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| entry = json.loads(line) |
| if 'instruction' in entry and 'response' in entry: |
| self.data.append(entry) |
| |
| print(f"✓ Loaded {len(self.data)} examples from corpus") |
| if use_augmentation: |
| print(f"✓ Data augmentation enabled (prob={augmentation_prob})") |
| |
| def __len__(self): |
| return len(self.data) |
| |
| def augment_text(self, text): |
| """Augmentación simple de texto""" |
| if not self.use_augmentation or random.random() > self.augmentation_prob: |
| return text |
| |
| |
| if random.random() < 0.3: |
| text = text.strip() |
| |
| |
| if random.random() < 0.2: |
| if text.endswith('.'): |
| text = text[:-1] |
| elif not text.endswith(('.', '!', '?')): |
| text = text + '.' |
| |
| return text |
| |
| def __getitem__(self, idx): |
| entry = self.data[idx] |
| |
| instruction = entry['instruction'] |
| response = entry['response'] |
| |
| |
| instruction = self.augment_text(instruction) |
| response = self.augment_text(response) |
| |
| |
| full_text = f"### Instrucción:\n{instruction}\n\n### Respuesta:\n{response}" |
| |
| |
| tokens = self.tokenizer.encode(full_text) |
| |
| |
| tokens = [self.tokenizer.bos_id()] + tokens + [self.tokenizer.eos_id()] |
| |
| |
| if len(tokens) > self.max_seq_len: |
| |
| tokens = [tokens[0]] + tokens[1:self.max_seq_len-1] + [self.tokenizer.eos_id()] |
| |
| |
| input_ids = torch.tensor(tokens[:-1], dtype=torch.long) |
| target_ids = torch.tensor(tokens[1:], dtype=torch.long) |
| |
| return input_ids, target_ids |
|
|
|
|
| def collate_fn(batch, pad_id=0): |
| """Custom collate function con padding inteligente""" |
| input_ids = [item[0] for item in batch] |
| target_ids = [item[1] for item in batch] |
| |
| |
| max_len = max(len(ids) for ids in input_ids) |
| |
| |
| input_ids_padded = [] |
| target_ids_padded = [] |
| |
| for inp, tgt in zip(input_ids, target_ids): |
| pad_len = max_len - len(inp) |
| input_ids_padded.append(torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)])) |
| target_ids_padded.append(torch.cat([tgt, torch.full((pad_len,), pad_id, dtype=torch.long)])) |
| |
| return torch.stack(input_ids_padded), torch.stack(target_ids_padded) |