| |
| """ |
| Preprocessing script for The Vault dataset to adapt it for GLEN training. |
| |
| This script converts The Vault's code-text format to GLEN's expected format for |
| generative retrieval training. |
| """ |
|
|
| import json |
| import pandas as pd |
| import os |
| import argparse |
| from typing import Dict, List, Any |
| from tqdm import tqdm |
| import hashlib |
|
|
| def clean_text(text: str) -> str: |
| """Clean text by removing problematic characters.""" |
| if not text: |
| return "" |
| text = text.replace("\n", " ") |
| text = text.replace("\t", " ") |
| text = text.replace("``", "") |
| text = text.replace('"', "'") |
| |
| text = " ".join(text.split()) |
| return text |
|
|
| def create_document_id(code: str, identifier: str, repo: str) -> str: |
| """Create a unique document ID for the code snippet.""" |
| |
| content = f"{repo}:{identifier}:{code}" |
| hash_obj = hashlib.md5(content.encode()) |
| return hash_obj.hexdigest()[:10] |
|
|
| def process_vault_sample(sample: Dict[str, Any], include_comments: bool = True) -> Dict[str, Any]: |
| """Process a single sample from The Vault dataset.""" |
| |
| |
| identifier = sample.get('identifier', '') |
| docstring = sample.get('docstring', '') |
| short_docstring = sample.get('short_docstring', '') |
| code = sample.get('code', '') |
| language = sample.get('language', '') |
| repo = sample.get('repo', '') |
| path = sample.get('path', '') |
| comments = sample.get('comment', []) |
| |
| |
| description = docstring if docstring else short_docstring |
| |
| |
| if include_comments and comments: |
| comment_text = " ".join([clean_text(c) for c in comments if c]) |
| if comment_text: |
| description = f"{description} {comment_text}" if description else comment_text |
| |
| |
| description = clean_text(description) |
| code = clean_text(code) |
| |
| |
| if description: |
| doc_content = f"{description} [CODE] {code}" |
| else: |
| doc_content = f"Function {identifier} in {language}. [CODE] {code}" |
| |
| |
| doc_id = create_document_id(code, identifier, repo) |
| |
| |
| if description: |
| query = description |
| else: |
| query = f"Find function {identifier} in {language}" |
| |
| return { |
| 'oldid': doc_id, |
| 'doc_content': doc_content, |
| 'query': query, |
| 'identifier': identifier, |
| 'language': language, |
| 'repo': repo, |
| 'path': path |
| } |
|
|
| def create_query_document_pairs(processed_samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """Create query-document pairs for training.""" |
| pairs = [] |
| |
| for sample in processed_samples: |
| pairs.append({ |
| 'query': sample['query'], |
| 'oldid': sample['oldid'], |
| 'docid': sample['oldid'], |
| 'rank': 1, |
| 'neg_docid_list': [], |
| 'aug_query_list': [] |
| }) |
| |
| return pairs |
|
|
| def generate_document_ids(processed_samples: List[Dict[str, Any]], id_class: str = "t5_bm25_truncate_3") -> pd.DataFrame: |
| """Generate document IDs in GLEN's expected format.""" |
| |
| id_data = [] |
| for sample in processed_samples: |
| |
| |
| doc_id = "-".join(list(sample['oldid'][:5])) |
| |
| id_data.append({ |
| 'oldid': sample['oldid'], |
| id_class: doc_id |
| }) |
| |
| return pd.DataFrame(id_data) |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Preprocess The Vault dataset for GLEN') |
| parser.add_argument('--input_dir', type=str, default='the_vault_dataset/', |
| help='Input directory containing The Vault dataset') |
| parser.add_argument('--output_dir', type=str, default='data/the_vault/', |
| help='Output directory for processed data') |
| parser.add_argument('--include_comments', action='store_true', |
| help='Include code comments in descriptions') |
| parser.add_argument('--max_samples', type=int, default=None, |
| help='Maximum number of samples to process (for testing)') |
| parser.add_argument('--create_test_set', action='store_true', |
| help='Create test set for evaluation') |
| |
| args = parser.parse_args() |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| splits = ['train_small', 'validate', 'test'] |
| |
| for split in splits: |
| print(f"Processing {split} split...") |
| |
| input_file = os.path.join(args.input_dir, f"{split}.json") |
| if not os.path.exists(input_file): |
| print(f"Warning: {input_file} not found, skipping...") |
| continue |
| |
| processed_samples = [] |
| |
| |
| with open(input_file, 'r', encoding='utf-8') as f: |
| for i, line in enumerate(tqdm(f, desc=f"Processing {split}")): |
| if args.max_samples and i >= args.max_samples: |
| break |
| |
| try: |
| sample = json.loads(line.strip()) |
| processed_sample = process_vault_sample(sample, args.include_comments) |
| |
| |
| if len(processed_sample['doc_content']) > 50: |
| processed_samples.append(processed_sample) |
| |
| except json.JSONDecodeError as e: |
| print(f"Error parsing line {i}: {e}") |
| continue |
| |
| print(f"Processed {len(processed_samples)} valid samples from {split}") |
| |
| |
| doc_output = split.replace('train_small', 'train') |
| doc_df = pd.DataFrame([{ |
| 'oldid': sample['oldid'], |
| 'doc_content': sample['doc_content'] |
| } for sample in processed_samples]) |
| |
| doc_file = os.path.join(args.output_dir, f"DOC_VAULT_{doc_output}.tsv") |
| doc_df.to_csv(doc_file, sep='\t', index=False, encoding='utf-8') |
| print(f"Saved document data to {doc_file}") |
| |
| |
| if split == 'train_small': |
| pairs = create_query_document_pairs(processed_samples) |
| gtq_df = pd.DataFrame(pairs) |
| gtq_file = os.path.join(args.output_dir, "GTQ_VAULT_train.tsv") |
| gtq_df.to_csv(gtq_file, sep='\t', index=False, encoding='utf-8') |
| print(f"Saved training query-document pairs to {gtq_file}") |
| |
| |
| elif split in ['validate', 'test']: |
| pairs = create_query_document_pairs(processed_samples) |
| |
| gtq_df = pd.DataFrame(pairs) |
| gtq_file = os.path.join(args.output_dir, "GTQ_VAULT_dev.tsv") |
| gtq_df.to_csv(gtq_file, sep='\t', index=False, encoding='utf-8') |
| print(f"Saved evaluation query-document pairs to {gtq_file}") |
| |
| |
| id_df = generate_document_ids(processed_samples) |
| |
| |
| if split == 'train_small': |
| id_file = os.path.join(args.output_dir, "ID_VAULT_t5_bm25_truncate_3.tsv") |
| id_df.to_csv(id_file, sep='\t', index=False, encoding='utf-8') |
| print(f"Saved document IDs to {id_file}") |
|
|
| if __name__ == "__main__": |
| main() |