universal-model-trainer / app /utils /dataset_utils.py
vectorplasticity's picture
Add dataset utilities
16525fb verified
"""
Dataset Utilities - Helper functions for dataset operations
"""
import logging
from typing import Dict, Any, List, Optional, Tuple
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer
import json
import os
logger = logging.getLogger(__name__)
# Dataset column mappings for common datasets
DATASET_COLUMN_MAPPINGS = {
"wikitext": {"text": "text"},
"squad": {"question": "question", "context": "context", "answers": "answers"},
"squad_v2": {"question": "question", "context": "context", "answers": "answers"},
"cnn_dailymail": {"article": "article", "highlights": "highlights"},
"xsum": {"document": "document", "summary": "summary"},
"samsum": {"dialogue": "dialogue", "summary": "summary"},
"billsum": {"text": "text", "summary": "summary"},
"aeslc": {"email_body": "email_body", "subject_line": "subject_line"},
"conll2003": {"tokens": "tokens", "ner_tags": "ner_tags"},
"wnut_17": {"tokens": "tokens", "ner_tags": "ner_tags"},
"imdb": {"text": "text", "label": "label"},
"yelp_polarity": {"text": "text", "label": "label"},
"yelp_review_full": {"text": "text", "label": "label"},
"sst2": {"sentence": "sentence", "label": "label"},
"cola": {"sentence": "sentence", "label": "label"},
"mnli": {"premise": "premise", "hypothesis": "hypothesis", "label": "label"},
"qnli": {"question": "question", "sentence": "sentence", "label": "label"},
"qqp": {"question1": "question1", "question2": "question2", "label": "label"},
"mrpc": {"sentence1": "sentence1", "sentence2": "sentence2", "label": "label"},
"stsb": {"sentence1": "sentence1", "sentence2": "sentence2", "label": "label"},
"glue": {},
"super_glue": {},
"trec": {"text": "text", "label": "label"},
"ag_news": {"text": "text", "label": "label"},
"dbpedia_14": {"content": "content", "label": "label"},
"20newsgroups": {"text": "text", "label": "label"},
}
# Task-specific dataset templates
TASK_DATASET_TEMPLATES = {
"causal-lm": {
"text_column": "text",
"format": "causal",
"examples": ["wikitext", "openwebtext", "the_pile", "c4", "oscar"],
},
"seq2seq": {
"input_column": None,
"target_column": None,
"format": "seq2seq",
"examples": ["cnn_dailymail", "xsum", "samsum", "billsum", "aeslc"],
},
"token-classification": {
"tokens_column": "tokens",
"labels_column": "ner_tags",
"format": "token",
"examples": ["conll2003", "wnut_17", "ontonotes5"],
},
"text-classification": {
"text_column": "text",
"label_column": "label",
"format": "classification",
"examples": ["imdb", "yelp_polarity", "sst2", "ag_news", "dbpedia_14"],
},
"question-answering": {
"context_column": "context",
"question_column": "question",
"answers_column": "answers",
"format": "qa",
"examples": ["squad", "squad_v2", "natural_questions", "hotpotqa"],
},
"reasoning": {
"input_column": "input",
"target_column": "target",
"format": "causal",
"examples": ["gsm8k", "strategyqa", "aqua"],
},
}
def get_dataset_info(dataset_name: str) -> Dict[str, Any]:
"""Get information about a dataset from HuggingFace Hub."""
try:
from huggingface_hub import HfApi, dataset_info
api = HfApi()
info = api.dataset_info(dataset_name)
return {
"id": info.id,
"author": info.author,
"sha": info.sha,
"downloads": getattr(info, "downloads", 0),
"tags": info.tags or [],
"description": getattr(info, "description", ""),
"card_data": getattr(info, "card_data", {}),
"siblings": [s.rfilename for s in info.siblings] if info.siblings else [],
"size_bytes": sum(getattr(s, "size", 0) or 0 for s in info.siblings) if info.siblings else 0,
}
except Exception as e:
logger.error(f"Error getting dataset info for {dataset_name}: {e}")
return {"error": str(e)}
def load_and_validate_dataset(
dataset_name: str,
config: Optional[str] = None,
split: Optional[str] = None,
trust_remote_code: bool = False,
) -> Tuple[Optional[DatasetDict], Optional[str]]:
"""Load a dataset and validate it."""
try:
kwargs = {"trust_remote_code": trust_remote_code}
if config:
kwargs["name"] = config
if split:
kwargs["split"] = split
dataset = load_dataset(dataset_name, **kwargs)
# If single split returned, wrap in dict
if isinstance(dataset, Dataset):
dataset = DatasetDict({"train": dataset})
return dataset, None
except Exception as e:
logger.error(f"Error loading dataset {dataset_name}: {e}")
return None, str(e)
def get_dataset_schema(dataset: DatasetDict) -> Dict[str, Any]:
"""Get the schema of a dataset."""
if not dataset:
return {}
# Get first available split
first_split = list(dataset.keys())[0]
ds = dataset[first_split]
schema = {
"splits": list(dataset.keys()),
"columns": {},
"num_rows": {},
"features": {},
}
for split_name, split_ds in dataset.items():
schema["num_rows"][split_name] = len(split_ds)
for col in ds.column_names:
col_info = {"name": col}
feature = ds.features.get(col)
if feature:
col_info["dtype"] = str(feature.dtype) if hasattr(feature, "dtype") else str(type(feature))
if hasattr(feature, "names"):
col_info["label_names"] = list(feature.names)
col_info["feature_type"] = type(feature).__name__
schema["columns"][col] = col_info
schema["features"][col] = str(feature) if feature else "unknown"
return schema
def detect_task_type(dataset_name: str, dataset: DatasetDict) -> str:
"""Detect the likely task type for a dataset based on its columns."""
if not dataset:
return "unknown"
first_split = list(dataset.keys())[0]
columns = set(dataset[first_split].column_names)
# Check for specific patterns
if "tokens" in columns and "ner_tags" in columns:
return "token-classification"
if "question" in columns and "context" in columns:
return "question-answering"
if "article" in columns or "document" in columns:
return "seq2seq"
if "text" in columns and "label" in columns:
return "text-classification"
if "text" in columns and len(columns) <= 3:
return "causal-lm"
if "dialogue" in columns or "summary" in columns:
return "seq2seq"
if "input" in columns and "target" in columns:
return "causal-lm"
# Default
return "causal-lm"
def get_dataset_columns_for_task(
dataset: DatasetDict,
task_type: str
) -> Dict[str, str]:
"""Get the appropriate column mapping for a task."""
if not dataset:
return {}
first_split = list(dataset.keys())[0]
columns = set(dataset[first_split].column_names)
mapping = {}
if task_type == "causal-lm":
# Look for text column
for col in ["text", "content", "document", "article", "input"]:
if col in columns:
mapping["text_column"] = col
break
if not mapping and len(columns) == 1:
mapping["text_column"] = list(columns)[0]
elif task_type == "seq2seq":
for col in ["article", "document", "text", "input", "dialogue"]:
if col in columns:
mapping["input_column"] = col
break
for col in ["highlights", "summary", "target", "output", "subject_line"]:
if col in columns:
mapping["target_column"] = col
break
elif task_type == "token-classification":
for col in ["tokens", "words"]:
if col in columns:
mapping["tokens_column"] = col
break
for col in ["ner_tags", "labels", "tags"]:
if col in columns:
mapping["labels_column"] = col
break
elif task_type == "text-classification":
for col in ["text", "sentence", "content", "review"]:
if col in columns:
mapping["text_column"] = col
break
for col in ["label", "labels", "class", "category"]:
if col in columns:
mapping["label_column"] = col
break
elif task_type == "question-answering":
for col in ["context"]:
if col in columns:
mapping["context_column"] = col
for col in ["question"]:
if col in columns:
mapping["question_column"] = col
for col in ["answers", "answer"]:
if col in columns:
mapping["answers_column"] = col
return mapping
def prepare_dataset_for_training(
dataset: DatasetDict,
tokenizer: Any,
task_type: str,
column_mapping: Dict[str, str],
max_length: int = 512,
padding: str = "max_length",
truncation: bool = True,
) -> Tuple[DatasetDict, Dict[str, Any]]:
"""Prepare dataset for training by tokenizing."""
stats = {
"original_samples": {},
"processed_samples": {},
"avg_length": {},
"removed_samples": {},
}
def tokenize_function(examples, text_col=None, target_col=None):
"""Tokenize function based on task type."""
if task_type == "causal-lm":
text_col = column_mapping.get("text_column", "text")
if text_col not in examples:
return examples
outputs = tokenizer(
examples[text_col],
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=None,
)
outputs["labels"] = outputs["input_ids"].copy()
return outputs
elif task_type == "seq2seq":
input_col = column_mapping.get("input_column")
target_col = column_mapping.get("target_column")
if not input_col or not target_col:
raise ValueError(f"Missing columns for seq2seq: {column_mapping}")
model_inputs = tokenizer(
examples[input_col],
padding=padding,
truncation=truncation,
max_length=max_length,
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
examples[target_col],
padding=padding,
truncation=truncation,
max_length=max_length,
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
elif task_type == "token-classification":
tokens_col = column_mapping.get("tokens_column", "tokens")
labels_col = column_mapping.get("labels_column", "ner_tags")
if tokens_col not in examples or labels_col not in examples:
return examples
tokenized_inputs = tokenizer(
examples[tokens_col],
padding=padding,
truncation=truncation,
max_length=max_length,
is_split_into_words=True,
)
labels = []
for i, label in enumerate(examples[labels_col]):
word_ids = tokenized_inputs.word_ids(batch_index=i)
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100)
elif word_idx != previous_word_idx:
label_ids.append(label[word_idx])
else:
label_ids.append(-100)
previous_word_idx = word_idx
labels.append(label_ids)
tokenized_inputs["labels"] = labels
return tokenized_inputs
elif task_type == "text-classification":
text_col = column_mapping.get("text_column", "text")
if text_col not in examples:
return examples
tokenized = tokenizer(
examples[text_col],
padding=padding,
truncation=truncation,
max_length=max_length,
)
# Add labels if present
label_col = column_mapping.get("label_column", "label")
if label_col in examples:
tokenized["labels"] = examples[label_col]
return tokenized
elif task_type == "question-answering":
context_col = column_mapping.get("context_column", "context")
question_col = column_mapping.get("question_column", "question")
answers_col = column_mapping.get("answers_column", "answers")
tokenized = tokenizer(
examples[question_col],
examples[context_col],
padding=padding,
truncation=truncation,
max_length=max_length,
)
# Process answers
if answers_col in examples:
# Simplified - full implementation would compute token positions
tokenized["labels"] = [[0, 0] for _ in examples[answers_col]]
return tokenized
return examples
# Tokenize each split
tokenized_datasets = DatasetDict()
for split_name, split_ds in dataset.items():
stats["original_samples"][split_name] = len(split_ds)
# Remove columns that aren't needed (keep label-related columns)
remove_columns = []
for col in split_ds.column_names:
if col not in ["labels", "label", "input_ids", "attention_mask"]:
if col not in column_mapping.values():
remove_columns.append(col)
tokenized = split_ds.map(
tokenize_function,
batched=True,
remove_columns=remove_columns,
desc=f"Tokenizing {split_name}",
)
tokenized_datasets[split_name] = tokenized
stats["processed_samples"][split_name] = len(tokenized)
return tokenized_datasets, stats
def split_dataset(
dataset: DatasetDict,
train_split: float = 0.9,
val_split: float = 0.1,
seed: int = 42,
) -> DatasetDict:
"""Split a dataset into train and validation sets."""
if "validation" in dataset:
return dataset
if "train" in dataset:
split_dataset = dataset["train"].train_test_split(
test_size=val_split,
seed=seed,
)
return DatasetDict({
"train": split_dataset["train"],
"validation": split_dataset["test"],
})
return dataset
def sample_dataset(
dataset: DatasetDict,
n_samples: int,
split: str = "train",
seed: int = 42,
) -> DatasetDict:
"""Sample a subset of the dataset for quick testing."""
if split not in dataset:
return dataset
sampled = dataset[split].shuffle(seed=seed).select(range(min(n_samples, len(dataset[split]))))
result = dict(dataset)
result[split] = sampled
return DatasetDict(result)
def get_label_list(dataset: DatasetDict, label_column: str = "label") -> List[str]:
"""Get list of labels from dataset."""
if not dataset:
return []
for split_name, split_ds in dataset.items():
if label_column in split_ds.column_names:
features = split_ds.features.get(label_column)
if features and hasattr(features, "names"):
return list(features.names)
elif features and hasattr(features, "int2str"):
# Try to infer number of labels
unique_labels = set(split_ds[label_column])
return [str(i) for i in range(max(unique_labels) + 1)]
return []
def estimate_dataset_size(dataset: DatasetDict) -> Dict[str, Any]:
"""Estimate dataset size in memory."""
if not dataset:
return {"total_samples": 0, "estimated_size_mb": 0}
total_samples = sum(len(split) for split in dataset.values())
# Rough estimation: ~1KB per sample for text
estimated_size_mb = total_samples * 0.001
return {
"total_samples": total_samples,
"estimated_size_mb": round(estimated_size_mb, 2),
"splits": {name: len(split) for name, split in dataset.items()},
}
def validate_dataset_for_task(
dataset: DatasetDict,
task_type: str,
column_mapping: Dict[str, str],
) -> Tuple[bool, List[str]]:
"""Validate that a dataset is suitable for a task."""
issues = []
if not dataset:
return False, ["Dataset is empty or could not be loaded"]
first_split = list(dataset.keys())[0]
columns = set(dataset[first_split].column_names)
if task_type == "causal-lm":
text_col = column_mapping.get("text_column")
if not text_col or text_col not in columns:
issues.append(f"Missing text column. Found: {columns}")
elif task_type == "seq2seq":
input_col = column_mapping.get("input_column")
target_col = column_mapping.get("target_column")
if not input_col or input_col not in columns:
issues.append(f"Missing input column. Found: {columns}")
if not target_col or target_col not in columns:
issues.append(f"Missing target column. Found: {columns}")
elif task_type == "token-classification":
tokens_col = column_mapping.get("tokens_column")
labels_col = column_mapping.get("labels_column")
if not tokens_col or tokens_col not in columns:
issues.append(f"Missing tokens column. Found: {columns}")
if not labels_col or labels_col not in columns:
issues.append(f"Missing labels column. Found: {columns}")
elif task_type == "text-classification":
text_col = column_mapping.get("text_column")
label_col = column_mapping.get("label_column")
if not text_col or text_col not in columns:
issues.append(f"Missing text column. Found: {columns}")
if not label_col or label_col not in columns:
issues.append(f"Missing label column. Found: {columns}")
elif task_type == "question-answering":
required = ["context_column", "question_column", "answers_column"]
for col_key in required:
col = column_mapping.get(col_key)
if not col or col not in columns:
issues.append(f"Missing {col_key}. Found: {columns}")
return len(issues) == 0, issues