# -*- coding: utf-8 -*- """Dataset utilities for ReACC-style generation. This module mirrors the CodeXGLUE/ReACC style where the generator reads retrieved code + current context and learns to predict only continuation tokens. Generator-only baseline is supported by setting `retrieved` to an empty string. Expected JSONL schema per line: {"retrieved": "...", "context": "...", "target": "..."} """ from __future__ import annotations import json from dataclasses import dataclass from typing import Dict, List, Optional, Sequence import torch from torch.utils.data import Dataset RET_START = "" RET_END = "" CTX_START = "" CTX_END = "" GEN_START = "" SPECIAL_TOKENS = [RET_START, RET_END, CTX_START, CTX_END, GEN_START] def load_jsonl(path: str) -> List[Dict[str, str]]: data: List[Dict[str, str]] = [] with open(path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue ex = json.loads(line) ex.setdefault('retrieved', '') ex.setdefault('context', '') ex.setdefault('target', '') data.append(ex) return data def save_jsonl(path: str, rows: Sequence[Dict[str, str]]) -> None: with open(path, 'w', encoding='utf-8') as f: for row in rows: f.write(json.dumps(row, ensure_ascii=False) + '\n') def build_prompt(retrieved: str, context: str) -> str: """ReACC prompt: retrieved code first, then unfinished context.""" return ( f"{RET_START}\n{retrieved.strip()}\n{RET_END}\n" f"{CTX_START}\n{context.rstrip()}\n{CTX_END}\n" f"{GEN_START}\n" ) @dataclass class EncodedSample: input_ids: List[int] attention_mask: List[int] labels: List[int] prompt_length: int class ReACCGeneratorDataset(Dataset): """Causal-LM dataset with prompt-masked labels. Labels are -100 on prompt tokens and equal to token ids on target tokens. Any example whose target tokenizes to zero tokens will be skipped to avoid all-ignored labels (which can cause NaN loss). """ def __init__( self, data, tokenizer, max_length: int = 384, max_target_length: int = 96, ): self.tokenizer = tokenizer self.max_length = int(max_length) self.max_target_length = int(max_target_length) self.examples: List[EncodedSample] = [] for ex in data: enc = self.encode_example(ex) if enc is not None: self.examples.append(enc) def __len__(self): return len(self.examples) def _encode_text(self, text: str, truncation: bool = False, max_length: Optional[int] = None) -> List[int]: return self.tokenizer.encode( text, add_special_tokens=False, truncation=truncation, max_length=max_length, ) def _safe_decode(self, ids: List[int]) -> str: return self.tokenizer.decode(ids, clean_up_tokenization_spaces=False) def _budgeted_fields(self, retrieved: str, context: str, target: str): target_ids = self._encode_text( target, truncation=True, max_length=self.max_target_length) if len(target_ids) == 0: return None empty_prompt_len = len(self._encode_text(build_prompt('', ''))) prompt_budget = max( self.max_length - len(target_ids) - empty_prompt_len, 32) retrieved_budget = prompt_budget // 2 context_budget = prompt_budget - retrieved_budget retrieved_ids = self._encode_text( retrieved, truncation=True, max_length=retrieved_budget) context_ids_full = self._encode_text(context, truncation=False) context_ids = context_ids_full[-context_budget:] if len( context_ids_full) > context_budget else context_ids_full return ( self._safe_decode(retrieved_ids), self._safe_decode(context_ids), self._safe_decode(target_ids), ) def encode_example(self, ex: Dict[str, str]) -> Optional[EncodedSample]: maybe = self._budgeted_fields(ex.get('retrieved', ''), ex.get( 'context', ''), ex.get('target', '')) if maybe is None: return None retrieved, context, target = maybe prompt = build_prompt(retrieved, context) prompt_ids = self._encode_text(prompt) target_ids = self._encode_text( target, truncation=True, max_length=self.max_target_length) if len(target_ids) == 0: return None input_ids = (prompt_ids + target_ids)[: self.max_length] prompt_length = min(len(prompt_ids), len(input_ids)) # If all remaining target tokens are truncated away, skip the example. if len(input_ids) <= prompt_length: return None labels = [-100] * prompt_length + input_ids[prompt_length:] labels = labels[: len(input_ids)] attention_mask = [1] * len(input_ids) # Safety check: require at least one supervised token. if all(x == -100 for x in labels): return None return EncodedSample( input_ids=input_ids, attention_mask=attention_mask, labels=labels, prompt_length=prompt_length, ) def __getitem__(self, idx: int): enc = self.examples[idx] return { 'input_ids': enc.input_ids, 'attention_mask': enc.attention_mask, 'labels': enc.labels, 'prompt_length': enc.prompt_length, } class ReACCInferenceDataset(Dataset): """Prompt-only dataset for evaluation / generation.""" def __init__(self, data, tokenizer, max_length: int = 384): self.data = data self.tokenizer = tokenizer self.max_length = int(max_length) def __len__(self): return len(self.data) def __getitem__(self, idx: int): ex = self.data[idx] prompt = build_prompt(ex.get('retrieved', ''), ex.get('context', '')) input_ids = self.tokenizer.encode( prompt, add_special_tokens=False, truncation=True, max_length=self.max_length, ) return { 'input_ids': input_ids, 'attention_mask': [1] * len(input_ids), 'meta': ex, } def collate_batch(batch, pad_token_id: int): max_len = max(len(x['input_ids']) for x in batch) input_ids, attention_mask, labels, prompt_lengths = [], [], [], [] for x in batch: pad_len = max_len - len(x['input_ids']) input_ids.append(x['input_ids'] + [pad_token_id] * pad_len) attention_mask.append(x['attention_mask'] + [0] * pad_len) if 'labels' in x: labels.append(x['labels'] + [-100] * pad_len) prompt_lengths.append(x.get('prompt_length', 0)) out = { 'input_ids': torch.tensor(input_ids, dtype=torch.long), 'attention_mask': torch.tensor(attention_mask, dtype=torch.long), 'prompt_length': torch.tensor(prompt_lengths, dtype=torch.long), } if labels: out['labels'] = torch.tensor(labels, dtype=torch.long) return out