| import random |
| from dataclasses import dataclass, field |
| from functools import partial |
| from pathlib import Path |
|
|
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from braceexpand import braceexpand |
| from datasets import Dataset, load_dataset |
|
|
| from .model.text import TextNormalizer |
|
|
|
|
| @dataclass |
| class Dataset: |
| dataset_repo_or_path: str |
| train_file: str = None |
| validation_file: str = None |
| streaming: bool = True |
| use_auth_token: bool = False |
| text_column: str = "caption" |
| encoding_column: str = "encoding" |
| max_train_samples: int = None |
| max_eval_samples: int = None |
| preprocessing_num_workers: int = None |
| overwrite_cache: bool = False |
| do_train: bool = False |
| do_eval: bool = True |
| seed_dataset: int = None |
| shard_by_host: bool = False |
| blank_caption_prob: float = 0.0 |
| clip_score_column: str = "clip_score" |
| min_clip_score: float = None |
| max_clip_score: float = None |
| filter_column: str = None |
| filter_value: str = None |
| multi_eval_ds: bool = False |
| train_dataset: Dataset = field(init=False) |
| eval_dataset: Dataset = field(init=False) |
| other_eval_datasets: list = field(init=False) |
| rng_dataset: jnp.ndarray = field(init=False) |
| multi_hosts: bool = field(init=False) |
|
|
| def __post_init__(self): |
| if self.seed_dataset is None: |
| |
| self.seed_dataset = random.randint(0, 2**32 - 1) |
| |
| self.np_rng = np.random.default_rng(self.seed_dataset) |
| self.multi_hosts = jax.process_count() > 1 |
| |
| |
| if self.blank_caption_prob: |
| assert ( |
| self.streaming is True |
| ), "blank_caption_prob can only be used in streaming mode" |
| |
| if self.train_file is not None or self.validation_file is not None: |
| |
| for k in ["train_file", "validation_file"]: |
| f = getattr(self, k) |
| if isinstance(f, str): |
| setattr(self, k, list(braceexpand(f))) |
| |
| if ( |
| isinstance(self.train_file, list) |
| and self.multi_hosts |
| and self.shard_by_host |
| ): |
| self.train_file = self.train_file[ |
| jax.process_index() :: jax.process_count() |
| ] |
| data_files = { |
| "train": self.train_file, |
| "validation": self.validation_file, |
| } |
| else: |
| data_files = None |
|
|
| |
| if self.multi_eval_ds: |
| assert Path( |
| self.dataset_repo_or_path |
| ).is_dir(), f"{self.dataset_repo_or_path} is not a directory, required for multi_eval_ds" |
| data_files = { |
| split.name: [str(f) for f in split.glob("*.parquet")] |
| for split in Path(self.dataset_repo_or_path).glob("*") |
| } |
| |
| if "valid" in data_files: |
| data_files["validation"] = data_files["valid"] |
| del data_files["valid"] |
| self.dataset_repo_or_path = "parquet" |
|
|
| |
| dataset = load_dataset( |
| self.dataset_repo_or_path, |
| data_files=data_files, |
| streaming=self.streaming, |
| use_auth_token=self.use_auth_token, |
| ) |
| if self.do_train: |
| if "train" not in dataset: |
| raise ValueError("Training requires a training dataset") |
| self.train_dataset = dataset["train"] |
| if self.max_train_samples is not None: |
| self.train_dataset = ( |
| self.train_dataset.take(self.max_train_samples) |
| if self.streaming |
| else self.train_dataset.select(range(self.max_train_samples)) |
| ) |
| if self.do_eval: |
| if "validation" not in dataset: |
| raise ValueError("Evaluating requires a validation dataset") |
| self.eval_dataset = dataset["validation"] |
| if self.max_eval_samples is not None: |
| self.eval_dataset = ( |
| self.eval_dataset.take(self.max_eval_samples) |
| if self.streaming |
| else self.eval_dataset.select(range(self.max_eval_samples)) |
| ) |
| |
| other_eval_splits = dataset.keys() - {"train", "validation"} |
| self.other_eval_datasets = { |
| split: dataset[split] for split in other_eval_splits |
| } |
|
|
| def preprocess(self, tokenizer, config): |
| |
| decoder_start_token_id = config.decoder_start_token_id |
| normalize_text = config.normalize_text |
| max_length = config.max_text_length |
|
|
| if self.streaming: |
| |
| if hasattr(self, "train_dataset"): |
| self.train_dataset = self.train_dataset.shuffle( |
| buffer_size=5000, seed=self.seed_dataset |
| ) |
| else: |
| self.rng_dataset = jax.random.PRNGKey(self.seed_dataset) |
|
|
| |
| partial_filter_function = partial( |
| filter_function, |
| filter_column=self.filter_column, |
| filter_value=self.filter_value, |
| clip_score_column=self.clip_score_column, |
| min_clip_score=self.min_clip_score, |
| max_clip_score=self.max_clip_score, |
| ) |
| for ds in ["train_dataset", "eval_dataset"]: |
| if hasattr(self, ds): |
| setattr( |
| self, |
| ds, |
| ( |
| getattr(self, ds).filter(partial_filter_function) |
| if self.streaming |
| else getattr(self, ds).filter( |
| partial_filter_function, |
| num_proc=self.preprocessing_num_workers, |
| load_from_cache_file=not self.overwrite_cache, |
| desc="Filtering datasets", |
| ) |
| ), |
| ) |
| if hasattr(self, "other_eval_datasets"): |
| self.other_eval_datasets = { |
| split: ( |
| ds.filter(partial_filter_function) |
| if self.streaming |
| else ds.filter( |
| partial_filter_function, |
| num_proc=self.preprocessing_num_workers, |
| load_from_cache_file=not self.overwrite_cache, |
| desc="Filtering datasets", |
| ) |
| ) |
| for split, ds in self.other_eval_datasets.items() |
| } |
|
|
| |
| if normalize_text: |
| text_normalizer = TextNormalizer() |
| partial_normalize_function = partial( |
| normalize_function, |
| text_column=self.text_column, |
| text_normalizer=text_normalizer, |
| ) |
| for ds in ["train_dataset", "eval_dataset"]: |
| if hasattr(self, ds): |
| setattr( |
| self, |
| ds, |
| ( |
| getattr(self, ds).map(partial_normalize_function) |
| if self.streaming |
| else getattr(self, ds).map( |
| partial_normalize_function, |
| num_proc=self.preprocessing_num_workers, |
| load_from_cache_file=not self.overwrite_cache, |
| desc="Normalizing datasets", |
| ) |
| ), |
| ) |
| if hasattr(self, "other_eval_datasets"): |
| self.other_eval_datasets = { |
| split: ( |
| ds.map(partial_normalize_function) |
| if self.streaming |
| else ds.map( |
| partial_normalize_function, |
| num_proc=self.preprocessing_num_workers, |
| load_from_cache_file=not self.overwrite_cache, |
| desc="Normalizing datasets", |
| ) |
| ) |
| for split, ds in self.other_eval_datasets.items() |
| } |
|
|
| |
| if self.blank_caption_prob: |
| partial_blank_caption_function = partial( |
| blank_caption_function, |
| text_column=self.text_column, |
| blank_caption_prob=self.blank_caption_prob, |
| rng=self.np_rng, |
| ) |
| if hasattr(self, "train_dataset"): |
| self.train_dataset = ( |
| self.train_dataset.map(partial_blank_caption_function) |
| if self.streaming |
| else self.train_dataset.map( |
| partial_blank_caption_function, |
| num_proc=None |
| if self.seed_dataset |
| else self.preprocessing_num_workers, |
| load_from_cache_file=False, |
| desc="Blanking some captions", |
| ) |
| ) |
|
|
| |
| partial_preprocess_function = partial( |
| preprocess_function, |
| tokenizer=tokenizer, |
| text_column=self.text_column, |
| encoding_column=self.encoding_column, |
| max_length=max_length, |
| decoder_start_token_id=decoder_start_token_id, |
| ) |
| for ds in ["train_dataset", "eval_dataset"]: |
| if hasattr(self, ds): |
| setattr( |
| self, |
| ds, |
| ( |
| getattr(self, ds).map( |
| partial_preprocess_function, |
| batched=True, |
| remove_columns=[ |
| self.text_column, |
| self.encoding_column, |
| ], |
| ) |
| if self.streaming |
| else getattr(self, ds).map( |
| partial_preprocess_function, |
| batched=True, |
| remove_columns=getattr(ds, "column_names"), |
| num_proc=self.preprocessing_num_workers, |
| load_from_cache_file=not self.overwrite_cache, |
| desc="Preprocessing datasets", |
| ) |
| ), |
| ) |
| if hasattr(self, "other_eval_datasets"): |
| self.other_eval_datasets = { |
| split: ( |
| ds.map( |
| partial_preprocess_function, |
| batched=True, |
| remove_columns=[ |
| self.text_column, |
| self.encoding_column, |
| ], |
| ) |
| if self.streaming |
| else ds.map( |
| partial_preprocess_function, |
| batched=True, |
| remove_columns=getattr(ds, "column_names"), |
| num_proc=self.preprocessing_num_workers, |
| load_from_cache_file=not self.overwrite_cache, |
| desc="Preprocessing datasets", |
| ) |
| ) |
| for split, ds in self.other_eval_datasets.items() |
| } |
|
|
| def dataloader(self, split, batch_size, epoch=None): |
| def _dataloader_datasets_non_streaming( |
| dataset: Dataset, |
| rng: jax.random.PRNGKey = None, |
| ): |
| """ |
| Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. |
| Shuffle batches if rng is set. |
| """ |
| steps_per_epoch = len(dataset) // batch_size |
|
|
| if rng is not None: |
| batch_idx = jax.random.permutation(rng, len(dataset)) |
| else: |
| batch_idx = jnp.arange(len(dataset)) |
|
|
| batch_idx = batch_idx[ |
| : steps_per_epoch * batch_size |
| ] |
| batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) |
|
|
| for idx in batch_idx: |
| batch = dataset[idx] |
| batch = {k: jnp.array(v) for k, v in batch.items()} |
| yield batch |
|
|
| def _dataloader_datasets_streaming( |
| dataset: Dataset, |
| epoch: int, |
| ): |
| keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"] |
| batch = {k: [] for k in keys} |
| first_loop = True |
| while (self.multi_hosts and split == "train") or first_loop: |
| |
| |
| |
| |
| if epoch is not None: |
| assert split == "train" |
| |
| dataset.set_epoch(epoch) |
| epoch += 1 |
| for item in dataset: |
| for k in keys: |
| batch[k].append(item[k]) |
| if len(batch[keys[0]]) == batch_size: |
| batch = {k: jnp.array(v) for k, v in batch.items()} |
| yield batch |
| batch = {k: [] for k in keys} |
| first_loop = False |
|
|
| if split == "train": |
| ds = self.train_dataset |
| elif split == "eval": |
| ds = self.eval_dataset |
| else: |
| ds = self.other_eval_datasets[split] |
|
|
| if self.streaming: |
| return _dataloader_datasets_streaming(ds, epoch) |
| else: |
| if split == "train": |
| self.rng_dataset, input_rng = jax.random.split(self.rng_dataset) |
| return _dataloader_datasets_non_streaming(ds, input_rng) |
|
|
| @property |
| def length(self): |
| len_train_dataset, len_eval_dataset = None, None |
| if self.streaming: |
| |
| if self.max_train_samples is not None: |
| len_train_dataset = self.max_train_samples |
| if self.max_eval_samples is not None: |
| len_eval_dataset = self.max_eval_samples |
| else: |
| len_train_dataset = ( |
| len(self.train_dataset) if hasattr(self, "train_dataset") else None |
| ) |
| len_eval_dataset = ( |
| len(self.eval_dataset) if hasattr(self, "eval_dataset") else None |
| ) |
| return len_train_dataset, len_eval_dataset |
|
|
|
|
| def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int): |
| """ |
| Shift input ids one token to the right. |
| """ |
| shifted_input_ids = np.zeros(input_ids.shape) |
| shifted_input_ids[:, 1:] = input_ids[:, :-1] |
| shifted_input_ids[:, 0] = decoder_start_token_id |
| return shifted_input_ids |
|
|
|
|
| def blank_caption_function(example, text_column, blank_caption_prob, rng=None): |
| if ( |
| blank_caption_prob |
| and (rng.random() if rng is not None else np.random.random()) |
| < blank_caption_prob |
| ): |
| example[text_column] = "" |
| return example |
|
|
|
|
| def normalize_function(example, text_column, text_normalizer): |
| example[text_column] = text_normalizer(example[text_column]) |
| return example |
|
|
|
|
| def filter_function( |
| example, |
| min_clip_score, |
| max_clip_score, |
| clip_score_column, |
| filter_column, |
| filter_value, |
| ): |
| if min_clip_score is not None and example[clip_score_column] < min_clip_score: |
| return False |
| if max_clip_score is not None and example[clip_score_column] > max_clip_score: |
| return False |
| if filter_column is not None and example[filter_column] != filter_value: |
| return False |
| return True |
|
|
|
|
| def preprocess_function( |
| examples, |
| tokenizer, |
| text_column, |
| encoding_column, |
| max_length, |
| decoder_start_token_id, |
| ): |
| inputs = examples[text_column] |
| |
| model_inputs = tokenizer( |
| inputs, |
| max_length=max_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="np", |
| ) |
|
|
| |
| |
| |
| labels = examples[encoding_column] |
| labels = np.asarray(labels) |
|
|
| |
| model_inputs["labels"] = labels |
|
|
| |
| decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id) |
| model_inputs["decoder_input_ids"] = decoder_input_ids |
|
|
| return model_inputs |
|
|