| |
| """Dataset utilities for ReACC-style generation. |
| |
| This module mirrors the CodeXGLUE/ReACC style where the generator reads |
| retrieved code + current context and learns to predict only continuation tokens. |
| |
| Generator-only baseline is supported by setting `retrieved` to an empty string. |
| Expected JSONL schema per line: |
| {"retrieved": "...", "context": "...", "target": "..."} |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from dataclasses import dataclass |
| from typing import Dict, List, Optional, Sequence |
|
|
| import torch |
| from torch.utils.data import Dataset |
|
|
| RET_START = "<RET>" |
| RET_END = "</RET>" |
| CTX_START = "<CTX>" |
| CTX_END = "</CTX>" |
| GEN_START = "<GEN>" |
|
|
| SPECIAL_TOKENS = [RET_START, RET_END, CTX_START, CTX_END, GEN_START] |
|
|
|
|
| def load_jsonl(path: str) -> List[Dict[str, str]]: |
| data: List[Dict[str, str]] = [] |
| with open(path, 'r', encoding='utf-8') as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| ex = json.loads(line) |
| ex.setdefault('retrieved', '') |
| ex.setdefault('context', '') |
| ex.setdefault('target', '') |
| data.append(ex) |
| return data |
|
|
|
|
| def save_jsonl(path: str, rows: Sequence[Dict[str, str]]) -> None: |
| with open(path, 'w', encoding='utf-8') as f: |
| for row in rows: |
| f.write(json.dumps(row, ensure_ascii=False) + '\n') |
|
|
|
|
| def build_prompt(retrieved: str, context: str) -> str: |
| """ReACC prompt: retrieved code first, then unfinished context.""" |
| return ( |
| f"{RET_START}\n{retrieved.strip()}\n{RET_END}\n" |
| f"{CTX_START}\n{context.rstrip()}\n{CTX_END}\n" |
| f"{GEN_START}\n" |
| ) |
|
|
|
|
| @dataclass |
| class EncodedSample: |
| input_ids: List[int] |
| attention_mask: List[int] |
| labels: List[int] |
| prompt_length: int |
|
|
|
|
| class ReACCGeneratorDataset(Dataset): |
| """Causal-LM dataset with prompt-masked labels. |
| |
| Labels are -100 on prompt tokens and equal to token ids on target tokens. |
| Any example whose target tokenizes to zero tokens will be skipped to avoid |
| all-ignored labels (which can cause NaN loss). |
| """ |
|
|
| def __init__( |
| self, |
| data, |
| tokenizer, |
| max_length: int = 384, |
| max_target_length: int = 96, |
| ): |
| self.tokenizer = tokenizer |
| self.max_length = int(max_length) |
| self.max_target_length = int(max_target_length) |
| self.examples: List[EncodedSample] = [] |
|
|
| for ex in data: |
| enc = self.encode_example(ex) |
| if enc is not None: |
| self.examples.append(enc) |
|
|
| def __len__(self): |
| return len(self.examples) |
|
|
| def _encode_text(self, text: str, truncation: bool = False, max_length: Optional[int] = None) -> List[int]: |
| return self.tokenizer.encode( |
| text, |
| add_special_tokens=False, |
| truncation=truncation, |
| max_length=max_length, |
| ) |
|
|
| def _safe_decode(self, ids: List[int]) -> str: |
| return self.tokenizer.decode(ids, clean_up_tokenization_spaces=False) |
|
|
| def _budgeted_fields(self, retrieved: str, context: str, target: str): |
| target_ids = self._encode_text( |
| target, truncation=True, max_length=self.max_target_length) |
| if len(target_ids) == 0: |
| return None |
|
|
| empty_prompt_len = len(self._encode_text(build_prompt('', ''))) |
| prompt_budget = max( |
| self.max_length - len(target_ids) - empty_prompt_len, 32) |
| retrieved_budget = prompt_budget // 2 |
| context_budget = prompt_budget - retrieved_budget |
|
|
| retrieved_ids = self._encode_text( |
| retrieved, truncation=True, max_length=retrieved_budget) |
| context_ids_full = self._encode_text(context, truncation=False) |
| context_ids = context_ids_full[-context_budget:] if len( |
| context_ids_full) > context_budget else context_ids_full |
|
|
| return ( |
| self._safe_decode(retrieved_ids), |
| self._safe_decode(context_ids), |
| self._safe_decode(target_ids), |
| ) |
|
|
| def encode_example(self, ex: Dict[str, str]) -> Optional[EncodedSample]: |
| maybe = self._budgeted_fields(ex.get('retrieved', ''), ex.get( |
| 'context', ''), ex.get('target', '')) |
| if maybe is None: |
| return None |
| retrieved, context, target = maybe |
|
|
| prompt = build_prompt(retrieved, context) |
| prompt_ids = self._encode_text(prompt) |
| target_ids = self._encode_text( |
| target, truncation=True, max_length=self.max_target_length) |
| if len(target_ids) == 0: |
| return None |
|
|
| input_ids = (prompt_ids + target_ids)[: self.max_length] |
| prompt_length = min(len(prompt_ids), len(input_ids)) |
|
|
| |
| if len(input_ids) <= prompt_length: |
| return None |
|
|
| labels = [-100] * prompt_length + input_ids[prompt_length:] |
| labels = labels[: len(input_ids)] |
| attention_mask = [1] * len(input_ids) |
|
|
| |
| if all(x == -100 for x in labels): |
| return None |
|
|
| return EncodedSample( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| prompt_length=prompt_length, |
| ) |
|
|
| def __getitem__(self, idx: int): |
| enc = self.examples[idx] |
| return { |
| 'input_ids': enc.input_ids, |
| 'attention_mask': enc.attention_mask, |
| 'labels': enc.labels, |
| 'prompt_length': enc.prompt_length, |
| } |
|
|
|
|
| class ReACCInferenceDataset(Dataset): |
| """Prompt-only dataset for evaluation / generation.""" |
|
|
| def __init__(self, data, tokenizer, max_length: int = 384): |
| self.data = data |
| self.tokenizer = tokenizer |
| self.max_length = int(max_length) |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx: int): |
| ex = self.data[idx] |
| prompt = build_prompt(ex.get('retrieved', ''), ex.get('context', '')) |
| input_ids = self.tokenizer.encode( |
| prompt, |
| add_special_tokens=False, |
| truncation=True, |
| max_length=self.max_length, |
| ) |
| return { |
| 'input_ids': input_ids, |
| 'attention_mask': [1] * len(input_ids), |
| 'meta': ex, |
| } |
|
|
|
|
| def collate_batch(batch, pad_token_id: int): |
| max_len = max(len(x['input_ids']) for x in batch) |
| input_ids, attention_mask, labels, prompt_lengths = [], [], [], [] |
| for x in batch: |
| pad_len = max_len - len(x['input_ids']) |
| input_ids.append(x['input_ids'] + [pad_token_id] * pad_len) |
| attention_mask.append(x['attention_mask'] + [0] * pad_len) |
| if 'labels' in x: |
| labels.append(x['labels'] + [-100] * pad_len) |
| prompt_lengths.append(x.get('prompt_length', 0)) |
|
|
| out = { |
| 'input_ids': torch.tensor(input_ids, dtype=torch.long), |
| 'attention_mask': torch.tensor(attention_mask, dtype=torch.long), |
| 'prompt_length': torch.tensor(prompt_lengths, dtype=torch.long), |
| } |
| if labels: |
| out['labels'] = torch.tensor(labels, dtype=torch.long) |
| return out |
|
|