| """ |
| Dataset Utilities - Helper functions for dataset operations |
| """ |
|
|
| import logging |
| from typing import Dict, Any, List, Optional, Tuple |
| from datasets import load_dataset, Dataset, DatasetDict |
| from transformers import AutoTokenizer |
| import json |
| import os |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| DATASET_COLUMN_MAPPINGS = { |
| "wikitext": {"text": "text"}, |
| "squad": {"question": "question", "context": "context", "answers": "answers"}, |
| "squad_v2": {"question": "question", "context": "context", "answers": "answers"}, |
| "cnn_dailymail": {"article": "article", "highlights": "highlights"}, |
| "xsum": {"document": "document", "summary": "summary"}, |
| "samsum": {"dialogue": "dialogue", "summary": "summary"}, |
| "billsum": {"text": "text", "summary": "summary"}, |
| "aeslc": {"email_body": "email_body", "subject_line": "subject_line"}, |
| "conll2003": {"tokens": "tokens", "ner_tags": "ner_tags"}, |
| "wnut_17": {"tokens": "tokens", "ner_tags": "ner_tags"}, |
| "imdb": {"text": "text", "label": "label"}, |
| "yelp_polarity": {"text": "text", "label": "label"}, |
| "yelp_review_full": {"text": "text", "label": "label"}, |
| "sst2": {"sentence": "sentence", "label": "label"}, |
| "cola": {"sentence": "sentence", "label": "label"}, |
| "mnli": {"premise": "premise", "hypothesis": "hypothesis", "label": "label"}, |
| "qnli": {"question": "question", "sentence": "sentence", "label": "label"}, |
| "qqp": {"question1": "question1", "question2": "question2", "label": "label"}, |
| "mrpc": {"sentence1": "sentence1", "sentence2": "sentence2", "label": "label"}, |
| "stsb": {"sentence1": "sentence1", "sentence2": "sentence2", "label": "label"}, |
| "glue": {}, |
| "super_glue": {}, |
| "trec": {"text": "text", "label": "label"}, |
| "ag_news": {"text": "text", "label": "label"}, |
| "dbpedia_14": {"content": "content", "label": "label"}, |
| "20newsgroups": {"text": "text", "label": "label"}, |
| } |
|
|
| |
| TASK_DATASET_TEMPLATES = { |
| "causal-lm": { |
| "text_column": "text", |
| "format": "causal", |
| "examples": ["wikitext", "openwebtext", "the_pile", "c4", "oscar"], |
| }, |
| "seq2seq": { |
| "input_column": None, |
| "target_column": None, |
| "format": "seq2seq", |
| "examples": ["cnn_dailymail", "xsum", "samsum", "billsum", "aeslc"], |
| }, |
| "token-classification": { |
| "tokens_column": "tokens", |
| "labels_column": "ner_tags", |
| "format": "token", |
| "examples": ["conll2003", "wnut_17", "ontonotes5"], |
| }, |
| "text-classification": { |
| "text_column": "text", |
| "label_column": "label", |
| "format": "classification", |
| "examples": ["imdb", "yelp_polarity", "sst2", "ag_news", "dbpedia_14"], |
| }, |
| "question-answering": { |
| "context_column": "context", |
| "question_column": "question", |
| "answers_column": "answers", |
| "format": "qa", |
| "examples": ["squad", "squad_v2", "natural_questions", "hotpotqa"], |
| }, |
| "reasoning": { |
| "input_column": "input", |
| "target_column": "target", |
| "format": "causal", |
| "examples": ["gsm8k", "strategyqa", "aqua"], |
| }, |
| } |
|
|
|
|
| def get_dataset_info(dataset_name: str) -> Dict[str, Any]: |
| """Get information about a dataset from HuggingFace Hub.""" |
| try: |
| from huggingface_hub import HfApi, dataset_info |
| |
| api = HfApi() |
| info = api.dataset_info(dataset_name) |
| |
| return { |
| "id": info.id, |
| "author": info.author, |
| "sha": info.sha, |
| "downloads": getattr(info, "downloads", 0), |
| "tags": info.tags or [], |
| "description": getattr(info, "description", ""), |
| "card_data": getattr(info, "card_data", {}), |
| "siblings": [s.rfilename for s in info.siblings] if info.siblings else [], |
| "size_bytes": sum(getattr(s, "size", 0) or 0 for s in info.siblings) if info.siblings else 0, |
| } |
| except Exception as e: |
| logger.error(f"Error getting dataset info for {dataset_name}: {e}") |
| return {"error": str(e)} |
|
|
|
|
| def load_and_validate_dataset( |
| dataset_name: str, |
| config: Optional[str] = None, |
| split: Optional[str] = None, |
| trust_remote_code: bool = False, |
| ) -> Tuple[Optional[DatasetDict], Optional[str]]: |
| """Load a dataset and validate it.""" |
| try: |
| kwargs = {"trust_remote_code": trust_remote_code} |
| if config: |
| kwargs["name"] = config |
| if split: |
| kwargs["split"] = split |
| |
| dataset = load_dataset(dataset_name, **kwargs) |
| |
| |
| if isinstance(dataset, Dataset): |
| dataset = DatasetDict({"train": dataset}) |
| |
| return dataset, None |
| |
| except Exception as e: |
| logger.error(f"Error loading dataset {dataset_name}: {e}") |
| return None, str(e) |
|
|
|
|
| def get_dataset_schema(dataset: DatasetDict) -> Dict[str, Any]: |
| """Get the schema of a dataset.""" |
| if not dataset: |
| return {} |
| |
| |
| first_split = list(dataset.keys())[0] |
| ds = dataset[first_split] |
| |
| schema = { |
| "splits": list(dataset.keys()), |
| "columns": {}, |
| "num_rows": {}, |
| "features": {}, |
| } |
| |
| for split_name, split_ds in dataset.items(): |
| schema["num_rows"][split_name] = len(split_ds) |
| |
| for col in ds.column_names: |
| col_info = {"name": col} |
| feature = ds.features.get(col) |
| if feature: |
| col_info["dtype"] = str(feature.dtype) if hasattr(feature, "dtype") else str(type(feature)) |
| if hasattr(feature, "names"): |
| col_info["label_names"] = list(feature.names) |
| col_info["feature_type"] = type(feature).__name__ |
| schema["columns"][col] = col_info |
| schema["features"][col] = str(feature) if feature else "unknown" |
| |
| return schema |
|
|
|
|
| def detect_task_type(dataset_name: str, dataset: DatasetDict) -> str: |
| """Detect the likely task type for a dataset based on its columns.""" |
| if not dataset: |
| return "unknown" |
| |
| first_split = list(dataset.keys())[0] |
| columns = set(dataset[first_split].column_names) |
| |
| |
| if "tokens" in columns and "ner_tags" in columns: |
| return "token-classification" |
| if "question" in columns and "context" in columns: |
| return "question-answering" |
| if "article" in columns or "document" in columns: |
| return "seq2seq" |
| if "text" in columns and "label" in columns: |
| return "text-classification" |
| if "text" in columns and len(columns) <= 3: |
| return "causal-lm" |
| if "dialogue" in columns or "summary" in columns: |
| return "seq2seq" |
| if "input" in columns and "target" in columns: |
| return "causal-lm" |
| |
| |
| return "causal-lm" |
|
|
|
|
| def get_dataset_columns_for_task( |
| dataset: DatasetDict, |
| task_type: str |
| ) -> Dict[str, str]: |
| """Get the appropriate column mapping for a task.""" |
| if not dataset: |
| return {} |
| |
| first_split = list(dataset.keys())[0] |
| columns = set(dataset[first_split].column_names) |
| |
| mapping = {} |
| |
| if task_type == "causal-lm": |
| |
| for col in ["text", "content", "document", "article", "input"]: |
| if col in columns: |
| mapping["text_column"] = col |
| break |
| if not mapping and len(columns) == 1: |
| mapping["text_column"] = list(columns)[0] |
| |
| elif task_type == "seq2seq": |
| for col in ["article", "document", "text", "input", "dialogue"]: |
| if col in columns: |
| mapping["input_column"] = col |
| break |
| for col in ["highlights", "summary", "target", "output", "subject_line"]: |
| if col in columns: |
| mapping["target_column"] = col |
| break |
| |
| elif task_type == "token-classification": |
| for col in ["tokens", "words"]: |
| if col in columns: |
| mapping["tokens_column"] = col |
| break |
| for col in ["ner_tags", "labels", "tags"]: |
| if col in columns: |
| mapping["labels_column"] = col |
| break |
| |
| elif task_type == "text-classification": |
| for col in ["text", "sentence", "content", "review"]: |
| if col in columns: |
| mapping["text_column"] = col |
| break |
| for col in ["label", "labels", "class", "category"]: |
| if col in columns: |
| mapping["label_column"] = col |
| break |
| |
| elif task_type == "question-answering": |
| for col in ["context"]: |
| if col in columns: |
| mapping["context_column"] = col |
| for col in ["question"]: |
| if col in columns: |
| mapping["question_column"] = col |
| for col in ["answers", "answer"]: |
| if col in columns: |
| mapping["answers_column"] = col |
| |
| return mapping |
|
|
|
|
| def prepare_dataset_for_training( |
| dataset: DatasetDict, |
| tokenizer: Any, |
| task_type: str, |
| column_mapping: Dict[str, str], |
| max_length: int = 512, |
| padding: str = "max_length", |
| truncation: bool = True, |
| ) -> Tuple[DatasetDict, Dict[str, Any]]: |
| """Prepare dataset for training by tokenizing.""" |
| |
| stats = { |
| "original_samples": {}, |
| "processed_samples": {}, |
| "avg_length": {}, |
| "removed_samples": {}, |
| } |
| |
| def tokenize_function(examples, text_col=None, target_col=None): |
| """Tokenize function based on task type.""" |
| if task_type == "causal-lm": |
| text_col = column_mapping.get("text_column", "text") |
| if text_col not in examples: |
| return examples |
| |
| outputs = tokenizer( |
| examples[text_col], |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| return_tensors=None, |
| ) |
| outputs["labels"] = outputs["input_ids"].copy() |
| return outputs |
| |
| elif task_type == "seq2seq": |
| input_col = column_mapping.get("input_column") |
| target_col = column_mapping.get("target_column") |
| |
| if not input_col or not target_col: |
| raise ValueError(f"Missing columns for seq2seq: {column_mapping}") |
| |
| model_inputs = tokenizer( |
| examples[input_col], |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| ) |
| |
| with tokenizer.as_target_tokenizer(): |
| labels = tokenizer( |
| examples[target_col], |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| ) |
| |
| model_inputs["labels"] = labels["input_ids"] |
| return model_inputs |
| |
| elif task_type == "token-classification": |
| tokens_col = column_mapping.get("tokens_column", "tokens") |
| labels_col = column_mapping.get("labels_column", "ner_tags") |
| |
| if tokens_col not in examples or labels_col not in examples: |
| return examples |
| |
| tokenized_inputs = tokenizer( |
| examples[tokens_col], |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| is_split_into_words=True, |
| ) |
| |
| labels = [] |
| for i, label in enumerate(examples[labels_col]): |
| word_ids = tokenized_inputs.word_ids(batch_index=i) |
| previous_word_idx = None |
| label_ids = [] |
| for word_idx in word_ids: |
| if word_idx is None: |
| label_ids.append(-100) |
| elif word_idx != previous_word_idx: |
| label_ids.append(label[word_idx]) |
| else: |
| label_ids.append(-100) |
| previous_word_idx = word_idx |
| labels.append(label_ids) |
| |
| |
| tokenized_inputs["labels"] = labels |
| return tokenized_inputs |
| |
| elif task_type == "text-classification": |
| text_col = column_mapping.get("text_column", "text") |
| if text_col not in examples: |
| return examples |
| |
| tokenized = tokenizer( |
| examples[text_col], |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| ) |
| |
| |
| label_col = column_mapping.get("label_column", "label") |
| if label_col in examples: |
| tokenized["labels"] = examples[label_col] |
| |
| return tokenized |
| |
| elif task_type == "question-answering": |
| context_col = column_mapping.get("context_column", "context") |
| question_col = column_mapping.get("question_column", "question") |
| answers_col = column_mapping.get("answers_column", "answers") |
| |
| tokenized = tokenizer( |
| examples[question_col], |
| examples[context_col], |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| ) |
| |
| |
| if answers_col in examples: |
| |
| tokenized["labels"] = [[0, 0] for _ in examples[answers_col]] |
| |
| return tokenized |
| |
| |
| return examples |
| |
| |
| tokenized_datasets = DatasetDict() |
| for split_name, split_ds in dataset.items(): |
| stats["original_samples"][split_name] = len(split_ds) |
| |
| |
| remove_columns = [] |
| for col in split_ds.column_names: |
| if col not in ["labels", "label", "input_ids", "attention_mask"]: |
| if col not in column_mapping.values(): |
| remove_columns.append(col) |
| |
| |
| tokenized = split_ds.map( |
| tokenize_function, |
| batched=True, |
| remove_columns=remove_columns, |
| desc=f"Tokenizing {split_name}", |
| ) |
| |
| tokenized_datasets[split_name] = tokenized |
| stats["processed_samples"][split_name] = len(tokenized) |
| |
| return tokenized_datasets, stats |
|
|
|
|
| def split_dataset( |
| dataset: DatasetDict, |
| train_split: float = 0.9, |
| val_split: float = 0.1, |
| seed: int = 42, |
| ) -> DatasetDict: |
| """Split a dataset into train and validation sets.""" |
| if "validation" in dataset: |
| return dataset |
| |
| if "train" in dataset: |
| split_dataset = dataset["train"].train_test_split( |
| test_size=val_split, |
| seed=seed, |
| ) |
| return DatasetDict({ |
| "train": split_dataset["train"], |
| "validation": split_dataset["test"], |
| }) |
| |
| return dataset |
|
|
|
|
| def sample_dataset( |
| dataset: DatasetDict, |
| n_samples: int, |
| split: str = "train", |
| seed: int = 42, |
| ) -> DatasetDict: |
| """Sample a subset of the dataset for quick testing.""" |
| if split not in dataset: |
| return dataset |
| |
| sampled = dataset[split].shuffle(seed=seed).select(range(min(n_samples, len(dataset[split])))) |
| |
| result = dict(dataset) |
| result[split] = sampled |
| return DatasetDict(result) |
|
|
|
|
| def get_label_list(dataset: DatasetDict, label_column: str = "label") -> List[str]: |
| """Get list of labels from dataset.""" |
| if not dataset: |
| return [] |
| |
| for split_name, split_ds in dataset.items(): |
| if label_column in split_ds.column_names: |
| features = split_ds.features.get(label_column) |
| if features and hasattr(features, "names"): |
| return list(features.names) |
| elif features and hasattr(features, "int2str"): |
| |
| unique_labels = set(split_ds[label_column]) |
| return [str(i) for i in range(max(unique_labels) + 1)] |
| |
| return [] |
|
|
|
|
| def estimate_dataset_size(dataset: DatasetDict) -> Dict[str, Any]: |
| """Estimate dataset size in memory.""" |
| if not dataset: |
| return {"total_samples": 0, "estimated_size_mb": 0} |
| |
| total_samples = sum(len(split) for split in dataset.values()) |
| |
| |
| estimated_size_mb = total_samples * 0.001 |
| |
| return { |
| "total_samples": total_samples, |
| "estimated_size_mb": round(estimated_size_mb, 2), |
| "splits": {name: len(split) for name, split in dataset.items()}, |
| } |
|
|
|
|
| def validate_dataset_for_task( |
| dataset: DatasetDict, |
| task_type: str, |
| column_mapping: Dict[str, str], |
| ) -> Tuple[bool, List[str]]: |
| """Validate that a dataset is suitable for a task.""" |
| issues = [] |
| |
| if not dataset: |
| return False, ["Dataset is empty or could not be loaded"] |
| |
| first_split = list(dataset.keys())[0] |
| columns = set(dataset[first_split].column_names) |
| |
| if task_type == "causal-lm": |
| text_col = column_mapping.get("text_column") |
| if not text_col or text_col not in columns: |
| issues.append(f"Missing text column. Found: {columns}") |
| |
| elif task_type == "seq2seq": |
| input_col = column_mapping.get("input_column") |
| target_col = column_mapping.get("target_column") |
| if not input_col or input_col not in columns: |
| issues.append(f"Missing input column. Found: {columns}") |
| if not target_col or target_col not in columns: |
| issues.append(f"Missing target column. Found: {columns}") |
| |
| elif task_type == "token-classification": |
| tokens_col = column_mapping.get("tokens_column") |
| labels_col = column_mapping.get("labels_column") |
| if not tokens_col or tokens_col not in columns: |
| issues.append(f"Missing tokens column. Found: {columns}") |
| if not labels_col or labels_col not in columns: |
| issues.append(f"Missing labels column. Found: {columns}") |
| |
| elif task_type == "text-classification": |
| text_col = column_mapping.get("text_column") |
| label_col = column_mapping.get("label_column") |
| if not text_col or text_col not in columns: |
| issues.append(f"Missing text column. Found: {columns}") |
| if not label_col or label_col not in columns: |
| issues.append(f"Missing label column. Found: {columns}") |
| |
| elif task_type == "question-answering": |
| required = ["context_column", "question_column", "answers_column"] |
| for col_key in required: |
| col = column_mapping.get(col_key) |
| if not col or col not in columns: |
| issues.append(f"Missing {col_key}. Found: {columns}") |
| |
| return len(issues) == 0, issues |