# -*- coding: utf-8 -*-
"""Dataset utilities for ReACC-style generation.
This module mirrors the CodeXGLUE/ReACC style where the generator reads
retrieved code + current context and learns to predict only continuation tokens.
Generator-only baseline is supported by setting `retrieved` to an empty string.
Expected JSONL schema per line:
{"retrieved": "...", "context": "...", "target": "..."}
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence
import torch
from torch.utils.data import Dataset
RET_START = ""
RET_END = ""
CTX_START = ""
CTX_END = ""
GEN_START = ""
SPECIAL_TOKENS = [RET_START, RET_END, CTX_START, CTX_END, GEN_START]
def load_jsonl(path: str) -> List[Dict[str, str]]:
data: List[Dict[str, str]] = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
ex = json.loads(line)
ex.setdefault('retrieved', '')
ex.setdefault('context', '')
ex.setdefault('target', '')
data.append(ex)
return data
def save_jsonl(path: str, rows: Sequence[Dict[str, str]]) -> None:
with open(path, 'w', encoding='utf-8') as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False) + '\n')
def build_prompt(retrieved: str, context: str) -> str:
"""ReACC prompt: retrieved code first, then unfinished context."""
return (
f"{RET_START}\n{retrieved.strip()}\n{RET_END}\n"
f"{CTX_START}\n{context.rstrip()}\n{CTX_END}\n"
f"{GEN_START}\n"
)
@dataclass
class EncodedSample:
input_ids: List[int]
attention_mask: List[int]
labels: List[int]
prompt_length: int
class ReACCGeneratorDataset(Dataset):
"""Causal-LM dataset with prompt-masked labels.
Labels are -100 on prompt tokens and equal to token ids on target tokens.
Any example whose target tokenizes to zero tokens will be skipped to avoid
all-ignored labels (which can cause NaN loss).
"""
def __init__(
self,
data,
tokenizer,
max_length: int = 384,
max_target_length: int = 96,
):
self.tokenizer = tokenizer
self.max_length = int(max_length)
self.max_target_length = int(max_target_length)
self.examples: List[EncodedSample] = []
for ex in data:
enc = self.encode_example(ex)
if enc is not None:
self.examples.append(enc)
def __len__(self):
return len(self.examples)
def _encode_text(self, text: str, truncation: bool = False, max_length: Optional[int] = None) -> List[int]:
return self.tokenizer.encode(
text,
add_special_tokens=False,
truncation=truncation,
max_length=max_length,
)
def _safe_decode(self, ids: List[int]) -> str:
return self.tokenizer.decode(ids, clean_up_tokenization_spaces=False)
def _budgeted_fields(self, retrieved: str, context: str, target: str):
target_ids = self._encode_text(
target, truncation=True, max_length=self.max_target_length)
if len(target_ids) == 0:
return None
empty_prompt_len = len(self._encode_text(build_prompt('', '')))
prompt_budget = max(
self.max_length - len(target_ids) - empty_prompt_len, 32)
retrieved_budget = prompt_budget // 2
context_budget = prompt_budget - retrieved_budget
retrieved_ids = self._encode_text(
retrieved, truncation=True, max_length=retrieved_budget)
context_ids_full = self._encode_text(context, truncation=False)
context_ids = context_ids_full[-context_budget:] if len(
context_ids_full) > context_budget else context_ids_full
return (
self._safe_decode(retrieved_ids),
self._safe_decode(context_ids),
self._safe_decode(target_ids),
)
def encode_example(self, ex: Dict[str, str]) -> Optional[EncodedSample]:
maybe = self._budgeted_fields(ex.get('retrieved', ''), ex.get(
'context', ''), ex.get('target', ''))
if maybe is None:
return None
retrieved, context, target = maybe
prompt = build_prompt(retrieved, context)
prompt_ids = self._encode_text(prompt)
target_ids = self._encode_text(
target, truncation=True, max_length=self.max_target_length)
if len(target_ids) == 0:
return None
input_ids = (prompt_ids + target_ids)[: self.max_length]
prompt_length = min(len(prompt_ids), len(input_ids))
# If all remaining target tokens are truncated away, skip the example.
if len(input_ids) <= prompt_length:
return None
labels = [-100] * prompt_length + input_ids[prompt_length:]
labels = labels[: len(input_ids)]
attention_mask = [1] * len(input_ids)
# Safety check: require at least one supervised token.
if all(x == -100 for x in labels):
return None
return EncodedSample(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
prompt_length=prompt_length,
)
def __getitem__(self, idx: int):
enc = self.examples[idx]
return {
'input_ids': enc.input_ids,
'attention_mask': enc.attention_mask,
'labels': enc.labels,
'prompt_length': enc.prompt_length,
}
class ReACCInferenceDataset(Dataset):
"""Prompt-only dataset for evaluation / generation."""
def __init__(self, data, tokenizer, max_length: int = 384):
self.data = data
self.tokenizer = tokenizer
self.max_length = int(max_length)
def __len__(self):
return len(self.data)
def __getitem__(self, idx: int):
ex = self.data[idx]
prompt = build_prompt(ex.get('retrieved', ''), ex.get('context', ''))
input_ids = self.tokenizer.encode(
prompt,
add_special_tokens=False,
truncation=True,
max_length=self.max_length,
)
return {
'input_ids': input_ids,
'attention_mask': [1] * len(input_ids),
'meta': ex,
}
def collate_batch(batch, pad_token_id: int):
max_len = max(len(x['input_ids']) for x in batch)
input_ids, attention_mask, labels, prompt_lengths = [], [], [], []
for x in batch:
pad_len = max_len - len(x['input_ids'])
input_ids.append(x['input_ids'] + [pad_token_id] * pad_len)
attention_mask.append(x['attention_mask'] + [0] * pad_len)
if 'labels' in x:
labels.append(x['labels'] + [-100] * pad_len)
prompt_lengths.append(x.get('prompt_length', 0))
out = {
'input_ids': torch.tensor(input_ids, dtype=torch.long),
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
'prompt_length': torch.tensor(prompt_lengths, dtype=torch.long),
}
if labels:
out['labels'] = torch.tensor(labels, dtype=torch.long)
return out