| import torch |
| import torch.nn as nn |
| import numpy as np |
| import argparse |
| import random |
| import os |
| from HybridTensor.routers.mlp.trainer_mlp import train |
| from HybridTensor.routers.router_utils import DATA, MODEL_CHOICES, DATA_CHOICES, CONFIG, BasicDataset |
| from HybridTensor.routers.router_utils import get_data, create_dataset, create_log_path, augment_data, get_data_ |
| from HybridTensor.routers.router_utils import generate_experiment_id, create_date_directory, generate_log_filename, setup_logging, generate_model_filename |
| import logging |
| from HybridTensor.routers.mlp.MLP_router import Router |
|
|
| from HybridTensor.utils.utils import _get_device, extract_model_name |
| from HybridTensor.utils.activations import OPT_CONFIGS, OPT_MODELS |
|
|
| def arg_parser(): |
| parser = argparse.ArgumentParser(description="PyTorch OPT Full Model") |
| parser.add_argument("--model", type=str, default="6_7b", choices=MODEL_CHOICES) |
| parser.add_argument("--model_name", type=str, default="opt", help="model name") |
| parser.add_argument("--model_index", type=int, default=8, help="model index") |
| parser.add_argument("--dataset", type=str, default="c4", choices=DATA_CHOICES) |
| parser.add_argument("--L", type=int, default=0, help="which layer") |
| parser.add_argument("--D", type=int, default=1024, help="low rank dimension") |
| parser.add_argument("--batch_size", type=int, default=64, help="batch size") |
| parser.add_argument("--norm", type=str, default="none", choices=["batch", "layer", "none"]) |
| parser.add_argument("--epochs", type=int, default=20, help="epochs") |
| parser.add_argument("--lr", type=float, default=0.0001, help="learning rate") |
| parser.add_argument("--ckpt_dir", type=str, default="<PATH_TO_CHECKPOINT_DIR>", help="checkpoint directory") |
| parser.add_argument("--data_dir", type=str, default="<PATH_TO_DATA_DIR>", help="data directory") |
| parser.add_argument("--gpu", type=int, default=0, help="which gpu to use") |
| parser.add_argument("--max_samples", type=int, default=0, help="total samples to process") |
| parser.add_argument("--dropout", type=float, default=0) |
| parser.add_argument("--alpha", type=float, default=1) |
| parser.add_argument("--gamma", type=float, default=1) |
| parser.add_argument("--threshold", type=float, default=0.5) |
| parser.add_argument("--data_augmentation", type=int, default=0) |
| parser.add_argument("--loss_fn", type=str, default="bce", choices=["bce", "focal", "mse"]) |
| parser.add_argument("--hidden_activation", type=str, default="none", choices=["relu", "gelu", "relu_squared", "leaky_relu", "swish", "swiglu", "elu", "selu", "mish", "none"]) |
| parser.add_argument("--debug", type=bool, default=False) |
|
|
| args = parser.parse_args() |
| return args |
|
|
|
|
| if __name__ == "__main__": |
| |
| args = arg_parser() |
| model_name = OPT_MODELS[args.model_index -1] |
| model_config = OPT_CONFIGS[model_name] |
| model_dim = model_config['d'] |
| inner_dim = args.D |
| device = _get_device(args.gpu) |
| |
| |
| if args.debug: |
| experiment_id = generate_experiment_id() |
| date_dir = create_date_directory(args.ckpt_dir, args.model) |
| log_file_path = generate_log_filename(experiment_id, args.model_name, args.L, date_dir) |
| else: |
| date_dir = "None" |
| experiment_id = "None" |
| model_name_clean = extract_model_name(model_name) |
| log_file_path = f"{args.ckpt_dir}/{model_name_clean}-routers/mlp/logs/{args.model_name}-{args.L}.log" |
| ckpt_path = f"{args.ckpt_dir}/{model_name_clean}-routers/mlp" |
| |
| |
| log_dir = os.path.dirname(log_file_path) |
| os.makedirs(log_dir, exist_ok=True) |
| |
| |
| setup_logging(log_file_path) |
| |
| |
| logging.info(f"Experiment ID: {experiment_id}") |
| logging.info(f"Model Name: {args.model_name}") |
| logging.info(f"Layer Index: {args.L}") |
| logging.info(f"Checkpoint Directory: {args.ckpt_dir}") |
| logging.info(f"Date Directory: {date_dir}") |
| logging.info(f"Using Dropout: {args.dropout}") |
| logging.info(f"loss_fn: {args.loss_fn}") |
| logging.info(f"hidden_activation: {args.hidden_activation}") |
| logging.info(f"Learning Rate: {args.lr}") |
| |
| random.seed(0) |
| torch.manual_seed(0) |
| np.random.seed(0) |
|
|
| device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") |
|
|
| print("=" * 40, "Layer", args.L, "=" * 40) |
| |
| if args.max_samples == 0: |
| total_samples = 400000 |
| else: |
| total_samples = args.max_samples |
| |
| query, labels = get_data_(args.data_dir, args.L, data_type = 'mlp_activations', total_samples=total_samples) |
| |
| |
| if args.data_augmentation: |
| logging.info("Using Data Augmentation") |
| index_metadata, cold_labels = augment_data(labels, device=device) |
| hot_index, cold_index, pruned_index = index_metadata |
| cold_neurons = len(cold_index) |
| train_loader, test_loader = create_dataset(query, cold_labels, args) |
| else: |
| logging.info("Not Using Data Augmentation") |
| train_loader, test_loader = create_dataset(query, labels, args) |
| cold_neurons = 0 |
| |
| |
| |
| mlp_router = Router(model_dim, inner_dim, norm=args.norm, hidden_activation=args.hidden_activation) |
| |
| |
| logging.info("Model Architecture:\n{}".format(mlp_router)) |
| logging.info("Training Arguments:\n{}".format(args)) |
| |
| best_model, eval_result = train( |
| mlp_router, train_loader, test_loader, args, device, verbal=True |
| ) |
| |
| if args.debug: |
| model_filename = generate_model_filename(experiment_id, args.model_name, args.L, eval_result) |
| model_file_path = date_dir / model_filename |
| else: |
| model_file_path = f"{ckpt_path}/mlp_router_{args.L}-{eval_result['Recall']:.2f}-{eval_result['Precision']:.2f}-{eval_result['Classifier Sparsity']:.2f}.pt" |
| |
| |
| torch.save(best_model, model_file_path) |
| logging.info(f"Model saved at {model_file_path}") |
|
|