| import json |
| import logging |
| import os |
| import re |
| import shutil |
| import sys |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| from accelerate import Accelerator, skip_first_batches |
| from accelerate.logging import get_logger |
| from datasets import DatasetDict, load_dataset |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| BitsAndBytesConfig, |
| HfArgumentParser, |
| ) |
|
|
|
|
| logger = get_logger(__name__, log_level="INFO") |
|
|
|
|
| @dataclass |
| class ModelArguments: |
| """ |
| Arguments pertaining to what data we are going to input our model for training and eval. |
| """ |
|
|
| model_name_or_path: str = field( |
| metadata={"help": "The name of the model to use (via the transformers library) for the prompt annotation."}, |
| ) |
| per_device_eval_batch_size: int = field( |
| metadata={"help": "The per-device batch size to use for inference."}, |
| ) |
| model_variant: str = field( |
| default=None, |
| metadata={"help": "If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. "}, |
| ) |
| model_revision: str = field( |
| default="main", |
| metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, |
| ) |
| cache_dir: Optional[str] = field( |
| default=None, |
| metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, |
| ) |
| torch_dtype: Optional[str] = field( |
| default="float16", |
| metadata={ |
| "help": ( |
| "Floating-point format in which the model weights should be initialized" |
| " and the computations run. Choose one of `[float32, float16, bfloat16]`." |
| ) |
| }, |
| ) |
| attn_implementation: Optional[str] = field( |
| default="sdpa", |
| metadata={"help": "Which attn type to use: ['eager', 'sdpa', 'flash_attention_2']"}, |
| ) |
| load_in_8bit: Optional[bool] = field( |
| default=False, metadata={"help": "Whether to use 8-bit precision for inference."} |
| ) |
| load_in_4bit: Optional[bool] = field( |
| default=False, metadata={"help": "Whether to use 4-bit precision for inference."} |
| ) |
| bnb_4bit_quant_type: Optional[str] = field( |
| default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"} |
| ) |
| use_bnb_nested_quant: Optional[bool] = field(default=False, metadata={"help": "use nested quantization"}) |
| trust_remote_code: Optional[bool] = field( |
| default=False, |
| metadata={ |
| "help": ( |
| "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option " |
| "should only be set to `True` for repositories you trust and in which you have read the code, as it will " |
| "execute code present on the Hub on your local machine." |
| ) |
| }, |
| ) |
| use_fast_tokenizer: Optional[bool] = field( |
| default=True, metadata={"help": "Use fast tokenizer for encoding/decoding input ids"} |
| ) |
| token: Optional[bool] = field( |
| default=True, |
| metadata={ |
| "help": "Whether or not to use an authentication token when loading/uploading from the Hugging Face Hub" |
| }, |
| ) |
| do_sample: Optional[bool] = field(default=True, metadata={"help": "Whether to use sampling mode for generation"}) |
| temperature: Optional[float] = field(default=0.6, metadata={"help": "Temperature for sampling-based generation"}) |
| max_new_tokens: Optional[int] = field( |
| default=256, metadata={"help": "Maximum number of new tokens during generation"} |
| ) |
| torch_compile: Optional[bool] = field( |
| default=False, |
| metadata={ |
| "help": "Whether to compile the forward pass (not sampling) in generate. Only compatible with Gemma and LlaMA." |
| }, |
| ) |
|
|
|
|
| @dataclass |
| class DataArguments: |
| """ |
| Arguments pertaining to what data we are going to input our model for training and eval. |
| """ |
|
|
| output_dir: str = field( |
| metadata={ |
| "help": "Where to save the processed dataset to disk. If unspecified, uses a 'pretty' version of the " |
| "original dataset name. E.g. 'facebook/voxpopuli' will be saved under 'voxpopuli'." |
| }, |
| ) |
| dataset_name: str = field( |
| default=None, |
| metadata={"help": "The name of the dataset to use (via the datasets library)"}, |
| ) |
| dataset_config_name: Optional[str] = field( |
| default=None, |
| metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}, |
| ) |
| dataset_split_name: Optional[str] = field( |
| default=None, |
| metadata={"help": "The split name of the dataset to use (via the datasets library)."}, |
| ) |
| dataset_cache_dir: Optional[str] = field( |
| default=None, |
| metadata={"help": "Path to cache directory for saving and loading datasets"}, |
| ) |
| max_eval_samples: Optional[int] = field( |
| default=None, |
| metadata={"help": "Maximum number of samples for generation - use for debugging purposes."}, |
| ) |
| overwrite_cache: bool = field( |
| default=False, |
| metadata={"help": "Overwrite the cached training and evaluation sets"}, |
| ) |
| preprocessing_num_workers: Optional[int] = field( |
| default=None, |
| metadata={"help": "The number of processes to use for the preprocessing."}, |
| ) |
| dataloader_num_workers: Optional[int] = field( |
| default=0, |
| metadata={"help": "The number of processes to use for the dataloader."}, |
| ) |
| push_to_hub: Optional[bool] = field( |
| default=False, |
| metadata={"help": "Whether or not to push the processed dataset to the Hub."}, |
| ) |
| hub_dataset_id: Optional[str] = field( |
| default=None, |
| metadata={"help": "Repository namespace if pushing to the Hugging Face Hub."}, |
| ) |
| overwrite_output_dir: Optional[bool] = field( |
| default=False, |
| metadata={"help": "Overwrite the content of the output directory each time the script is run."}, |
| ) |
| save_steps: Optional[int] = field( |
| default=500, |
| metadata={"help": "Save the generated prompts every save_steps."}, |
| ) |
| save_total_limit: Optional[int] = field( |
| default=1, metadata={"help": ("If a value is passed, will limit the total number of saved checkpoints")} |
| ) |
|
|
| def __post_init__(self): |
| if self.push_to_hub and self.hub_dataset_id is None: |
| raise ValueError("You must specify the `hub_dataset_id` when setting `--push_to_hub=True`") |
|
|
|
|
| def get_quantization_config(model_args: ModelArguments) -> Union[BitsAndBytesConfig, None]: |
| if model_args.load_in_4bit: |
| compute_dtype = torch.float16 |
| if model_args.torch_dtype not in {"auto", None}: |
| compute_dtype = getattr(torch, model_args.torch_dtype) |
|
|
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=compute_dtype, |
| bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, |
| bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, |
| ) |
| elif model_args.load_in_8bit: |
| quantization_config = BitsAndBytesConfig( |
| load_in_8bit=True, |
| ) |
| else: |
| quantization_config = None |
|
|
| return quantization_config |
|
|
|
|
| def get_current_device() -> int: |
| """Get the current device. For GPU we return the local process index to enable multiple GPU training.""" |
| return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def get_kbit_device_map() -> Union[Dict[str, int], None]: |
| """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`""" |
| return {"": get_current_device()} if torch.cuda.is_available() else None |
|
|
|
|
| CHECKPOINT_PREFIX = "checkpoint" |
| _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+).json$") |
|
|
|
|
| def save_checkpoint(output_dir, all_generated_ids, step): |
| checkpoint_path = f"{CHECKPOINT_PREFIX}-{step}.json" |
| output_path = os.path.join(output_dir, checkpoint_path) |
| all_generated_ids = [ids.tolist() for ids in all_generated_ids] |
| with open(output_path, "w") as file: |
| json.dump(all_generated_ids, file) |
|
|
|
|
| def load_checkpoint(checkpoint_path): |
| with open(checkpoint_path, "r") as file: |
| all_generated_ids = json.load(file) |
| all_generated_ids = [np.array(lst) for lst in all_generated_ids] |
| return all_generated_ids |
|
|
|
|
| def sorted_checkpoints(output_dir=None) -> List[str]: |
| """Helper function to sort saved checkpoints from oldest to newest.""" |
| ordering_and_checkpoint_path = [] |
|
|
| glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{CHECKPOINT_PREFIX}-*")] |
|
|
| for path in glob_checkpoints: |
| regex_match = re.match(f".*{CHECKPOINT_PREFIX}-([0-9]+)", path) |
| if regex_match is not None and regex_match.groups() is not None: |
| ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) |
|
|
| checkpoints_sorted = sorted(ordering_and_checkpoint_path) |
| checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] |
| return checkpoints_sorted |
|
|
|
|
| def rotate_checkpoints(save_total_limit=None, output_dir=None) -> None: |
| """Helper function to delete old checkpoints.""" |
| if save_total_limit is None or save_total_limit <= 0: |
| return |
| |
| checkpoints_sorted = sorted_checkpoints(output_dir=output_dir) |
| if len(checkpoints_sorted) <= save_total_limit: |
| return |
|
|
| number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) |
| checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] |
| for checkpoint in checkpoints_to_be_deleted: |
| logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") |
| os.remove(checkpoint) |
|
|
|
|
| def get_last_checkpoint(folder) -> Tuple[List, int]: |
| if not os.path.exists(folder) or not os.path.isdir(folder): |
| os.makedirs(folder, exist_ok=True) |
| return [], 0 |
| content = os.listdir(folder) |
| checkpoints = [path for path in content if _RE_CHECKPOINT.search(path) is not None] |
| if len(checkpoints) == 0: |
| return [], 0 |
| last_checkpoint = os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0]))) |
| |
| pattern = r"checkpoint-(\d+).json" |
| match = re.search(pattern, last_checkpoint) |
| cur_step = int(match.group(1)) |
| |
| all_generated_ids = load_checkpoint(last_checkpoint) |
| return all_generated_ids, cur_step |
|
|
|
|
| @dataclass |
| class DataCollatorWithPadding: |
| """ |
| Data collator that will dynamically pad the inputs received to the longest sequence in the batch. |
| """ |
|
|
| tokenizer: Any |
|
|
| def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
| |
| |
| input_ids = {"input_ids": [feature["input_ids"] for feature in features]} |
| batch = self.tokenizer.pad(input_ids, return_tensors="pt", padding="longest", return_attention_mask=True) |
| return batch |
|
|
| id_to_name = { |
| "ex01": "Jerry", |
| "ex02": "Elisabeth", |
| "ex03": "Thomas", |
| "ex04": "Talia" |
| } |
|
|
| PROMPT = """You will be given a name and an enunciation style related to an audio sample of a person's speech. |
| 1. The name will be one of those: Jerry, Elisabeth, Thomas, Talia. |
| 2. The enunciation style will be one of those: 'enunciated', 'happy', 'confused', 'default', 'laughing', 'sad', 'whisper', 'emphasis'. |
| The enunciation style 'default' can be associated to 'with no particular emotion conveyed'. |
| |
| Your task is to create a text description using these information that accurately describes the speech sample. Ensure that the generated description is grammatically correct, easy to understand, and most importantly, concise. |
| |
| For example, given the following keywords: 'Talia', 'happy', a valid description would be: 'In an excellent recording, Talia speaks happily.'. |
| Another valid description would be: 'Talia delivers her words happily.' |
| Another example, given the following keywords: 'Jerry', 'emphasis': 'Jerry speaks with emphasis on certain words.' |
| |
| You are free to change the order of th!e information, and replace synonymous terms. |
| You must give one and only one description and nothing else. Remember, I only want one description and nothing else. |
| |
| For the information: '[speaker_id]', '[style]', the corresponding description is:""" |
|
|
|
|
| def main(): |
| |
| parser = HfArgumentParser((ModelArguments, DataArguments)) |
| if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
| |
| |
| model_args, data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
| else: |
| model_args, data_args = parser.parse_args_into_dataclasses() |
|
|
| |
| |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| handlers=[logging.StreamHandler(sys.stdout)], |
| ) |
| |
| accelerator = Accelerator() |
|
|
| if data_args.overwrite_output_dir and os.path.exists(data_args.output_dir) and os.path.isdir(data_args.output_dir): |
| logger.info("Cleaning output dir from previous run...") |
| shutil.rmtree(data_args.output_dir) |
|
|
| |
| logger.info("*** Load annotated dataset ***") |
| if data_args.dataset_split_name is not None: |
| raw_datasets = DatasetDict() |
| data_splits = data_args.dataset_split_name.split("+") |
| |
| for split in data_splits: |
| with accelerator.local_main_process_first(): |
| raw_datasets[split] = load_dataset( |
| data_args.dataset_name, |
| data_args.dataset_config_name, |
| split=split, |
| cache_dir=model_args.cache_dir, |
| token=model_args.token, |
| num_proc=data_args.preprocessing_num_workers, |
| ) |
| else: |
| with accelerator.local_main_process_first(): |
| |
| raw_datasets = load_dataset( |
| data_args.dataset_name, |
| data_args.dataset_config_name, |
| cache_dir=model_args.cache_dir, |
| token=model_args.token, |
| num_proc=data_args.preprocessing_num_workers, |
| ) |
|
|
| raw_datasets_features = set(raw_datasets[next(iter(raw_datasets))].features.keys()) |
|
|
| if data_args.max_eval_samples is not None: |
| for split in raw_datasets: |
| raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples)) |
|
|
| |
| EXPECTED_COLUMNS = {"speaker_id", "style"} |
| if not EXPECTED_COLUMNS.issubset(raw_datasets_features): |
| missing_columns = EXPECTED_COLUMNS - raw_datasets_features |
| raise ValueError( |
| f"Missing columns {missing_columns} from the dataset features. Got dataset features {raw_datasets_features}" |
| ) |
|
|
| |
| logger.info("*** Load pretrained model ***") |
| torch_dtype = ( |
| model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) |
| ) |
| quantization_config = get_quantization_config(model_args) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| model_args.model_name_or_path, |
| revision=model_args.model_revision, |
| variant=model_args.model_variant, |
| trust_remote_code=model_args.trust_remote_code, |
| attn_implementation=model_args.attn_implementation, |
| torch_dtype=torch_dtype, |
| device_map=get_kbit_device_map() if quantization_config is not None else None, |
| quantization_config=quantization_config, |
| low_cpu_mem_usage=True, |
| token=model_args.token, |
| ).eval() |
|
|
| if model_args.torch_compile: |
| |
| if not callable(getattr(model, "_setup_cache", None)): |
| raise ValueError( |
| f"Static k/v cache is not compatible with the model {model.__class__.__name__}. Set `--torch_compile=False" |
| "for dynamic k/v cache" |
| ) |
| model.generation_config.cache_implementation = "static" |
| |
| model = torch.compile(model, mode="reduce-overhead", fullgraph=True) |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| model_args.model_name_or_path, |
| revision=model_args.model_revision, |
| trust_remote_code=model_args.trust_remote_code, |
| use_fast=model_args.use_fast_tokenizer, |
| padding_side="left", |
| ) |
| if tokenizer.pad_token_id is None: |
| tokenizer.pad_token_id = tokenizer.bos_token_id |
| model.generation_config.pad_token_id = model.generation_config.eos_token_id |
|
|
|
|
| def prepare_dataset(sample): |
| sample_prompt = PROMPT |
| sample["speaker_id"] = id_to_name[sample["speaker_id"]] |
| for key in EXPECTED_COLUMNS: |
| sample_prompt = sample_prompt.replace(f"[{key}]", sample[key]) |
| sample_prompt = [{"role": "user", "content": sample_prompt}] |
| token_ids = tokenizer.apply_chat_template(sample_prompt) |
| sample["input_ids"] = token_ids |
| return sample |
|
|
| with accelerator.local_main_process_first(): |
| vectorized_datasets = raw_datasets.map( |
| prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preparing prompts" |
| ) |
|
|
| |
| model = accelerator.prepare(model) |
| data_collator = DataCollatorWithPadding(tokenizer) |
|
|
| def generate_step(batch): |
| output_ids = accelerator.unwrap_model(model).generate( |
| batch["input_ids"], |
| attention_mask=batch["attention_mask"], |
| do_sample=model_args.do_sample, |
| temperature=model_args.temperature, |
| max_new_tokens=model_args.max_new_tokens, |
| ) |
| output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id) |
| return output_ids |
|
|
| def postprocess_dataset(sample): |
| prompt_text = tokenizer.decode(sample["input_ids"], skip_special_tokens=True) |
| generated_text = tokenizer.decode(sample["generated_ids"], skip_special_tokens=True) |
| sample["text_description"] = generated_text[len(prompt_text) :] |
| return sample |
|
|
| for split in vectorized_datasets: |
| data_loader = DataLoader( |
| vectorized_datasets[split], |
| batch_size=model_args.per_device_eval_batch_size, |
| collate_fn=data_collator, |
| num_workers=data_args.dataloader_num_workers, |
| pin_memory=True, |
| ) |
| data_loader = accelerator.prepare(data_loader) |
| total_inference_steps = len(data_loader) |
| progress_bar = tqdm( |
| range(total_inference_steps), desc=" ... ", position=0, disable=not accelerator.is_local_main_process |
| ) |
|
|
| split_output_dir = os.path.join(data_args.output_dir, split) |
| all_generated_ids, cur_step = get_last_checkpoint(split_output_dir) |
|
|
| if cur_step > 0: |
| logger.info(f"Resuming {split} from step {cur_step}") |
| |
| data_loader = skip_first_batches(data_loader, cur_step) |
| progress_bar.update(cur_step) |
|
|
| while cur_step < total_inference_steps: |
| for batch in data_loader: |
| generated_ids = generate_step(batch) |
| generated_ids = accelerator.gather_for_metrics(generated_ids) |
| all_generated_ids.extend(generated_ids.cpu().numpy()) |
|
|
| cur_step += 1 |
| progress_bar.update(1) |
|
|
| if (cur_step % data_args.save_steps == 0) or (cur_step == total_inference_steps): |
| save_checkpoint(split_output_dir, all_generated_ids, cur_step) |
| rotate_checkpoints(data_args.save_total_limit, output_dir=split_output_dir) |
|
|
| vectorized_datasets[split] = vectorized_datasets[split].add_column("generated_ids", all_generated_ids) |
|
|
| if accelerator.is_main_process: |
| vectorized_datasets[split] = vectorized_datasets[split].map( |
| postprocess_dataset, |
| num_proc=data_args.preprocessing_num_workers, |
| desc="Postprocessing dataset", |
| remove_columns=["input_ids", "generated_ids"], |
| ) |
| accelerator.wait_for_everyone() |
|
|
| if accelerator.is_main_process: |
| vectorized_datasets.save_to_disk(data_args.output_dir) |
| if data_args.push_to_hub: |
| vectorized_datasets.push_to_hub( |
| data_args.hub_dataset_id, |
| config_name=data_args.dataset_config_name if data_args.dataset_config_name is not None else "default", |
| token=model_args.token, |
| ) |
| accelerator.wait_for_everyone() |
| accelerator.end_training() |
|
|
|
|
| if __name__ == "__main__": |
| main() |