| from utils import CustomDataset, transform, Convert_ONNX |
| from torch.utils.data import Dataset, DataLoader |
| import torch |
| import numpy as np |
| from resnet_model_mask import ResidualBlock, ResNet |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from tqdm import tqdm |
| import torch.nn.functional as F |
| from torch.optim.lr_scheduler import ReduceLROnPlateau |
| import pickle |
|
|
| torch.manual_seed(1) |
| |
|
|
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| num_gpus = torch.cuda.device_count() |
| print(num_gpus) |
|
|
| |
| data_dir = '/mnt/buf0/pma/frbnn/train_ready' |
| dataset = CustomDataset(data_dir, transform=transform) |
| valid_data_dir = '/mnt/buf0/pma/frbnn/valid_ready' |
| valid_dataset = CustomDataset(valid_data_dir, transform=transform) |
|
|
|
|
| num_classes = 2 |
| trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32) |
|
|
| model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device) |
| model = nn.DataParallel(model) |
| model = model.to(device) |
| params = sum(p.numel() for p in model.parameters()) |
| print("num params ",params) |
|
|
|
|
| model_path = 'models/model-47-99.125.pt' |
|
|
| model.load_state_dict(torch.load(model_path, weights_only=True)) |
| model = model.eval() |
|
|
| |
| import sigpyproc.readers as r |
| import cv2 |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from scipy.special import softmax |
| from tqdm import tqdm |
|
|
| all_detections = [] |
|
|
| |
| print("Processing first file (SNR 180)...") |
| fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60398_67123_110077819_frb20240114a_0001/LoC.C0736/decimated.fil') |
| header = fil.header |
| print(header) |
|
|
| triggers = [] |
| counter = 0 |
| for i in tqdm(range(27085468,27397968, 2048)): |
| data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() |
| out = model(transform(torch.tensor(data).cuda())[None]) |
| out = softmax(out.detach().cpu().numpy(), axis=1) |
| triggers.append(out) |
| counter += 1 |
| if out[0, 1]>0.9982: |
| key = data.cpu().numpy() |
| all_detections.append({ |
| 'data': key, |
| 'confidence': out[0, 1], |
| 'file_index': i, |
| 'file_name': 'fil_60398_67123_110077819_frb20240114a_0001 (SNR 180)', |
| 'normalization': 'raw', |
| 'header': header |
| }) |
| stack = np.stack(triggers) |
| positives = stack[:,0,1] |
| num_pos = np.where(positives > 0.9988)[0].shape[0] |
| print(f"File 1 detections: {num_pos}") |
|
|
| |
| print("Processing second file (SNR 60)...") |
| fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60428_58167_24730285_frb20240114a_0001/LoC.C1504/decimated.fil') |
| header = fil.header |
| print(header) |
|
|
| triggers = [] |
| counter = 0 |
| for i in tqdm(range(8148984,8461484, 2048)): |
| data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() |
| out = model(transform(torch.tensor(data).cuda())[None]) |
| out = softmax(out.detach().cpu().numpy(), axis=1) |
| triggers.append(out) |
| counter += 1 |
| if out[0, 1]>0.9988: |
| key = data.cpu().numpy() |
| result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) |
| all_detections.append({ |
| 'data': key/result, |
| 'confidence': out[0, 1], |
| 'file_index': i, |
| 'file_name': 'fil_60428_58167_24730285_frb20240114a_0001 (SNR 60)', |
| 'normalization': 'normalized', |
| 'header': header |
| }) |
| stack = np.stack(triggers) |
| positives = stack[:,0,1] |
| num_pos = np.where(positives > 0.9988)[0].shape[0] |
| print(f"File 2 detections: {num_pos}") |
|
|
| |
| print("Processing third file...") |
| fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60427_42703_18513000_frb20240114a_0001/LoC.C1504/decimated.fil') |
| header = fil.header |
| print(header) |
|
|
| triggers = [] |
| counter = 0 |
| for i in tqdm(range(20343125,20655625, 2048)): |
| data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() |
| out = model(transform(torch.tensor(data).cuda())[None]) |
| out = softmax(out.detach().cpu().numpy(), axis=1) |
| triggers.append(out) |
| counter += 1 |
| if out[0, 1]>0.9988: |
| key = data.cpu().numpy() |
| result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) |
| all_detections.append({ |
| 'data': key/result, |
| 'confidence': out[0, 1], |
| 'file_index': i, |
| 'file_name': 'fil_60427_42703_18513000_frb20240114a_0001', |
| 'normalization': 'normalized', |
| 'header': header |
| }) |
| stack = np.stack(triggers) |
| positives = stack[:,0,1] |
| num_pos = np.where(positives > 0.9988)[0].shape[0] |
| print(f"File 3 detections: {num_pos}") |
|
|
| |
| print("Processing fourth file...") |
| fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60395_72956_94613525_frb20240114a_0001/LoB.C1312/decimated.fil') |
| header = fil.header |
| print(header) |
|
|
| triggers = [] |
| counter = 0 |
| for i in tqdm(range(8708515,9021015, 2048)): |
| data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() |
| out = model(transform(torch.tensor(data).cuda())[None]) |
| out = softmax(out.detach().cpu().numpy(), axis=1) |
| triggers.append(out) |
| counter += 1 |
| if out[0, 1]>0.9988: |
| key = data.cpu().numpy() |
| result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) |
| all_detections.append({ |
| 'data': key/result, |
| 'confidence': out[0, 1], |
| 'file_index': i, |
| 'file_name': 'fil_60395_72956_94613525_frb20240114a_0001', |
| 'normalization': 'normalized', |
| 'header': header |
| }) |
| stack = np.stack(triggers) |
| positives = stack[:,0,1] |
| num_pos = np.where(positives > 0.9988)[0].shape[0] |
| print(f"File 4 detections: {num_pos}") |
|
|
| |
| print("Processing fifth file...") |
| fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60429_47342_29343017_frb20240114a_0001/LoB.C1120/decimated.fil') |
| header = fil.header |
| print(header) |
|
|
| triggers = [] |
| counter = 0 |
| for i in tqdm(range(10399062,10711562, 2048)): |
| data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() |
| out = model(transform(torch.tensor(data).cuda())[None]) |
| out = softmax(out.detach().cpu().numpy(), axis=1) |
| triggers.append(out) |
| counter += 1 |
| if out[0, 1]>0.9988: |
| key = data.cpu().numpy() |
| result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) |
| all_detections.append({ |
| 'data': key/result, |
| 'confidence': out[0, 1], |
| 'file_index': i, |
| 'file_name': 'fil_60429_47342_29343017_frb20240114a_0001', |
| 'normalization': 'normalized', |
| 'header': header |
| }) |
| stack = np.stack(triggers) |
| positives = stack[:,0,1] |
| num_pos = np.where(positives > 0.9988)[0].shape[0] |
| print(f"File 5 detections: {num_pos}") |
|
|
| |
| print("Processing sixth file...") |
| fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60456_42557_118616821_frb20240114a_0001/LoC.C1312/decimated.fil') |
| header = fil.header |
| print(header) |
|
|
| triggers = [] |
| counter = 0 |
| for i in tqdm(range(1250000,1562500, 2048)): |
| data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() |
| out = model(transform(torch.tensor(data).cuda())[None]) |
| out = softmax(out.detach().cpu().numpy(), axis=1) |
| triggers.append(out) |
| counter += 1 |
| if out[0, 1]>0.9988: |
| key = data.cpu().numpy() |
| result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) |
| all_detections.append({ |
| 'data': key/result, |
| 'confidence': out[0, 1], |
| 'file_index': i, |
| 'file_name': 'fil_60456_42557_118616821_frb20240114a_0001', |
| 'normalization': 'normalized', |
| 'header': header |
| }) |
| stack = np.stack(triggers) |
| positives = stack[:,0,1] |
| num_pos = np.where(positives > 0.9988)[0].shape[0] |
| print(f"File 6 detections: {num_pos}") |
|
|
| |
| print(f"\nTotal detections found: {len(all_detections)}") |
|
|
| if len(all_detections) > 0: |
| |
| all_detections.sort(key=lambda x: x['confidence'], reverse=True) |
| |
| |
| n_detections = len(all_detections) |
| cols = 2 |
| rows = 5 |
| |
| fig, axes = plt.subplots(rows, cols, figsize=(10, 12)) |
| |
| |
| axes_flat = axes.flatten() |
| |
| for idx, detection in enumerate(all_detections): |
| ax = axes_flat[idx] |
| |
| |
| data_median = np.median(detection['data']) |
| im = ax.imshow(detection['data'], aspect=6, cmap='hot', vmin=data_median) |
| |
| |
| |
| time_increment = 6.5e-5 |
| n_samples = detection['data'].shape[1] |
| total_time = n_samples * time_increment |
| |
| |
| n_ticks = 5 |
| tick_positions = np.linspace(0, n_samples-1, n_ticks) |
| tick_labels = [f"{i*time_increment:.2f}" for i in tick_positions] |
| |
| ax.set_xticks(tick_positions) |
| ax.set_xticklabels(tick_labels, fontsize=12) |
| |
| |
| if idx >= 8: |
| ax.set_xlabel('Time (seconds)', fontsize=14) |
| |
| |
| header = detection['header'] |
| fch1 = header.fch1 |
| foff = header.foff |
| nchans = header.nchans |
| |
| |
| freq_start = fch1 |
| freq_end = fch1 + (nchans - 1) * foff |
| |
| |
| n_freq_ticks = 5 |
| freq_tick_positions = np.linspace(0, nchans-1, n_freq_ticks) |
| freq_values = [fch1 + i * foff for i in freq_tick_positions] |
| freq_labels = [f"{freq:.1f}" for freq in freq_values] |
| |
| ax.set_yticks(freq_tick_positions) |
| ax.set_yticklabels(freq_labels, fontsize=12) |
| |
| |
| if idx % 2 == 0: |
| ax.set_ylabel('Freq. (MHz)', fontsize=14) |
| |
| |
| ax.tick_params(axis='both', which='major', size=3) |
| |
| |
| for idx in range(n_detections, len(axes_flat)): |
| axes_flat[idx].set_visible(False) |
| |
| |
| plt.subplots_adjust(hspace=0.3, wspace=0.2) |
| plt.savefig('combined_frb_detections.pdf', dpi=150, bbox_inches='tight', format='pdf') |
| plt.show() |
| |
| print(f"Combined plot saved as 'combined_frb_detections.png'") |
| |
| |
| print("\nDetection Summary:") |
| for i, detection in enumerate(all_detections): |
| print(f"{i+1}. {detection['file_name'][:50]}... - Confidence: {detection['confidence']:.4f}") |
| else: |
| print("No detections found across all files.") |