| import logging |
| import sys |
| import argparse |
| import os |
| import inspect |
| from typing import Optional, Any |
| from dataclasses import dataclass, field, make_dataclass |
| from transformers import Trainer, TrainingArguments, AutoTokenizer, HfArgumentParser |
| from datasets import load_from_disk |
|
|
| from funnel_vae.src.funnel_vae import FunnelVae |
| from funnel_vae.src.config import FunnelVaeConfig |
|
|
|
|
| @dataclass |
| class BaseArgs: |
| |
| model_name: str |
| epochs: int = 3 |
| per_device_train_batch_size: int = 32 |
| per_device_eval_batch_size: int = 64 |
| warmup_steps: int = 500 |
| learning_rate: str = 5e-5 |
|
|
| output_data_dir: str = os.environ["SM_OUTPUT_DATA_DIR"] |
| model_dir: str = os.environ["SM_MODEL_DIR"] |
| n_gpus: str = os.environ["SM_NUM_GPUS"] |
| training_dir: str = os.environ["SM_CHANNEL_TRAIN"] |
| test_dir: str = os.environ["SM_CHANNEL_TEST"] |
|
|
|
|
| |
| fields = [ |
| ( |
| 'tokenizer_name', Optional[str], field( |
| default='t5-base', metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
| ) |
| ), |
| ] + [ |
| ( |
| name, type(info.default) if info.default is not None else Any, field( |
| default=info.default, metadata={"help": f"Has default {info.default}, see FunnelVaeConfig docstring for more info."} |
| ) |
| ) |
| |
| for name, info in inspect.signature(FunnelVaeConfig.__init__).parameters.items() if name not in ['self', 'kwargs', 'use_extra_logs', 'cache_dir'] |
| ] |
| |
| start_f = list(filter(lambda field: field[2].default is None, fields)) |
| end_f = list(filter(lambda field: field[2].default is not None, fields)) |
| ModelArguments = make_dataclass('ModelArguments', start_f + end_f) |
|
|
|
|
| @dataclass |
| class DataArguments: |
| dataset_name: Optional[str] = field( |
| default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} |
| ) |
| text_column: Optional[str] = field(default=None, metadata={"help": "Use this dataset column as 'text'."}) |
| train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) |
| validation_file: Optional[str] = field( |
| default=None, |
| metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, |
| ) |
| overwrite_cache: bool = field(default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}) |
| preprocessing_num_workers: Optional[int] = field( |
| default=None, |
| metadata={"help": "The number of processes to use for the preprocessing."}, |
| ) |
| mlm_probability: float = field( |
| default=0.0, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} |
| ) |
| validation_name: str = field( |
| default="validation", |
| metadata={"help": "Name of the set to run evaluation on."}, |
| ) |
|
|
| def __post_init__(self): |
| if self.dataset_name is None and self.train_file is None and self.validation_file is None: |
| raise ValueError("Need either a dataset name or a training/validation file.") |
| else: |
| if self.train_file is not None: |
| extension = self.train_file.split(".")[-1] |
| assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." |
| if self.validation_file is not None: |
| extension = self.validation_file.split(".")[-1] |
| assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." |
|
|
|
|
| if __name__ == "__main__": |
| parser = HfArgumentParser((BaseArgs, ModelArguments, DataArguments, TrainingArguments)) |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
| parser = argparse.ArgumentParser() |
|
|
| args, _ = parser.parse_known_args() |
|
|
| |
| logger = logging.getLogger(__name__) |
|
|
| logging.basicConfig( |
| level=logging.getLevelName("INFO"), |
| handlers=[logging.StreamHandler(sys.stdout)], |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| ) |
|
|
| |
| train_dataset = load_from_disk(args.training_dir) |
| test_dataset = load_from_disk(args.test_dir) |
|
|
| logger.info(f" loaded train_dataset length is: {len(train_dataset)}") |
| logger.info(f" loaded test_dataset length is: {len(test_dataset)}") |
|
|
| |
| config = FunnelVaeConfig.from_pretrained(**model_args.__dict__) |
| tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast_tokenizer=True) |
|
|
| vocab_size = len(tokenizer) |
| config.funnel.vocab_size = vocab_size |
| config.t5.vocab_size = vocab_size |
| config.vocab_size = vocab_size |
| model = FunnelVae(config) |
|
|
| model = FunnelVae.from_pretrained() |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
|
| |
| training_args = TrainingArguments( |
| output_dir=args.model_dir, |
| num_train_epochs=args.epochs, |
| per_device_train_batch_size=args.train_batch_size, |
| per_device_eval_batch_size=args.eval_batch_size, |
| warmup_steps=args.warmup_steps, |
| evaluation_strategy="epoch", |
| logging_dir=f"{args.output_data_dir}/logs", |
| learning_rate=float(args.learning_rate), |
| ) |
|
|
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=test_dataset, |
| tokenizer=tokenizer, |
| ) |
|
|
| |
| trainer.train() |
|
|
| |
| eval_result = trainer.evaluate(eval_dataset=test_dataset) |
|
|
| |
| with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer: |
| print(f"***** Eval results *****") |
| for key, value in sorted(eval_result.items()): |
| writer.write(f"{key} = {value}\n") |
|
|
| |
| trainer.save_model(args.model_dir) |
|
|