| #!/bin/bash |
|
|
| USE_DDP=false |
|
|
| |
| PHASE1_CKPT="logs/model_glen_vault/GLEN_P1_base" |
|
|
| |
| GPU_MEMORY_THRESHOLD=0.85 |
| GPU_CHECK_INTERVAL=50 |
|
|
| if [ $USE_DDP = false ]; then |
| |
| CUDA_VISIBLE_DEVICES=0 \ |
| python examples/glen_phase2/train_glen.py \ |
| --output_dir logs/model_glen_vault/GLEN_P2_base \ |
| --model_name_or_path ${PHASE1_CKPT} \ |
| --load_best_model_at_end True \ |
| --per_device_train_batch_size 4 \ |
| --per_device_eval_batch_size 2 \ |
| --gradient_accumulation_steps 32 \ |
| --dropout_rate 0.1 \ |
| --warmup_ratio 0.1 \ |
| --id_class t5_bm25_truncate_3 \ |
| --dataset_name the_vault \ |
| --test100 1 \ |
| --tree 1 \ |
| --q_max_len 32 \ |
| --p_max_len 256 \ |
| --negative_passage_type self \ |
| --positive_passage_no_shuffle True \ |
| --tie_word_embeddings True \ |
| --num_return_sequences 10 \ |
| --logging_steps 100 \ |
| --overwrite_output_dir \ |
| --wandb_tag glen_vault_p2 \ |
| --do_eval \ |
| --seed 42 \ |
| --gpu_memory_threshold ${GPU_MEMORY_THRESHOLD} \ |
| --gpu_check_interval ${GPU_CHECK_INTERVAL} \ |
| --fp16 True |
| else |
| |
| CUDA_VISIBLE_DEVICES=0,1 \ |
| python -m torch.distributed.launch --nproc_per_node=2 examples/glen_phase2/train_glen.py \ |
| --ddp_find_unused_parameters False \ |
| --output_dir logs/model_glen_vault/GLEN_P2_base \ |
| --model_name_or_path ${PHASE1_CKPT} \ |
| --load_best_model_at_end True \ |
| --per_device_train_batch_size 4 \ |
| --per_device_eval_batch_size 2 \ |
| --gradient_accumulation_steps 32 \ |
| --dropout_rate 0.1 \ |
| --warmup_ratio 0.1 \ |
| --id_class t5_bm25_truncate_3 \ |
| --dataset_name the_vault \ |
| --test100 1 \ |
| --tree 1 \ |
| --q_max_len 32 \ |
| --p_max_len 256 \ |
| --negative_passage_type self \ |
| --positive_passage_no_shuffle True \ |
| --tie_word_embeddings True \ |
| --num_return_sequences 10 \ |
| --logging_steps 100 \ |
| --overwrite_output_dir \ |
| --wandb_tag glen_vault_p2 \ |
| --do_eval \ |
| --seed 42 \ |
| --gpu_memory_threshold ${GPU_MEMORY_THRESHOLD} \ |
| --gpu_check_interval ${GPU_CHECK_INTERVAL} \ |
| --fp16 True |
| fi |