""" Dataset Utilities - Helper functions for dataset operations """ import logging from typing import Dict, Any, List, Optional, Tuple from datasets import load_dataset, Dataset, DatasetDict from transformers import AutoTokenizer import json import os logger = logging.getLogger(__name__) # Dataset column mappings for common datasets DATASET_COLUMN_MAPPINGS = { "wikitext": {"text": "text"}, "squad": {"question": "question", "context": "context", "answers": "answers"}, "squad_v2": {"question": "question", "context": "context", "answers": "answers"}, "cnn_dailymail": {"article": "article", "highlights": "highlights"}, "xsum": {"document": "document", "summary": "summary"}, "samsum": {"dialogue": "dialogue", "summary": "summary"}, "billsum": {"text": "text", "summary": "summary"}, "aeslc": {"email_body": "email_body", "subject_line": "subject_line"}, "conll2003": {"tokens": "tokens", "ner_tags": "ner_tags"}, "wnut_17": {"tokens": "tokens", "ner_tags": "ner_tags"}, "imdb": {"text": "text", "label": "label"}, "yelp_polarity": {"text": "text", "label": "label"}, "yelp_review_full": {"text": "text", "label": "label"}, "sst2": {"sentence": "sentence", "label": "label"}, "cola": {"sentence": "sentence", "label": "label"}, "mnli": {"premise": "premise", "hypothesis": "hypothesis", "label": "label"}, "qnli": {"question": "question", "sentence": "sentence", "label": "label"}, "qqp": {"question1": "question1", "question2": "question2", "label": "label"}, "mrpc": {"sentence1": "sentence1", "sentence2": "sentence2", "label": "label"}, "stsb": {"sentence1": "sentence1", "sentence2": "sentence2", "label": "label"}, "glue": {}, "super_glue": {}, "trec": {"text": "text", "label": "label"}, "ag_news": {"text": "text", "label": "label"}, "dbpedia_14": {"content": "content", "label": "label"}, "20newsgroups": {"text": "text", "label": "label"}, } # Task-specific dataset templates TASK_DATASET_TEMPLATES = { "causal-lm": { "text_column": "text", "format": "causal", "examples": ["wikitext", "openwebtext", "the_pile", "c4", "oscar"], }, "seq2seq": { "input_column": None, "target_column": None, "format": "seq2seq", "examples": ["cnn_dailymail", "xsum", "samsum", "billsum", "aeslc"], }, "token-classification": { "tokens_column": "tokens", "labels_column": "ner_tags", "format": "token", "examples": ["conll2003", "wnut_17", "ontonotes5"], }, "text-classification": { "text_column": "text", "label_column": "label", "format": "classification", "examples": ["imdb", "yelp_polarity", "sst2", "ag_news", "dbpedia_14"], }, "question-answering": { "context_column": "context", "question_column": "question", "answers_column": "answers", "format": "qa", "examples": ["squad", "squad_v2", "natural_questions", "hotpotqa"], }, "reasoning": { "input_column": "input", "target_column": "target", "format": "causal", "examples": ["gsm8k", "strategyqa", "aqua"], }, } def get_dataset_info(dataset_name: str) -> Dict[str, Any]: """Get information about a dataset from HuggingFace Hub.""" try: from huggingface_hub import HfApi, dataset_info api = HfApi() info = api.dataset_info(dataset_name) return { "id": info.id, "author": info.author, "sha": info.sha, "downloads": getattr(info, "downloads", 0), "tags": info.tags or [], "description": getattr(info, "description", ""), "card_data": getattr(info, "card_data", {}), "siblings": [s.rfilename for s in info.siblings] if info.siblings else [], "size_bytes": sum(getattr(s, "size", 0) or 0 for s in info.siblings) if info.siblings else 0, } except Exception as e: logger.error(f"Error getting dataset info for {dataset_name}: {e}") return {"error": str(e)} def load_and_validate_dataset( dataset_name: str, config: Optional[str] = None, split: Optional[str] = None, trust_remote_code: bool = False, ) -> Tuple[Optional[DatasetDict], Optional[str]]: """Load a dataset and validate it.""" try: kwargs = {"trust_remote_code": trust_remote_code} if config: kwargs["name"] = config if split: kwargs["split"] = split dataset = load_dataset(dataset_name, **kwargs) # If single split returned, wrap in dict if isinstance(dataset, Dataset): dataset = DatasetDict({"train": dataset}) return dataset, None except Exception as e: logger.error(f"Error loading dataset {dataset_name}: {e}") return None, str(e) def get_dataset_schema(dataset: DatasetDict) -> Dict[str, Any]: """Get the schema of a dataset.""" if not dataset: return {} # Get first available split first_split = list(dataset.keys())[0] ds = dataset[first_split] schema = { "splits": list(dataset.keys()), "columns": {}, "num_rows": {}, "features": {}, } for split_name, split_ds in dataset.items(): schema["num_rows"][split_name] = len(split_ds) for col in ds.column_names: col_info = {"name": col} feature = ds.features.get(col) if feature: col_info["dtype"] = str(feature.dtype) if hasattr(feature, "dtype") else str(type(feature)) if hasattr(feature, "names"): col_info["label_names"] = list(feature.names) col_info["feature_type"] = type(feature).__name__ schema["columns"][col] = col_info schema["features"][col] = str(feature) if feature else "unknown" return schema def detect_task_type(dataset_name: str, dataset: DatasetDict) -> str: """Detect the likely task type for a dataset based on its columns.""" if not dataset: return "unknown" first_split = list(dataset.keys())[0] columns = set(dataset[first_split].column_names) # Check for specific patterns if "tokens" in columns and "ner_tags" in columns: return "token-classification" if "question" in columns and "context" in columns: return "question-answering" if "article" in columns or "document" in columns: return "seq2seq" if "text" in columns and "label" in columns: return "text-classification" if "text" in columns and len(columns) <= 3: return "causal-lm" if "dialogue" in columns or "summary" in columns: return "seq2seq" if "input" in columns and "target" in columns: return "causal-lm" # Default return "causal-lm" def get_dataset_columns_for_task( dataset: DatasetDict, task_type: str ) -> Dict[str, str]: """Get the appropriate column mapping for a task.""" if not dataset: return {} first_split = list(dataset.keys())[0] columns = set(dataset[first_split].column_names) mapping = {} if task_type == "causal-lm": # Look for text column for col in ["text", "content", "document", "article", "input"]: if col in columns: mapping["text_column"] = col break if not mapping and len(columns) == 1: mapping["text_column"] = list(columns)[0] elif task_type == "seq2seq": for col in ["article", "document", "text", "input", "dialogue"]: if col in columns: mapping["input_column"] = col break for col in ["highlights", "summary", "target", "output", "subject_line"]: if col in columns: mapping["target_column"] = col break elif task_type == "token-classification": for col in ["tokens", "words"]: if col in columns: mapping["tokens_column"] = col break for col in ["ner_tags", "labels", "tags"]: if col in columns: mapping["labels_column"] = col break elif task_type == "text-classification": for col in ["text", "sentence", "content", "review"]: if col in columns: mapping["text_column"] = col break for col in ["label", "labels", "class", "category"]: if col in columns: mapping["label_column"] = col break elif task_type == "question-answering": for col in ["context"]: if col in columns: mapping["context_column"] = col for col in ["question"]: if col in columns: mapping["question_column"] = col for col in ["answers", "answer"]: if col in columns: mapping["answers_column"] = col return mapping def prepare_dataset_for_training( dataset: DatasetDict, tokenizer: Any, task_type: str, column_mapping: Dict[str, str], max_length: int = 512, padding: str = "max_length", truncation: bool = True, ) -> Tuple[DatasetDict, Dict[str, Any]]: """Prepare dataset for training by tokenizing.""" stats = { "original_samples": {}, "processed_samples": {}, "avg_length": {}, "removed_samples": {}, } def tokenize_function(examples, text_col=None, target_col=None): """Tokenize function based on task type.""" if task_type == "causal-lm": text_col = column_mapping.get("text_column", "text") if text_col not in examples: return examples outputs = tokenizer( examples[text_col], padding=padding, truncation=truncation, max_length=max_length, return_tensors=None, ) outputs["labels"] = outputs["input_ids"].copy() return outputs elif task_type == "seq2seq": input_col = column_mapping.get("input_column") target_col = column_mapping.get("target_column") if not input_col or not target_col: raise ValueError(f"Missing columns for seq2seq: {column_mapping}") model_inputs = tokenizer( examples[input_col], padding=padding, truncation=truncation, max_length=max_length, ) with tokenizer.as_target_tokenizer(): labels = tokenizer( examples[target_col], padding=padding, truncation=truncation, max_length=max_length, ) model_inputs["labels"] = labels["input_ids"] return model_inputs elif task_type == "token-classification": tokens_col = column_mapping.get("tokens_column", "tokens") labels_col = column_mapping.get("labels_column", "ner_tags") if tokens_col not in examples or labels_col not in examples: return examples tokenized_inputs = tokenizer( examples[tokens_col], padding=padding, truncation=truncation, max_length=max_length, is_split_into_words=True, ) labels = [] for i, label in enumerate(examples[labels_col]): word_ids = tokenized_inputs.word_ids(batch_index=i) previous_word_idx = None label_ids = [] for word_idx in word_ids: if word_idx is None: label_ids.append(-100) elif word_idx != previous_word_idx: label_ids.append(label[word_idx]) else: label_ids.append(-100) previous_word_idx = word_idx labels.append(label_ids) tokenized_inputs["labels"] = labels return tokenized_inputs elif task_type == "text-classification": text_col = column_mapping.get("text_column", "text") if text_col not in examples: return examples tokenized = tokenizer( examples[text_col], padding=padding, truncation=truncation, max_length=max_length, ) # Add labels if present label_col = column_mapping.get("label_column", "label") if label_col in examples: tokenized["labels"] = examples[label_col] return tokenized elif task_type == "question-answering": context_col = column_mapping.get("context_column", "context") question_col = column_mapping.get("question_column", "question") answers_col = column_mapping.get("answers_column", "answers") tokenized = tokenizer( examples[question_col], examples[context_col], padding=padding, truncation=truncation, max_length=max_length, ) # Process answers if answers_col in examples: # Simplified - full implementation would compute token positions tokenized["labels"] = [[0, 0] for _ in examples[answers_col]] return tokenized return examples # Tokenize each split tokenized_datasets = DatasetDict() for split_name, split_ds in dataset.items(): stats["original_samples"][split_name] = len(split_ds) # Remove columns that aren't needed (keep label-related columns) remove_columns = [] for col in split_ds.column_names: if col not in ["labels", "label", "input_ids", "attention_mask"]: if col not in column_mapping.values(): remove_columns.append(col) tokenized = split_ds.map( tokenize_function, batched=True, remove_columns=remove_columns, desc=f"Tokenizing {split_name}", ) tokenized_datasets[split_name] = tokenized stats["processed_samples"][split_name] = len(tokenized) return tokenized_datasets, stats def split_dataset( dataset: DatasetDict, train_split: float = 0.9, val_split: float = 0.1, seed: int = 42, ) -> DatasetDict: """Split a dataset into train and validation sets.""" if "validation" in dataset: return dataset if "train" in dataset: split_dataset = dataset["train"].train_test_split( test_size=val_split, seed=seed, ) return DatasetDict({ "train": split_dataset["train"], "validation": split_dataset["test"], }) return dataset def sample_dataset( dataset: DatasetDict, n_samples: int, split: str = "train", seed: int = 42, ) -> DatasetDict: """Sample a subset of the dataset for quick testing.""" if split not in dataset: return dataset sampled = dataset[split].shuffle(seed=seed).select(range(min(n_samples, len(dataset[split])))) result = dict(dataset) result[split] = sampled return DatasetDict(result) def get_label_list(dataset: DatasetDict, label_column: str = "label") -> List[str]: """Get list of labels from dataset.""" if not dataset: return [] for split_name, split_ds in dataset.items(): if label_column in split_ds.column_names: features = split_ds.features.get(label_column) if features and hasattr(features, "names"): return list(features.names) elif features and hasattr(features, "int2str"): # Try to infer number of labels unique_labels = set(split_ds[label_column]) return [str(i) for i in range(max(unique_labels) + 1)] return [] def estimate_dataset_size(dataset: DatasetDict) -> Dict[str, Any]: """Estimate dataset size in memory.""" if not dataset: return {"total_samples": 0, "estimated_size_mb": 0} total_samples = sum(len(split) for split in dataset.values()) # Rough estimation: ~1KB per sample for text estimated_size_mb = total_samples * 0.001 return { "total_samples": total_samples, "estimated_size_mb": round(estimated_size_mb, 2), "splits": {name: len(split) for name, split in dataset.items()}, } def validate_dataset_for_task( dataset: DatasetDict, task_type: str, column_mapping: Dict[str, str], ) -> Tuple[bool, List[str]]: """Validate that a dataset is suitable for a task.""" issues = [] if not dataset: return False, ["Dataset is empty or could not be loaded"] first_split = list(dataset.keys())[0] columns = set(dataset[first_split].column_names) if task_type == "causal-lm": text_col = column_mapping.get("text_column") if not text_col or text_col not in columns: issues.append(f"Missing text column. Found: {columns}") elif task_type == "seq2seq": input_col = column_mapping.get("input_column") target_col = column_mapping.get("target_column") if not input_col or input_col not in columns: issues.append(f"Missing input column. Found: {columns}") if not target_col or target_col not in columns: issues.append(f"Missing target column. Found: {columns}") elif task_type == "token-classification": tokens_col = column_mapping.get("tokens_column") labels_col = column_mapping.get("labels_column") if not tokens_col or tokens_col not in columns: issues.append(f"Missing tokens column. Found: {columns}") if not labels_col or labels_col not in columns: issues.append(f"Missing labels column. Found: {columns}") elif task_type == "text-classification": text_col = column_mapping.get("text_column") label_col = column_mapping.get("label_column") if not text_col or text_col not in columns: issues.append(f"Missing text column. Found: {columns}") if not label_col or label_col not in columns: issues.append(f"Missing label column. Found: {columns}") elif task_type == "question-answering": required = ["context_column", "question_column", "answers_column"] for col_key in required: col = column_mapping.get(col_key) if not col or col not in columns: issues.append(f"Missing {col_key}. Found: {columns}") return len(issues) == 0, issues