|
|
| import numpy as np
|
| import torch
|
| from input_preprocess import DeepMIMO_data_gen, deepmimo_data_cleaning, tokenizer
|
| from inference import lwm_inference, create_raw_dataset
|
| from lwm_model import lwm
|
|
|
|
|
| scenario_names = np.array([
|
| "city_18_denver", "city_15_indianapolis", "city_19_oklahoma",
|
| "city_12_fortworth", "city_11_santaclara", "city_7_sandiego"
|
| ])
|
|
|
| bf_scenario_idx = 3
|
| scenario_idxs = np.array([bf_scenario_idx])
|
| selected_scenario_names = scenario_names[scenario_idxs]
|
|
|
| deepmimo_data = [DeepMIMO_data_gen(scenario_name) for scenario_name in selected_scenario_names]
|
| cleaned_deepmimo_data = [deepmimo_data_cleaning(deepmimo_data[scenario_idx]) for scenario_idx in range(len(deepmimo_data))]
|
|
|
|
|
| def compute_mrt_beamforming(channel_data, snr_db=None):
|
|
|
| channel_data = torch.tensor(channel_data[0])
|
| mrt_vectors = []
|
| snr_linear = 10 ** (snr_db / 10) if snr_db is not None else None
|
|
|
| for idx in range(channel_data.shape[0]):
|
| channel = channel_data[idx, 0, :, :]
|
|
|
| if snr_db is not None:
|
|
|
| noise_power = torch.mean(torch.abs(channel) ** 2) / snr_linear
|
| noise = torch.sqrt(noise_power / 2) * (
|
| torch.randn_like(channel) + 1j * torch.randn_like(channel)
|
| )
|
| channel = channel + noise
|
|
|
|
|
| h_avg = torch.mean(channel, dim=1, keepdim=True)
|
| h_conj = torch.conj(h_avg)
|
| mrt_vector = h_conj / torch.norm(h_conj, dim=0, keepdim=True)
|
|
|
| mrt_vectors.append(mrt_vector)
|
|
|
| return torch.stack(mrt_vectors, dim=0)
|
|
|
|
|
| beamforming_vectors = compute_mrt_beamforming(cleaned_deepmimo_data)
|
|
|
|
|
| preprocessed_chs = tokenizer(
|
| selected_scenario_names=selected_scenario_names,
|
| manual_data=None,
|
| gen_raw=False)
|
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| print(f"Loading the LWM model on {device} ...")
|
| model = lwm.from_pretrained(device=device)
|
|
|
| input_types = ['cls_emb', 'channel_emb', 'raw']
|
| selected_input_type = input_types[1]
|
|
|
| if selected_input_type in ['cls_emb', 'channel_emb']:
|
| dataset = lwm_inference(preprocessed_chs, selected_input_type, model, device)
|
| else:
|
| dataset = create_raw_dataset(preprocessed_chs, device)
|
|
|