| |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.optim.lr_scheduler import StepLR |
| from inference import prepare_for_lwm |
| from input_preprocess import tokenizer |
| from lwm_model import lwm |
| import numpy as np |
| import DeepMIMOv3 |
|
|
| |
| def get_parameters(scenario): |
| |
| n_ant_bs = 32 |
| n_ant_ue = 1 |
| n_subcarriers = 32 |
| scs = 30e3 |
| |
| row_column_users = { |
| 'asu_campus1': { |
| 'n_rows': 321, |
| 'n_per_row': 411 |
| }, |
| 'Boston5G_3p5': { |
| 'n_rows': [812,1622], |
| 'n_per_row': 595 |
| }, |
| 'city_0_newyork': { |
| 'n_rows': 44, |
| 'n_per_row': 117 |
| }, |
| 'city_1_losangeles': { |
| 'n_rows': 57, |
| 'n_per_row': 81 |
| }, |
| 'city_2_chicago': { |
| 'n_rows': 56, |
| 'n_per_row': 80 |
| }, |
| 'city_3_houston': { |
| 'n_rows': 62, |
| 'n_per_row': 81 |
| }, |
| 'city_4_phoenix': { |
| 'n_rows': 79, |
| 'n_per_row': 86 |
| }, |
| 'city_5_philadelphia': { |
| 'n_rows': 96, |
| 'n_per_row': 66 |
| }, |
| 'city_6_miami': { |
| 'n_rows': 80, |
| 'n_per_row': 87 |
| }, |
| 'city_8_dallas': { |
| 'n_rows': 83, |
| 'n_per_row': 76 |
| }, |
| 'city_9_sanfrancisco': { |
| 'n_rows': 79, |
| 'n_per_row': 83 |
| }, |
| 'city_10_austin': { |
| 'n_rows': 102, |
| 'n_per_row': 55 |
| }, |
| 'city_13_columbus': { |
| 'n_rows': 71, |
| 'n_per_row': 96 |
| }, |
| 'city_17_seattle': { |
| 'n_rows': 74, |
| 'n_per_row': 82 |
| }, |
| 'O1_3p5': { |
| 'n_rows': 5203, |
| 'n_per_row': 181 |
| }, |
| 'city_18_denver': { |
| 'n_rows': 85, |
| 'n_per_row': 82 |
| }, |
| 'city_15_indianapolis': { |
| 'n_rows': 80, |
| 'n_per_row': 79 |
| }, |
| 'city_19_oklahoma': { |
| 'n_rows': 82, |
| 'n_per_row': 75 |
| }, |
| 'city_12_fortworth': { |
| 'n_rows': 86, |
| 'n_per_row': 72 |
| }, |
| 'city_11_santaclara': { |
| 'n_rows': 47, |
| 'n_per_row': 114 |
| }, |
| 'city_7_sandiego': { |
| 'n_rows': 71, |
| 'n_per_row': 83 |
| }} |
| |
| parameters = DeepMIMOv3.default_params() |
| parameters['dataset_folder'] = './scenarios' |
| parameters['scenario'] = scenario |
| |
| if scenario == 'O1_3p5': |
| parameters['active_BS'] = np.array([4]) |
| elif scenario in ['city_14_charlotte', 'city_18_denver', 'city_15_indianapolis']: |
| parameters['active_BS'] = np.array([3]) |
| else: |
| parameters['active_BS'] = np.array([1]) |
| |
| if scenario == 'Boston5G_3p5': |
| parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0], |
| row_column_users[scenario]['n_rows'][1]) |
| else: |
| parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows']) |
| parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) |
| parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) |
| parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1]) |
| parameters['enable_BS2BS'] = False |
| parameters['OFDM']['subcarriers'] = n_subcarriers |
| parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers) |
| |
| parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9 |
| parameters['num_paths'] = 20 |
| |
| return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers |
| |
| |
| n_epochs = 100 |
| n_layers = 12 |
| n_heads = 12 |
| d_model = 64 |
| d_ff = d_model * 4 |
| d_k = d_model // n_heads |
| d_v = d_model // n_heads |
| dropout = 0.1 |
| max_len = 129 |
| element_length = 16 |
| batch_size = 64 |
| train_ratio = 0.7 |
| val_ratio = 0.2 |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| |
| |
| |
| |
| |
| scenario_names = np.array([ |
| "city_18_denver", "city_15_indianapolis", "city_19_oklahoma", |
| "city_12_fortworth", "city_11_santaclara", "city_7_sandiego" |
| ]) |
|
|
| scenario_idxs = np.array([0, 1, 2, 3, 4, 5]) |
| selected_scenario_names = scenario_names[scenario_idxs] |
|
|
| preprocessed_chs = tokenizer( |
| selected_scenario_names=selected_scenario_names, |
| manual_data=None, |
| gen_raw=False) |
|
|
| |
| train_size = int(train_ratio * len(preprocessed_chs)) |
| val_size = int(val_ratio * len(preprocessed_chs)) |
| test_size = len(preprocessed_chs) - val_size - train_size |
|
|
| train_data, val_data, test_data = torch.utils.data.random_split( |
| preprocessed_chs, [train_size, val_size, test_size] |
| ) |
|
|
| train_loader = prepare_for_lwm(train_data, device, batch_size=batch_size, shuffle=True) |
| val_loader = prepare_for_lwm(val_data, device, batch_size=batch_size, shuffle=True) |
| test_loader = prepare_for_lwm(test_data, device, batch_size=batch_size, shuffle=True) |
|
|
| |
| load_model = False |
|
|
| model = lwm() |
| model.to(device) |
|
|
| if load_model: |
| model_name = 'models/pretrained_model.pth' |
| model.load_state_dict(torch.load(model_name)) |
| print(f"Model loaded from {model_name}") |
| |
| |
| criterionMLM = nn.MSELoss() |
|
|
| |
| adaptive_lr = False |
|
|
| optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5) |
| scheduler = ( |
| optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min') |
| if adaptive_lr |
| else StepLR(optimizer, step_size=10, gamma=0.9) |
| ) |
|
|
| |
| training_loss = [] |
| validation_loss = [] |
|
|
| def train(model, dataloader, optimizer, scheduler=None, device="cuda"): |
|
|
| model.train() |
| running_loss = 0.0 |
| criterionMCM = nn.MSELoss() |
|
|
| for idx, batch in enumerate(dataloader): |
| input_ids = batch[0].to(device) |
| masked_tokens = batch[1].to(device) |
| masked_pos = batch[2].to(device) |
| |
| optimizer.zero_grad() |
| |
| logits_lm, _ = model(input_ids, masked_pos) |
| loss_lm = criterionMCM(logits_lm, masked_tokens) |
| loss = loss_lm / torch.var(masked_tokens) |
| |
| loss.backward() |
| optimizer.step() |
|
|
| if scheduler is not None: |
| scheduler.step() |
|
|
| running_loss += loss.item() |
|
|
| average_loss = running_loss / len(dataloader) |
|
|
| return average_loss |
|
|
| def validate(model, dataloader, device="cuda"): |
| model.eval() |
| running_loss = 0.0 |
| criterionMCM = nn.MSELoss() |
|
|
| with torch.no_grad(): |
| for idx, batch in enumerate(dataloader): |
| input_ids = batch[0].to(device) |
| masked_tokens = batch[1].to(device) |
| masked_pos = batch[2].to(device) |
|
|
| logits_lm, _ = model(input_ids, masked_pos) |
|
|
| loss_lm = criterionMCM(logits_lm, masked_tokens) |
| loss = loss_lm / torch.var(masked_tokens) |
|
|
| running_loss += loss.item() |
|
|
| average_loss = running_loss / len(dataloader) |
|
|
| return average_loss |
|
|
| |
| for epoch in range(n_epochs): |
| print(f"Epoch {epoch + 1}/{n_epochs}") |
|
|
| |
| train_loss = train(model, train_loader, optimizer, scheduler, device) |
| training_loss.append(train_loss) |
| print(f"Training Loss: {train_loss:.4f}") |
|
|
| |
| if val_loader is not None: |
| val_loss = validate(model, val_loader, device) |
| validation_loss.append(val_loss) |
| print(f"Validation Loss: {val_loss:.4f}") |
|
|