GLEN-model / src /tevatron /arguments.py
QuanTH02's picture
Commit 15-06-v1
6534252
from typing import Optional
from dataclasses import dataclass, field
from transformers import TrainingArguments
@dataclass
class GLENTrainingArguments(TrainingArguments):
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(
default=False, metadata={"help": "Whether to run eval on the dev set."}
)
warmup_ratio: float = field(default=0.0)
negatives_x_device: bool = field(
default=False, metadata={"help": "share negatives across devices"}
)
do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"})
project_name: Optional[str] = field(
default="EMNLP2023", metadata={"help": "wandb project name"}
)
wandb_tag: Optional[str] = field(default=None, metadata={"help": "wandb tag"})
save_steps: int = field(
default=20000, metadata={"help": "save model every x steps"}
)
save_strategy: str = field(default="steps", metadata={"help": "save strategy"})
save_total_limit: int = field(default=5, metadata={"help": "save total limit"})
res1_save_path: str = field(default="")
val_check_interval: float = field(
default=0.2, metadata={"help": "validation check interval for each epoch"}
)
evaluation_strategy: str = field(
default="steps", metadata={"help": "evaluation strategy"}
)
# GPU Memory Monitoring Arguments
gpu_memory_threshold: float = field(
default=0.85, metadata={"help": "GPU memory threshold (0.0-1.0) to stop training"}
)
gpu_check_interval: int = field(
default=50, metadata={"help": "Check GPU memory every N steps"}
)
@dataclass
class GLENP1TrainingArguments(GLENTrainingArguments):
metric_for_best_model: str = field(
default="eval_recall@1", metadata={"help": "metric for best model"}
)
num_train_epochs: float = field(
default=500.0, metadata={"help": "number of training epochs"}
)
adam_epsilon: float = field(default=1e-8, metadata={"help": "adam epsilon"})
warmup_steps: int = field(default=0, metadata={"help": "warmup steps"})
weight_decay: float = field(default=1e-4, metadata={"help": "weight decay"})
learning_rate: float = field(default=2e-4, metadata={"help": "learning rate"})
decoder_learning_rate: float = field(
default=1e-4, metadata={"help": "decoder learning rate"}
)
@dataclass
class GLENP2TrainingArguments(GLENTrainingArguments):
learning_rate: float = field(default=5e-5, metadata={"help": "learning rate"})
grad_cache: bool = field(
default=False, metadata={"help": "Use gradient cache update"}
)
gc_q_chunk_size: int = field(default=128)
gc_p_chunk_size: int = field(default=128)
@dataclass
class GLENDataArguments:
dataset_name: str = field(
default=None,
metadata={"help": "huggingface dataset name or custom dataset name"},
)
encode_train_qry: bool = field(default=False)
test100: int = field(
default=0,
metadata={"help": "Debug mode. Only use a subset of the data (100 examples)"},
)
query_type: str = field(
default="gtq_doc_aug_qg",
metadata={
"help": "gtq: ground turth query, qg: generated query, doc: just use top64 doc tokens, aug: use random doc token"
},
)
small_set: int = field(
default=0, metadata={"help": "nq320k small set size", "choices": [0, 1, 10]}
)
aug_query: bool = field(
default=True, metadata={"help": "whether to use augmented query"}
)
aug_query_type: str = field(
default="corrupted_query",
metadata={
"help": "augmented query type",
"choices": ["corrupted_query", "aug_query"],
},
)
id_class: str = field(
default="t5_bm25_truncate_3", metadata={"help": "id class for nq320k"}
)
@dataclass
class GLENP1DataArguments(GLENDataArguments):
max_input_length: int = field(default=156, metadata={"help": "max input length"})
max_output_length: int = field(default=5, metadata={"help": "max output length"})
@dataclass
class GLENP2DataArguments(GLENDataArguments):
max_input_length: int = field(
default=156, metadata={"help": "max input length used for making id"}
)
train_n_passages: int = field(default=0)
positive_passage_no_shuffle: bool = field(
default=True, metadata={"help": "always use the first positive passage"}
)
negative_passage_no_shuffle: bool = field(
default=False, metadata={"help": "always use the first negative passages"}
)
negative_passage_type: str = field(
default="self",
metadata={
"help": "ibn: in batch negative, hard: hard negative, random: random negative",
"choices": ["random", "self"],
},
)
q_max_len: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for query. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
p_max_len: int = field(
default=156,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
@dataclass
class GLENModelArguments:
model_name_or_path: str = field(
default="t5-base",
metadata={
"help": "Path to pretrained model or model identifier from huggingface.co/models"
},
)
config_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained config name or path if not the same as model_name"
},
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained tokenizer name or path if not the same as model_name"
},
)
cache_dir: Optional[str] = field(
default=None,
metadata={
"help": "Where do you want to store the pretrained models downloaded from s3"
},
)
num_layers: int = field(default=12)
num_decoder_layers: int = field(default=12)
d_ff: int = field(default=3072)
d_model: int = field(default=768)
num_heads: int = field(default=12)
d_kv: int = field(default=64)
use_past_key_values: bool = field(default=True)
load_pretrained_st5_checkpoint: str = field(default=None)
mask_special_tokens_for_decoding: bool = field(default=True)
tie_decode_embeddings: bool = field(default=True)
tie_word_embeddings: bool = field(default=True)
dropout_rate: float = field(default=0.1)
# Inference Arguments
length_penalty: float = field(default=0.8)
num_return_sequences: int = field(
default=100, metadata={"help": "number of return sequences."}
)
early_stopping: bool = field(default=False)
tree: int = field(default=1)
reranking: str = field(
default="cosine",
metadata={
"help": "random, cosine, mse",
"choices": ["random", "cosine", "mse"],
},
)
gen_method: str = field(
default="greedy", metadata={"help": "Only used when decoder_input is docid"}
) # greedy, beam_search, top_k, top_p
infer_ckpt: str = field(
default="",
metadata={
"help": "Path to checkpoint file (e.g., logs/GLEN-6700/pytorch_model.bin). Model args will not be loaded from model_args.json"
},
)
infer_dir: str = field(
default="",
metadata={
"help": "Path to directory that contains .bin files (e.g., logs/GLEN-6700)"
},
)
logs_dir: str = field(
default="logs", metadata={"help": "Path to save inference results"}
)
docid_file_name: str = field(default="")
@dataclass
class GLENP1ModelArguments(GLENModelArguments):
verbose_valid_query: int = field(
default=1,
metadata={
"help": "0: no verbose, 1: verbose with 10^1 queries, 2: verbose with all queries",
"choices": [0, 1, 2],
},
)
freeze_encoder: bool = field(default=False)
freeze_embeds: bool = field(default=False)
pretrain_encoder: bool = field(default=True)
pretrain_decoder: bool = field(default=True)
output_vocab_size: int = field(default=10)
Rdrop: float = field(default=0.15)
input_dropout: int = field(default=1)
decoder_input: str = field(default="doc_rep") # doc_rep, doc_id
@dataclass
class GLENP2ModelArguments(GLENModelArguments):
softmax_temperature: float = field(default=1.0)
num_multi_vectors: int = field(default=3)
untie_encoder: bool = field(
default=False,
metadata={"help": "no weight sharing between qry passage encoders"},
)
infonce_loss: float = field(default=1.0) # pairwise ranking loss
q_to_docid_loss: float = field(default=0.5) # pointwise retrieval loss (first term)
cosine_point_loss: float = field(
default=0.25
) # pointwise retrieval loss (second term)
do_docid_temperature_annealing: bool = field(default=True)
docid_temperature: float = field(default=1.0)
docid_temperature_min: float = field(default=1e-5)