| import os |
| import torch |
| from torch.utils.data import Dataset |
| import json |
|
|
| from torch_geometric.data import HeteroData |
| import networkx as nx |
|
|
| class PowerFlowDataset(Dataset): |
| def __init__(self, data_root, split_txt, pq_len, pv_len, slack_len, mask_num=0): |
| self.data_root = data_root |
| with open(split_txt, 'r') as f: |
| self.file_list = [json.loads(line) for line in f] |
| self.pq_len = pq_len |
| self.pv_len = pv_len |
| self.slack_len = slack_len |
| self.mask_num = mask_num |
| |
| |
| self.flag_distance_once_calculated = False |
| self.shortest_paths = None |
| self.node_type_to_global_index = None |
| self.max_depth = 16 |
|
|
| def __len__(self): |
| return len(self.file_list) |
| |
| def update_max_depth(self): |
| tmp_distance = max(list(self.shortest_paths.values())) |
| if tmp_distance < self.max_depth: |
| self.max_depth = tmp_distance |
|
|
| def __getitem__(self, idx): |
| file_dict = self.file_list[idx] |
| data = torch.load(os.path.join(file_dict['file_path'])) |
| pq_num = data['PQ'].x.shape[0] |
| pv_num = data['PV'].x.shape[0] |
| slack_num = data['Slack'].x.shape[0] |
|
|
| Vm, Va, P_net, Q_net, Gs, Bs = 0, 1, 2, 3, 4, 5 |
|
|
| |
| |
| data['PQ'].y = data['PQ'].x[:,[Vm, Va, P_net, Q_net]].clone().detach() |
| data['PQ'].x[:, Vm] = 1.0 |
| data['PQ'].x[:, Va] = data['Slack'].x[0, Va].item() |
|
|
| non_zero_indices = torch.nonzero(data['PQ'].x[:, Q_net]) |
| data['PQ'].q_mask = torch.ones((pq_num,),dtype=torch.bool) |
| if self.mask_num > 0: |
| if file_dict.get('masked_node') is None: |
| mask_indices = non_zero_indices[torch.randperm(non_zero_indices.shape[0])[:self.mask_num]] |
| else: |
| mask_indices = file_dict['masked_node'][:self.mask_num] |
| data['PQ'].q_mask[mask_indices] = False |
| data['PQ'].x[~data['PQ'].q_mask, Q_net] = 0 |
|
|
| data['PV'].y = data['PV'].x[:,[Vm, Va, P_net, Q_net]].clone().detach() |
| data['PV'].x[:, Va] = data['Slack'].x[0, Va].item() |
| data['PV'].x[:, Q_net] = 0 |
|
|
| data['Slack'].y = data['Slack'].x[:,[Vm, Va, P_net, Q_net]].clone().detach() |
| data['Slack'].x[:, P_net] = 0 |
| data['Slack'].x[:, Q_net] = 0 |
|
|
| return data |
|
|