Spaces:
Sleeping
Sleeping
Upload 46 files
Browse files- matdeeplearn/__init__.py +3 -0
- matdeeplearn/__pycache__/__init__.cpython-311.pyc +0 -0
- matdeeplearn/__pycache__/__init__.cpython-37.pyc +0 -0
- matdeeplearn/__pycache__/config.cpython-37.pyc +0 -0
- matdeeplearn/__pycache__/models.cpython-37.pyc +0 -0
- matdeeplearn/__pycache__/process.cpython-37.pyc +0 -0
- matdeeplearn/__pycache__/process_HEA.cpython-37.pyc +0 -0
- matdeeplearn/__pycache__/training.cpython-37.pyc +0 -0
- matdeeplearn/models/__init__.py +16 -0
- matdeeplearn/models/__pycache__/MLP.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/__init__.cpython-311.pyc +0 -0
- matdeeplearn/models/__pycache__/__init__.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/cgcnn.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/cgcnn_nmr.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/cgcnn_test.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/cnnet.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/descriptor_nn.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/gcn.cpython-311.pyc +0 -0
- matdeeplearn/models/__pycache__/gcn.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/megnet.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/mpnn.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/schnet.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/test_cgcnn2.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/test_dosgnn.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/test_forces.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/test_matgnn.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/test_misc.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/testing.cpython-37.pyc +0 -0
- matdeeplearn/models/__pycache__/utils.cpython-37.pyc +0 -0
- matdeeplearn/models/cgcnn.py +174 -0
- matdeeplearn/models/descriptor_nn.py +69 -0
- matdeeplearn/models/gcn.py +173 -0
- matdeeplearn/models/megnet.py +371 -0
- matdeeplearn/models/mpnn.py +188 -0
- matdeeplearn/models/schnet.py +172 -0
- matdeeplearn/models/utils.py +23 -0
- matdeeplearn/process/__init__.py +1 -0
- matdeeplearn/process/__pycache__/__init__.cpython-37.pyc +0 -0
- matdeeplearn/process/__pycache__/process.cpython-37.pyc +0 -0
- matdeeplearn/process/dictionary_blank.json +1 -0
- matdeeplearn/process/dictionary_default.json +1 -0
- matdeeplearn/process/process.py +703 -0
- matdeeplearn/training/__init__.py +1 -0
- matdeeplearn/training/__pycache__/__init__.cpython-37.pyc +0 -0
- matdeeplearn/training/__pycache__/training.cpython-37.pyc +0 -0
- matdeeplearn/training/training.py +1290 -0
matdeeplearn/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .models import *
|
| 2 |
+
from .training import *
|
| 3 |
+
from .process import *
|
matdeeplearn/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (254 Bytes). View file
|
|
|
matdeeplearn/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (247 Bytes). View file
|
|
|
matdeeplearn/__pycache__/config.cpython-37.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
matdeeplearn/__pycache__/models.cpython-37.pyc
ADDED
|
Binary file (27.2 kB). View file
|
|
|
matdeeplearn/__pycache__/process.cpython-37.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
matdeeplearn/__pycache__/process_HEA.cpython-37.pyc
ADDED
|
Binary file (7.64 kB). View file
|
|
|
matdeeplearn/__pycache__/training.cpython-37.pyc
ADDED
|
Binary file (17.4 kB). View file
|
|
|
matdeeplearn/models/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .gcn import GCN
|
| 2 |
+
from .mpnn import MPNN
|
| 3 |
+
from .schnet import SchNet
|
| 4 |
+
from .cgcnn import CGCNN
|
| 5 |
+
from .megnet import MEGNet
|
| 6 |
+
from .descriptor_nn import SOAP, SM
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"GCN",
|
| 10 |
+
"MPNN",
|
| 11 |
+
"SchNet",
|
| 12 |
+
"CGCNN",
|
| 13 |
+
"MEGNet",
|
| 14 |
+
"SOAP",
|
| 15 |
+
"SM",
|
| 16 |
+
]
|
matdeeplearn/models/__pycache__/MLP.cpython-37.pyc
ADDED
|
Binary file (5.47 kB). View file
|
|
|
matdeeplearn/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (568 Bytes). View file
|
|
|
matdeeplearn/models/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (496 Bytes). View file
|
|
|
matdeeplearn/models/__pycache__/cgcnn.cpython-37.pyc
ADDED
|
Binary file (3.56 kB). View file
|
|
|
matdeeplearn/models/__pycache__/cgcnn_nmr.cpython-37.pyc
ADDED
|
Binary file (1.65 kB). View file
|
|
|
matdeeplearn/models/__pycache__/cgcnn_test.cpython-37.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
matdeeplearn/models/__pycache__/cnnet.cpython-37.pyc
ADDED
|
Binary file (4.94 kB). View file
|
|
|
matdeeplearn/models/__pycache__/descriptor_nn.cpython-37.pyc
ADDED
|
Binary file (2.47 kB). View file
|
|
|
matdeeplearn/models/__pycache__/gcn.cpython-311.pyc
ADDED
|
Binary file (8.74 kB). View file
|
|
|
matdeeplearn/models/__pycache__/gcn.cpython-37.pyc
ADDED
|
Binary file (3.56 kB). View file
|
|
|
matdeeplearn/models/__pycache__/megnet.cpython-37.pyc
ADDED
|
Binary file (8.58 kB). View file
|
|
|
matdeeplearn/models/__pycache__/mpnn.cpython-37.pyc
ADDED
|
Binary file (3.84 kB). View file
|
|
|
matdeeplearn/models/__pycache__/schnet.cpython-37.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|
matdeeplearn/models/__pycache__/test_cgcnn2.cpython-37.pyc
ADDED
|
Binary file (1.95 kB). View file
|
|
|
matdeeplearn/models/__pycache__/test_dosgnn.cpython-37.pyc
ADDED
|
Binary file (5.17 kB). View file
|
|
|
matdeeplearn/models/__pycache__/test_forces.cpython-37.pyc
ADDED
|
Binary file (6.1 kB). View file
|
|
|
matdeeplearn/models/__pycache__/test_matgnn.cpython-37.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
matdeeplearn/models/__pycache__/test_misc.cpython-37.pyc
ADDED
|
Binary file (2.67 kB). View file
|
|
|
matdeeplearn/models/__pycache__/testing.cpython-37.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
matdeeplearn/models/__pycache__/utils.cpython-37.pyc
ADDED
|
Binary file (1.29 kB). View file
|
|
|
matdeeplearn/models/cgcnn.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn import Sequential, Linear, BatchNorm1d
|
| 5 |
+
import torch_geometric
|
| 6 |
+
from torch_geometric.nn import (
|
| 7 |
+
Set2Set,
|
| 8 |
+
global_mean_pool,
|
| 9 |
+
global_add_pool,
|
| 10 |
+
global_max_pool,
|
| 11 |
+
CGConv,
|
| 12 |
+
)
|
| 13 |
+
from torch_scatter import scatter_mean, scatter_add, scatter_max, scatter
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# CGCNN
|
| 17 |
+
class CGCNN(torch.nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
data,
|
| 21 |
+
dim1=64,
|
| 22 |
+
dim2=64,
|
| 23 |
+
pre_fc_count=1,
|
| 24 |
+
gc_count=3,
|
| 25 |
+
post_fc_count=1,
|
| 26 |
+
pool="global_mean_pool",
|
| 27 |
+
pool_order="early",
|
| 28 |
+
batch_norm="True",
|
| 29 |
+
batch_track_stats="True",
|
| 30 |
+
act="relu",
|
| 31 |
+
dropout_rate=0.0,
|
| 32 |
+
**kwargs
|
| 33 |
+
):
|
| 34 |
+
super(CGCNN, self).__init__()
|
| 35 |
+
|
| 36 |
+
if batch_track_stats == "False":
|
| 37 |
+
self.batch_track_stats = False
|
| 38 |
+
else:
|
| 39 |
+
self.batch_track_stats = True
|
| 40 |
+
self.batch_norm = batch_norm
|
| 41 |
+
self.pool = pool
|
| 42 |
+
self.act = act
|
| 43 |
+
self.pool_order = pool_order
|
| 44 |
+
self.dropout_rate = dropout_rate
|
| 45 |
+
|
| 46 |
+
##Determine gc dimension dimension
|
| 47 |
+
assert gc_count > 0, "Need at least 1 GC layer"
|
| 48 |
+
if pre_fc_count == 0:
|
| 49 |
+
gc_dim = data.num_features
|
| 50 |
+
else:
|
| 51 |
+
gc_dim = dim1
|
| 52 |
+
##Determine post_fc dimension
|
| 53 |
+
if pre_fc_count == 0:
|
| 54 |
+
post_fc_dim = data.num_features
|
| 55 |
+
else:
|
| 56 |
+
post_fc_dim = dim1
|
| 57 |
+
##Determine output dimension length
|
| 58 |
+
if data[0].y.ndim == 0:
|
| 59 |
+
output_dim = 1
|
| 60 |
+
else:
|
| 61 |
+
output_dim = len(data[0].y[0])
|
| 62 |
+
|
| 63 |
+
##Set up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer)
|
| 64 |
+
if pre_fc_count > 0:
|
| 65 |
+
self.pre_lin_list = torch.nn.ModuleList()
|
| 66 |
+
for i in range(pre_fc_count):
|
| 67 |
+
if i == 0:
|
| 68 |
+
lin = torch.nn.Linear(data.num_features, dim1)
|
| 69 |
+
self.pre_lin_list.append(lin)
|
| 70 |
+
else:
|
| 71 |
+
lin = torch.nn.Linear(dim1, dim1)
|
| 72 |
+
self.pre_lin_list.append(lin)
|
| 73 |
+
elif pre_fc_count == 0:
|
| 74 |
+
self.pre_lin_list = torch.nn.ModuleList()
|
| 75 |
+
|
| 76 |
+
##Set up GNN layers
|
| 77 |
+
self.conv_list = torch.nn.ModuleList()
|
| 78 |
+
self.bn_list = torch.nn.ModuleList()
|
| 79 |
+
for i in range(gc_count):
|
| 80 |
+
conv = CGConv(
|
| 81 |
+
gc_dim, data.num_edge_features, aggr="mean", batch_norm=False
|
| 82 |
+
)
|
| 83 |
+
self.conv_list.append(conv)
|
| 84 |
+
##Track running stats set to false can prevent some instabilities; this causes other issues with different val/test performance from loader size?
|
| 85 |
+
if self.batch_norm == "True":
|
| 86 |
+
bn = BatchNorm1d(gc_dim, track_running_stats=self.batch_track_stats)
|
| 87 |
+
self.bn_list.append(bn)
|
| 88 |
+
|
| 89 |
+
##Set up post-GNN dense layers (NOTE: in v0.1 there was a minimum of 2 dense layers, and fc_count(now post_fc_count) added to this number. In the current version, the minimum is zero)
|
| 90 |
+
if post_fc_count > 0:
|
| 91 |
+
self.post_lin_list = torch.nn.ModuleList()
|
| 92 |
+
for i in range(post_fc_count):
|
| 93 |
+
if i == 0:
|
| 94 |
+
##Set2set pooling has doubled dimension
|
| 95 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 96 |
+
lin = torch.nn.Linear(post_fc_dim * 2, dim2)
|
| 97 |
+
else:
|
| 98 |
+
lin = torch.nn.Linear(post_fc_dim, dim2)
|
| 99 |
+
self.post_lin_list.append(lin)
|
| 100 |
+
else:
|
| 101 |
+
lin = torch.nn.Linear(dim2, dim2)
|
| 102 |
+
self.post_lin_list.append(lin)
|
| 103 |
+
self.lin_out = torch.nn.Linear(dim2, output_dim)
|
| 104 |
+
|
| 105 |
+
elif post_fc_count == 0:
|
| 106 |
+
self.post_lin_list = torch.nn.ModuleList()
|
| 107 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 108 |
+
self.lin_out = torch.nn.Linear(post_fc_dim*2, output_dim)
|
| 109 |
+
else:
|
| 110 |
+
self.lin_out = torch.nn.Linear(post_fc_dim, output_dim)
|
| 111 |
+
|
| 112 |
+
##Set up set2set pooling (if used)
|
| 113 |
+
##Should processing_setps be a hypereparameter?
|
| 114 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 115 |
+
self.set2set = Set2Set(post_fc_dim, processing_steps=3)
|
| 116 |
+
elif self.pool_order == "late" and self.pool == "set2set":
|
| 117 |
+
self.set2set = Set2Set(output_dim, processing_steps=3, num_layers=1)
|
| 118 |
+
# workaround for doubled dimension by set2set; if late pooling not reccomended to use set2set
|
| 119 |
+
self.lin_out_2 = torch.nn.Linear(output_dim * 2, output_dim)
|
| 120 |
+
|
| 121 |
+
def forward(self, data):
|
| 122 |
+
|
| 123 |
+
##Pre-GNN dense layers
|
| 124 |
+
for i in range(0, len(self.pre_lin_list)):
|
| 125 |
+
if i == 0:
|
| 126 |
+
out = self.pre_lin_list[i](data.x)
|
| 127 |
+
out = getattr(F, self.act)(out)
|
| 128 |
+
else:
|
| 129 |
+
out = self.pre_lin_list[i](out)
|
| 130 |
+
out = getattr(F, self.act)(out)
|
| 131 |
+
|
| 132 |
+
##GNN layers
|
| 133 |
+
for i in range(0, len(self.conv_list)):
|
| 134 |
+
if len(self.pre_lin_list) == 0 and i == 0:
|
| 135 |
+
if self.batch_norm == "True":
|
| 136 |
+
out = self.conv_list[i](data.x, data.edge_index, data.edge_attr)
|
| 137 |
+
out = self.bn_list[i](out)
|
| 138 |
+
else:
|
| 139 |
+
out = self.conv_list[i](data.x, data.edge_index, data.edge_attr)
|
| 140 |
+
else:
|
| 141 |
+
if self.batch_norm == "True":
|
| 142 |
+
out = self.conv_list[i](out, data.edge_index, data.edge_attr)
|
| 143 |
+
out = self.bn_list[i](out)
|
| 144 |
+
else:
|
| 145 |
+
out = self.conv_list[i](out, data.edge_index, data.edge_attr)
|
| 146 |
+
#out = getattr(F, self.act)(out)
|
| 147 |
+
out = F.dropout(out, p=self.dropout_rate, training=self.training)
|
| 148 |
+
|
| 149 |
+
##Post-GNN dense layers
|
| 150 |
+
if self.pool_order == "early":
|
| 151 |
+
if self.pool == "set2set":
|
| 152 |
+
out = self.set2set(out, data.batch)
|
| 153 |
+
else:
|
| 154 |
+
out = getattr(torch_geometric.nn, self.pool)(out, data.batch)
|
| 155 |
+
for i in range(0, len(self.post_lin_list)):
|
| 156 |
+
out = self.post_lin_list[i](out)
|
| 157 |
+
out = getattr(F, self.act)(out)
|
| 158 |
+
out = self.lin_out(out)
|
| 159 |
+
|
| 160 |
+
elif self.pool_order == "late":
|
| 161 |
+
for i in range(0, len(self.post_lin_list)):
|
| 162 |
+
out = self.post_lin_list[i](out)
|
| 163 |
+
out = getattr(F, self.act)(out)
|
| 164 |
+
out = self.lin_out(out)
|
| 165 |
+
if self.pool == "set2set":
|
| 166 |
+
out = self.set2set(out, data.batch)
|
| 167 |
+
out = self.lin_out_2(out)
|
| 168 |
+
else:
|
| 169 |
+
out = getattr(torch_geometric.nn, self.pool)(out, data.batch)
|
| 170 |
+
|
| 171 |
+
if out.shape[1] == 1:
|
| 172 |
+
return out.view(-1)
|
| 173 |
+
else:
|
| 174 |
+
return out
|
matdeeplearn/models/descriptor_nn.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
from torch.nn import (
|
| 5 |
+
Sequential,
|
| 6 |
+
Linear,
|
| 7 |
+
ReLU,
|
| 8 |
+
GRU,
|
| 9 |
+
Embedding,
|
| 10 |
+
BatchNorm1d,
|
| 11 |
+
Dropout,
|
| 12 |
+
LayerNorm,
|
| 13 |
+
)
|
| 14 |
+
from torch_geometric.nn import (
|
| 15 |
+
Set2Set,
|
| 16 |
+
global_mean_pool,
|
| 17 |
+
global_add_pool,
|
| 18 |
+
global_max_pool,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Sine matrix with neural network
|
| 23 |
+
class SM(torch.nn.Module):
|
| 24 |
+
def __init__(self, data, dim1=64, fc_count=1, **kwargs):
|
| 25 |
+
super(SM, self).__init__()
|
| 26 |
+
|
| 27 |
+
self.lin1 = torch.nn.Linear(data[0].extra_features_SM.shape[1], dim1)
|
| 28 |
+
|
| 29 |
+
self.lin_list = torch.nn.ModuleList(
|
| 30 |
+
[torch.nn.Linear(dim1, dim1) for i in range(fc_count)]
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
self.lin2 = torch.nn.Linear(dim1, 1)
|
| 34 |
+
|
| 35 |
+
def forward(self, data):
|
| 36 |
+
|
| 37 |
+
out = F.relu(self.lin1(data.extra_features_SM))
|
| 38 |
+
for layer in self.lin_list:
|
| 39 |
+
out = F.relu(layer(out))
|
| 40 |
+
out = self.lin2(out)
|
| 41 |
+
if out.shape[1] == 1:
|
| 42 |
+
return out.view(-1)
|
| 43 |
+
else:
|
| 44 |
+
return out
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Smooth Overlap of Atomic Positions with neural network
|
| 48 |
+
class SOAP(torch.nn.Module):
|
| 49 |
+
def __init__(self, data, dim1, fc_count, **kwargs):
|
| 50 |
+
super(SOAP, self).__init__()
|
| 51 |
+
|
| 52 |
+
self.lin1 = torch.nn.Linear(data[0].extra_features_SOAP.shape[1], dim1)
|
| 53 |
+
|
| 54 |
+
self.lin_list = torch.nn.ModuleList(
|
| 55 |
+
[torch.nn.Linear(dim1, dim1) for i in range(fc_count)]
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self.lin2 = torch.nn.Linear(dim1, 1)
|
| 59 |
+
|
| 60 |
+
def forward(self, data):
|
| 61 |
+
|
| 62 |
+
out = F.relu(self.lin1(data.extra_features_SOAP))
|
| 63 |
+
for layer in self.lin_list:
|
| 64 |
+
out = F.relu(layer(out))
|
| 65 |
+
out = self.lin2(out)
|
| 66 |
+
if out.shape[1] == 1:
|
| 67 |
+
return out.view(-1)
|
| 68 |
+
else:
|
| 69 |
+
return out
|
matdeeplearn/models/gcn.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn import Sequential, Linear, BatchNorm1d
|
| 5 |
+
import torch_geometric
|
| 6 |
+
from torch_geometric.nn import (
|
| 7 |
+
Set2Set,
|
| 8 |
+
global_mean_pool,
|
| 9 |
+
global_add_pool,
|
| 10 |
+
global_max_pool,
|
| 11 |
+
GCNConv,
|
| 12 |
+
)
|
| 13 |
+
from torch_scatter import scatter_mean, scatter_add, scatter_max, scatter
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# CGCNN
|
| 17 |
+
class GCN(torch.nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
data,
|
| 21 |
+
dim1=64,
|
| 22 |
+
dim2=64,
|
| 23 |
+
pre_fc_count=1,
|
| 24 |
+
gc_count=3,
|
| 25 |
+
post_fc_count=1,
|
| 26 |
+
pool="global_mean_pool",
|
| 27 |
+
pool_order="early",
|
| 28 |
+
batch_norm="True",
|
| 29 |
+
batch_track_stats="True",
|
| 30 |
+
act="relu",
|
| 31 |
+
dropout_rate=0.0,
|
| 32 |
+
**kwargs
|
| 33 |
+
):
|
| 34 |
+
super(GCN, self).__init__()
|
| 35 |
+
|
| 36 |
+
if batch_track_stats == "False":
|
| 37 |
+
self.batch_track_stats = False
|
| 38 |
+
else:
|
| 39 |
+
self.batch_track_stats = True
|
| 40 |
+
self.batch_norm = batch_norm
|
| 41 |
+
self.pool = pool
|
| 42 |
+
self.act = act
|
| 43 |
+
self.pool_order = pool_order
|
| 44 |
+
self.dropout_rate = dropout_rate
|
| 45 |
+
|
| 46 |
+
##Determine gc dimension dimension
|
| 47 |
+
assert gc_count > 0, "Need at least 1 GC layer"
|
| 48 |
+
if pre_fc_count == 0:
|
| 49 |
+
gc_dim = data.num_features
|
| 50 |
+
else:
|
| 51 |
+
gc_dim = dim1
|
| 52 |
+
##Determine post_fc dimension
|
| 53 |
+
if pre_fc_count == 0:
|
| 54 |
+
post_fc_dim = data.num_features
|
| 55 |
+
else:
|
| 56 |
+
post_fc_dim = dim1
|
| 57 |
+
##Determine output dimension length
|
| 58 |
+
if data[0].y.ndim == 0:
|
| 59 |
+
output_dim = 1
|
| 60 |
+
else:
|
| 61 |
+
output_dim = len(data[0].y[0])
|
| 62 |
+
|
| 63 |
+
##Set up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer)
|
| 64 |
+
if pre_fc_count > 0:
|
| 65 |
+
self.pre_lin_list = torch.nn.ModuleList()
|
| 66 |
+
for i in range(pre_fc_count):
|
| 67 |
+
if i == 0:
|
| 68 |
+
lin = torch.nn.Linear(data.num_features, dim1)
|
| 69 |
+
self.pre_lin_list.append(lin)
|
| 70 |
+
else:
|
| 71 |
+
lin = torch.nn.Linear(dim1, dim1)
|
| 72 |
+
self.pre_lin_list.append(lin)
|
| 73 |
+
elif pre_fc_count == 0:
|
| 74 |
+
self.pre_lin_list = torch.nn.ModuleList()
|
| 75 |
+
|
| 76 |
+
##Set up GNN layers
|
| 77 |
+
self.conv_list = torch.nn.ModuleList()
|
| 78 |
+
self.bn_list = torch.nn.ModuleList()
|
| 79 |
+
for i in range(gc_count):
|
| 80 |
+
conv = GCNConv(
|
| 81 |
+
gc_dim, gc_dim, improved=True, add_self_loops=False
|
| 82 |
+
)
|
| 83 |
+
self.conv_list.append(conv)
|
| 84 |
+
##Track running stats set to false can prevent some instabilities; this causes other issues with different val/test performance from loader size?
|
| 85 |
+
if self.batch_norm == "True":
|
| 86 |
+
bn = BatchNorm1d(gc_dim, track_running_stats=self.batch_track_stats)
|
| 87 |
+
self.bn_list.append(bn)
|
| 88 |
+
|
| 89 |
+
##Set up post-GNN dense layers (NOTE: in v0.1 there was a minimum of 2 dense layers, and fc_count(now post_fc_count) added to this number. In the current version, the minimum is zero)
|
| 90 |
+
if post_fc_count > 0:
|
| 91 |
+
self.post_lin_list = torch.nn.ModuleList()
|
| 92 |
+
for i in range(post_fc_count):
|
| 93 |
+
if i == 0:
|
| 94 |
+
##Set2set pooling has doubled dimension
|
| 95 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 96 |
+
lin = torch.nn.Linear(post_fc_dim * 2, dim2)
|
| 97 |
+
else:
|
| 98 |
+
lin = torch.nn.Linear(post_fc_dim, dim2)
|
| 99 |
+
self.post_lin_list.append(lin)
|
| 100 |
+
else:
|
| 101 |
+
lin = torch.nn.Linear(dim2, dim2)
|
| 102 |
+
self.post_lin_list.append(lin)
|
| 103 |
+
self.lin_out = torch.nn.Linear(dim2, output_dim)
|
| 104 |
+
|
| 105 |
+
elif post_fc_count == 0:
|
| 106 |
+
self.post_lin_list = torch.nn.ModuleList()
|
| 107 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 108 |
+
self.lin_out = torch.nn.Linear(post_fc_dim*2, output_dim)
|
| 109 |
+
else:
|
| 110 |
+
self.lin_out = torch.nn.Linear(post_fc_dim, output_dim)
|
| 111 |
+
|
| 112 |
+
##Set up set2set pooling (if used)
|
| 113 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 114 |
+
self.set2set = Set2Set(post_fc_dim, processing_steps=3)
|
| 115 |
+
elif self.pool_order == "late" and self.pool == "set2set":
|
| 116 |
+
self.set2set = Set2Set(output_dim, processing_steps=3, num_layers=1)
|
| 117 |
+
# workaround for doubled dimension by set2set; if late pooling not reccomended to use set2set
|
| 118 |
+
self.lin_out_2 = torch.nn.Linear(output_dim * 2, output_dim)
|
| 119 |
+
|
| 120 |
+
def forward(self, data):
|
| 121 |
+
|
| 122 |
+
##Pre-GNN dense layers
|
| 123 |
+
for i in range(0, len(self.pre_lin_list)):
|
| 124 |
+
if i == 0:
|
| 125 |
+
out = self.pre_lin_list[i](data.x)
|
| 126 |
+
out = getattr(F, self.act)(out)
|
| 127 |
+
else:
|
| 128 |
+
out = self.pre_lin_list[i](out)
|
| 129 |
+
out = getattr(F, self.act)(out)
|
| 130 |
+
|
| 131 |
+
##GNN layers
|
| 132 |
+
for i in range(0, len(self.conv_list)):
|
| 133 |
+
if len(self.pre_lin_list) == 0 and i == 0:
|
| 134 |
+
if self.batch_norm == "True":
|
| 135 |
+
out = self.conv_list[i](data.x, data.edge_index, data.edge_weight)
|
| 136 |
+
out = self.bn_list[i](out)
|
| 137 |
+
else:
|
| 138 |
+
out = self.conv_list[i](data.x, data.edge_index, data.edge_weight)
|
| 139 |
+
else:
|
| 140 |
+
if self.batch_norm == "True":
|
| 141 |
+
out = self.conv_list[i](out, data.edge_index, data.edge_weight)
|
| 142 |
+
out = self.bn_list[i](out)
|
| 143 |
+
else:
|
| 144 |
+
out = self.conv_list[i](out, data.edge_index, data.edge_weight)
|
| 145 |
+
out = getattr(F, self.act)(out)
|
| 146 |
+
out = F.dropout(out, p=self.dropout_rate, training=self.training)
|
| 147 |
+
|
| 148 |
+
##Post-GNN dense layers
|
| 149 |
+
if self.pool_order == "early":
|
| 150 |
+
if self.pool == "set2set":
|
| 151 |
+
out = self.set2set(out, data.batch)
|
| 152 |
+
else:
|
| 153 |
+
out = getattr(torch_geometric.nn, self.pool)(out, data.batch)
|
| 154 |
+
for i in range(0, len(self.post_lin_list)):
|
| 155 |
+
out = self.post_lin_list[i](out)
|
| 156 |
+
out = getattr(F, self.act)(out)
|
| 157 |
+
out = self.lin_out(out)
|
| 158 |
+
|
| 159 |
+
elif self.pool_order == "late":
|
| 160 |
+
for i in range(0, len(self.post_lin_list)):
|
| 161 |
+
out = self.post_lin_list[i](out)
|
| 162 |
+
out = getattr(F, self.act)(out)
|
| 163 |
+
out = self.lin_out(out)
|
| 164 |
+
if self.pool == "set2set":
|
| 165 |
+
out = self.set2set(out, data.batch)
|
| 166 |
+
out = self.lin_out_2(out)
|
| 167 |
+
else:
|
| 168 |
+
out = getattr(torch_geometric.nn, self.pool)(out, data.batch)
|
| 169 |
+
|
| 170 |
+
if out.shape[1] == 1:
|
| 171 |
+
return out.view(-1)
|
| 172 |
+
else:
|
| 173 |
+
return out
|
matdeeplearn/models/megnet.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn import Sequential, Linear, ReLU, BatchNorm1d
|
| 5 |
+
import torch_geometric
|
| 6 |
+
from torch_geometric.nn import (
|
| 7 |
+
Set2Set,
|
| 8 |
+
global_mean_pool,
|
| 9 |
+
global_add_pool,
|
| 10 |
+
global_max_pool,
|
| 11 |
+
MetaLayer,
|
| 12 |
+
)
|
| 13 |
+
from torch_scatter import scatter_mean, scatter_add, scatter_max, scatter
|
| 14 |
+
|
| 15 |
+
# Megnet
|
| 16 |
+
class Megnet_EdgeModel(torch.nn.Module):
|
| 17 |
+
def __init__(self, dim, act, batch_norm, batch_track_stats, dropout_rate, fc_layers=2):
|
| 18 |
+
super(Megnet_EdgeModel, self).__init__()
|
| 19 |
+
self.act=act
|
| 20 |
+
self.fc_layers = fc_layers
|
| 21 |
+
if batch_track_stats == "False":
|
| 22 |
+
self.batch_track_stats = False
|
| 23 |
+
else:
|
| 24 |
+
self.batch_track_stats = True
|
| 25 |
+
self.batch_norm = batch_norm
|
| 26 |
+
self.dropout_rate = dropout_rate
|
| 27 |
+
|
| 28 |
+
self.edge_mlp = torch.nn.ModuleList()
|
| 29 |
+
self.bn_list = torch.nn.ModuleList()
|
| 30 |
+
for i in range(self.fc_layers + 1):
|
| 31 |
+
if i == 0:
|
| 32 |
+
lin = torch.nn.Linear(dim * 4, dim)
|
| 33 |
+
self.edge_mlp.append(lin)
|
| 34 |
+
else:
|
| 35 |
+
lin = torch.nn.Linear(dim, dim)
|
| 36 |
+
self.edge_mlp.append(lin)
|
| 37 |
+
if self.batch_norm == "True":
|
| 38 |
+
bn = BatchNorm1d(dim, track_running_stats=self.batch_track_stats)
|
| 39 |
+
self.bn_list.append(bn)
|
| 40 |
+
|
| 41 |
+
def forward(self, src, dest, edge_attr, u, batch):
|
| 42 |
+
comb = torch.cat([src, dest, edge_attr, u[batch]], dim=1)
|
| 43 |
+
for i in range(0, len(self.edge_mlp)):
|
| 44 |
+
if i == 0:
|
| 45 |
+
out = self.edge_mlp[i](comb)
|
| 46 |
+
out = getattr(F, self.act)(out)
|
| 47 |
+
if self.batch_norm == "True":
|
| 48 |
+
out = self.bn_list[i](out)
|
| 49 |
+
out = F.dropout(out, p=self.dropout_rate, training=self.training)
|
| 50 |
+
else:
|
| 51 |
+
out = self.edge_mlp[i](out)
|
| 52 |
+
out = getattr(F, self.act)(out)
|
| 53 |
+
if self.batch_norm == "True":
|
| 54 |
+
out = self.bn_list[i](out)
|
| 55 |
+
out = F.dropout(out, p=self.dropout_rate, training=self.training)
|
| 56 |
+
return out
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Megnet_NodeModel(torch.nn.Module):
|
| 60 |
+
def __init__(self, dim, act, batch_norm, batch_track_stats, dropout_rate, fc_layers=2):
|
| 61 |
+
super(Megnet_NodeModel, self).__init__()
|
| 62 |
+
self.act=act
|
| 63 |
+
self.fc_layers = fc_layers
|
| 64 |
+
if batch_track_stats == "False":
|
| 65 |
+
self.batch_track_stats = False
|
| 66 |
+
else:
|
| 67 |
+
self.batch_track_stats = True
|
| 68 |
+
self.batch_norm = batch_norm
|
| 69 |
+
self.dropout_rate = dropout_rate
|
| 70 |
+
|
| 71 |
+
self.node_mlp = torch.nn.ModuleList()
|
| 72 |
+
self.bn_list = torch.nn.ModuleList()
|
| 73 |
+
for i in range(self.fc_layers + 1):
|
| 74 |
+
if i == 0:
|
| 75 |
+
lin = torch.nn.Linear(dim * 3, dim)
|
| 76 |
+
self.node_mlp.append(lin)
|
| 77 |
+
else:
|
| 78 |
+
lin = torch.nn.Linear(dim, dim)
|
| 79 |
+
self.node_mlp.append(lin)
|
| 80 |
+
if self.batch_norm == "True":
|
| 81 |
+
bn = BatchNorm1d(dim, track_running_stats=self.batch_track_stats)
|
| 82 |
+
self.bn_list.append(bn)
|
| 83 |
+
|
| 84 |
+
def forward(self, x, edge_index, edge_attr, u, batch):
|
| 85 |
+
# row, col = edge_index
|
| 86 |
+
v_e = scatter_mean(edge_attr, edge_index[0, :], dim=0)
|
| 87 |
+
comb = torch.cat([x, v_e, u[batch]], dim=1)
|
| 88 |
+
for i in range(0, len(self.node_mlp)):
|
| 89 |
+
if i == 0:
|
| 90 |
+
out = self.node_mlp[i](comb)
|
| 91 |
+
out = getattr(F, self.act)(out)
|
| 92 |
+
if self.batch_norm == "True":
|
| 93 |
+
out = self.bn_list[i](out)
|
| 94 |
+
out = F.dropout(out, p=self.dropout_rate, training=self.training)
|
| 95 |
+
else:
|
| 96 |
+
out = self.node_mlp[i](out)
|
| 97 |
+
out = getattr(F, self.act)(out)
|
| 98 |
+
if self.batch_norm == "True":
|
| 99 |
+
out = self.bn_list[i](out)
|
| 100 |
+
out = F.dropout(out, p=self.dropout_rate, training=self.training)
|
| 101 |
+
return out
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Megnet_GlobalModel(torch.nn.Module):
|
| 105 |
+
def __init__(self, dim, act, batch_norm, batch_track_stats, dropout_rate, fc_layers=2):
|
| 106 |
+
super(Megnet_GlobalModel, self).__init__()
|
| 107 |
+
self.act=act
|
| 108 |
+
self.fc_layers = fc_layers
|
| 109 |
+
if batch_track_stats == "False":
|
| 110 |
+
self.batch_track_stats = False
|
| 111 |
+
else:
|
| 112 |
+
self.batch_track_stats = True
|
| 113 |
+
self.batch_norm = batch_norm
|
| 114 |
+
self.dropout_rate = dropout_rate
|
| 115 |
+
|
| 116 |
+
self.global_mlp = torch.nn.ModuleList()
|
| 117 |
+
self.bn_list = torch.nn.ModuleList()
|
| 118 |
+
for i in range(self.fc_layers + 1):
|
| 119 |
+
if i == 0:
|
| 120 |
+
lin = torch.nn.Linear(dim * 3, dim)
|
| 121 |
+
self.global_mlp.append(lin)
|
| 122 |
+
else:
|
| 123 |
+
lin = torch.nn.Linear(dim, dim)
|
| 124 |
+
self.global_mlp.append(lin)
|
| 125 |
+
if self.batch_norm == "True":
|
| 126 |
+
bn = BatchNorm1d(dim, track_running_stats=self.batch_track_stats)
|
| 127 |
+
self.bn_list.append(bn)
|
| 128 |
+
|
| 129 |
+
def forward(self, x, edge_index, edge_attr, u, batch):
|
| 130 |
+
u_e = scatter_mean(edge_attr, edge_index[0, :], dim=0)
|
| 131 |
+
u_e = scatter_mean(u_e, batch, dim=0)
|
| 132 |
+
u_v = scatter_mean(x, batch, dim=0)
|
| 133 |
+
comb = torch.cat([u_e, u_v, u], dim=1)
|
| 134 |
+
for i in range(0, len(self.global_mlp)):
|
| 135 |
+
if i == 0:
|
| 136 |
+
out = self.global_mlp[i](comb)
|
| 137 |
+
out = getattr(F, self.act)(out)
|
| 138 |
+
if self.batch_norm == "True":
|
| 139 |
+
out = self.bn_list[i](out)
|
| 140 |
+
out = F.dropout(out, p=self.dropout_rate, training=self.training)
|
| 141 |
+
else:
|
| 142 |
+
out = self.global_mlp[i](out)
|
| 143 |
+
out = getattr(F, self.act)(out)
|
| 144 |
+
if self.batch_norm == "True":
|
| 145 |
+
out = self.bn_list[i](out)
|
| 146 |
+
out = F.dropout(out, p=self.dropout_rate, training=self.training)
|
| 147 |
+
return out
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class MEGNet(torch.nn.Module):
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
data,
|
| 154 |
+
dim1=64,
|
| 155 |
+
dim2=64,
|
| 156 |
+
dim3=64,
|
| 157 |
+
pre_fc_count=1,
|
| 158 |
+
gc_count=3,
|
| 159 |
+
gc_fc_count=2,
|
| 160 |
+
post_fc_count=1,
|
| 161 |
+
pool="global_mean_pool",
|
| 162 |
+
pool_order="early",
|
| 163 |
+
batch_norm="True",
|
| 164 |
+
batch_track_stats="True",
|
| 165 |
+
act="relu",
|
| 166 |
+
dropout_rate=0.0,
|
| 167 |
+
**kwargs
|
| 168 |
+
):
|
| 169 |
+
super(MEGNet, self).__init__()
|
| 170 |
+
|
| 171 |
+
if batch_track_stats == "False":
|
| 172 |
+
self.batch_track_stats = False
|
| 173 |
+
else:
|
| 174 |
+
self.batch_track_stats = True
|
| 175 |
+
self.batch_norm = batch_norm
|
| 176 |
+
self.pool = pool
|
| 177 |
+
if pool == "global_mean_pool":
|
| 178 |
+
self.pool_reduce="mean"
|
| 179 |
+
elif pool== "global_max_pool":
|
| 180 |
+
self.pool_reduce="max"
|
| 181 |
+
elif pool== "global_sum_pool":
|
| 182 |
+
self.pool_reduce="sum"
|
| 183 |
+
self.act = act
|
| 184 |
+
self.pool_order = pool_order
|
| 185 |
+
self.dropout_rate = dropout_rate
|
| 186 |
+
|
| 187 |
+
##Determine gc dimension dimension
|
| 188 |
+
assert gc_count > 0, "Need at least 1 GC layer"
|
| 189 |
+
if pre_fc_count == 0:
|
| 190 |
+
gc_dim = data.num_features
|
| 191 |
+
else:
|
| 192 |
+
gc_dim = dim1
|
| 193 |
+
##Determine post_fc dimension
|
| 194 |
+
post_fc_dim = dim3
|
| 195 |
+
##Determine output dimension length
|
| 196 |
+
if data[0].y.ndim == 0:
|
| 197 |
+
output_dim = 1
|
| 198 |
+
else:
|
| 199 |
+
output_dim = len(data[0].y[0])
|
| 200 |
+
|
| 201 |
+
##Set up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer)
|
| 202 |
+
if pre_fc_count > 0:
|
| 203 |
+
self.pre_lin_list = torch.nn.ModuleList()
|
| 204 |
+
for i in range(pre_fc_count):
|
| 205 |
+
if i == 0:
|
| 206 |
+
lin = torch.nn.Linear(data.num_features, dim1)
|
| 207 |
+
self.pre_lin_list.append(lin)
|
| 208 |
+
else:
|
| 209 |
+
lin = torch.nn.Linear(dim1, dim1)
|
| 210 |
+
self.pre_lin_list.append(lin)
|
| 211 |
+
elif pre_fc_count == 0:
|
| 212 |
+
self.pre_lin_list = torch.nn.ModuleList()
|
| 213 |
+
|
| 214 |
+
##Set up GNN layers
|
| 215 |
+
self.e_embed_list = torch.nn.ModuleList()
|
| 216 |
+
self.x_embed_list = torch.nn.ModuleList()
|
| 217 |
+
self.u_embed_list = torch.nn.ModuleList()
|
| 218 |
+
self.conv_list = torch.nn.ModuleList()
|
| 219 |
+
self.bn_list = torch.nn.ModuleList()
|
| 220 |
+
for i in range(gc_count):
|
| 221 |
+
if i == 0:
|
| 222 |
+
e_embed = Sequential(
|
| 223 |
+
Linear(data.num_edge_features, dim3), ReLU(), Linear(dim3, dim3), ReLU()
|
| 224 |
+
)
|
| 225 |
+
x_embed = Sequential(
|
| 226 |
+
Linear(gc_dim, dim3), ReLU(), Linear(dim3, dim3), ReLU()
|
| 227 |
+
)
|
| 228 |
+
u_embed = Sequential(
|
| 229 |
+
Linear((data[0].u.shape[1]), dim3), ReLU(), Linear(dim3, dim3), ReLU()
|
| 230 |
+
)
|
| 231 |
+
self.e_embed_list.append(e_embed)
|
| 232 |
+
self.x_embed_list.append(x_embed)
|
| 233 |
+
self.u_embed_list.append(u_embed)
|
| 234 |
+
self.conv_list.append(
|
| 235 |
+
MetaLayer(
|
| 236 |
+
Megnet_EdgeModel(dim3, self.act, self.batch_norm, self.batch_track_stats, self.dropout_rate, gc_fc_count),
|
| 237 |
+
Megnet_NodeModel(dim3, self.act, self.batch_norm, self.batch_track_stats, self.dropout_rate, gc_fc_count),
|
| 238 |
+
Megnet_GlobalModel(dim3, self.act, self.batch_norm, self.batch_track_stats, self.dropout_rate, gc_fc_count),
|
| 239 |
+
)
|
| 240 |
+
)
|
| 241 |
+
elif i > 0:
|
| 242 |
+
e_embed = Sequential(Linear(dim3, dim3), ReLU(), Linear(dim3, dim3), ReLU())
|
| 243 |
+
x_embed = Sequential(Linear(dim3, dim3), ReLU(), Linear(dim3, dim3), ReLU())
|
| 244 |
+
u_embed = Sequential(Linear(dim3, dim3), ReLU(), Linear(dim3, dim3), ReLU())
|
| 245 |
+
self.e_embed_list.append(e_embed)
|
| 246 |
+
self.x_embed_list.append(x_embed)
|
| 247 |
+
self.u_embed_list.append(u_embed)
|
| 248 |
+
self.conv_list.append(
|
| 249 |
+
MetaLayer(
|
| 250 |
+
Megnet_EdgeModel(dim3, self.act, self.batch_norm, self.batch_track_stats, self.dropout_rate, gc_fc_count),
|
| 251 |
+
Megnet_NodeModel(dim3, self.act, self.batch_norm, self.batch_track_stats, self.dropout_rate, gc_fc_count),
|
| 252 |
+
Megnet_GlobalModel(dim3, self.act, self.batch_norm, self.batch_track_stats, self.dropout_rate, gc_fc_count),
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
##Set up post-GNN dense layers (NOTE: in v0.1 there was a minimum of 2 dense layers, and fc_count(now post_fc_count) added to this number. In the current version, the minimum is zero)
|
| 257 |
+
if post_fc_count > 0:
|
| 258 |
+
self.post_lin_list = torch.nn.ModuleList()
|
| 259 |
+
for i in range(post_fc_count):
|
| 260 |
+
if i == 0:
|
| 261 |
+
##Set2set pooling has doubled dimension
|
| 262 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 263 |
+
lin = torch.nn.Linear(post_fc_dim * 5, dim2)
|
| 264 |
+
elif self.pool_order == "early" and self.pool != "set2set":
|
| 265 |
+
lin = torch.nn.Linear(post_fc_dim * 3, dim2)
|
| 266 |
+
elif self.pool_order == "late":
|
| 267 |
+
lin = torch.nn.Linear(post_fc_dim, dim2)
|
| 268 |
+
self.post_lin_list.append(lin)
|
| 269 |
+
else:
|
| 270 |
+
lin = torch.nn.Linear(dim2, dim2)
|
| 271 |
+
self.post_lin_list.append(lin)
|
| 272 |
+
self.lin_out = torch.nn.Linear(dim2, output_dim)
|
| 273 |
+
|
| 274 |
+
elif post_fc_count == 0:
|
| 275 |
+
self.post_lin_list = torch.nn.ModuleList()
|
| 276 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 277 |
+
self.lin_out = torch.nn.Linear(post_fc_dim * 5, output_dim)
|
| 278 |
+
elif self.pool_order == "early" and self.pool != "set2set":
|
| 279 |
+
self.lin_out = torch.nn.Linear(post_fc_dim * 3, output_dim)
|
| 280 |
+
else:
|
| 281 |
+
self.lin_out = torch.nn.Linear(post_fc_dim, output_dim)
|
| 282 |
+
|
| 283 |
+
##Set up set2set pooling (if used)
|
| 284 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 285 |
+
self.set2set_x = Set2Set(post_fc_dim, processing_steps=3)
|
| 286 |
+
self.set2set_e = Set2Set(post_fc_dim, processing_steps=3)
|
| 287 |
+
elif self.pool_order == "late" and self.pool == "set2set":
|
| 288 |
+
self.set2set_x = Set2Set(output_dim, processing_steps=3, num_layers=1)
|
| 289 |
+
# workaround for doubled dimension by set2set; if late pooling not reccomended to use set2set
|
| 290 |
+
self.lin_out_2 = torch.nn.Linear(output_dim * 2, output_dim)
|
| 291 |
+
|
| 292 |
+
def forward(self, data):
|
| 293 |
+
|
| 294 |
+
##Pre-GNN dense layers
|
| 295 |
+
for i in range(0, len(self.pre_lin_list)):
|
| 296 |
+
if i == 0:
|
| 297 |
+
out = self.pre_lin_list[i](data.x)
|
| 298 |
+
out = getattr(F, self.act)(out)
|
| 299 |
+
else:
|
| 300 |
+
out = self.pre_lin_list[i](out)
|
| 301 |
+
out = getattr(F, self.act)(out)
|
| 302 |
+
|
| 303 |
+
##GNN layers
|
| 304 |
+
for i in range(0, len(self.conv_list)):
|
| 305 |
+
if i == 0:
|
| 306 |
+
if len(self.pre_lin_list) == 0:
|
| 307 |
+
e_temp = self.e_embed_list[i](data.edge_attr)
|
| 308 |
+
x_temp = self.x_embed_list[i](data.x)
|
| 309 |
+
u_temp = self.u_embed_list[i](data.u)
|
| 310 |
+
x_out, e_out, u_out = self.conv_list[i](
|
| 311 |
+
x_temp, data.edge_index, e_temp, u_temp, data.batch
|
| 312 |
+
)
|
| 313 |
+
x = torch.add(x_out, x_temp)
|
| 314 |
+
e = torch.add(e_out, e_temp)
|
| 315 |
+
u = torch.add(u_out, u_temp)
|
| 316 |
+
else:
|
| 317 |
+
e_temp = self.e_embed_list[i](data.edge_attr)
|
| 318 |
+
x_temp = self.x_embed_list[i](out)
|
| 319 |
+
u_temp = self.u_embed_list[i](data.u)
|
| 320 |
+
x_out, e_out, u_out = self.conv_list[i](
|
| 321 |
+
x_temp, data.edge_index, e_temp, u_temp, data.batch
|
| 322 |
+
)
|
| 323 |
+
x = torch.add(x_out, x_temp)
|
| 324 |
+
e = torch.add(e_out, e_temp)
|
| 325 |
+
u = torch.add(u_out, u_temp)
|
| 326 |
+
|
| 327 |
+
elif i > 0:
|
| 328 |
+
e_temp = self.e_embed_list[i](e)
|
| 329 |
+
x_temp = self.x_embed_list[i](x)
|
| 330 |
+
u_temp = self.u_embed_list[i](u)
|
| 331 |
+
x_out, e_out, u_out = self.conv_list[i](
|
| 332 |
+
x_temp, data.edge_index, e_temp, u_temp, data.batch
|
| 333 |
+
)
|
| 334 |
+
x = torch.add(x_out, x)
|
| 335 |
+
e = torch.add(e_out, e)
|
| 336 |
+
u = torch.add(u_out, u)
|
| 337 |
+
|
| 338 |
+
##Post-GNN dense layers
|
| 339 |
+
if self.pool_order == "early":
|
| 340 |
+
if self.pool == "set2set":
|
| 341 |
+
x_pool = self.set2set_x(x, data.batch)
|
| 342 |
+
e = scatter(e, data.edge_index[0, :], dim=0, reduce="mean")
|
| 343 |
+
e_pool = self.set2set_e(e, data.batch)
|
| 344 |
+
out = torch.cat([x_pool, e_pool, u], dim=1)
|
| 345 |
+
else:
|
| 346 |
+
x_pool = scatter(x, data.batch, dim=0, reduce=self.pool_reduce)
|
| 347 |
+
e_pool = scatter(e, data.edge_index[0, :], dim=0, reduce=self.pool_reduce)
|
| 348 |
+
e_pool = scatter(e_pool, data.batch, dim=0, reduce=self.pool_reduce)
|
| 349 |
+
out = torch.cat([x_pool, e_pool, u], dim=1)
|
| 350 |
+
for i in range(0, len(self.post_lin_list)):
|
| 351 |
+
out = self.post_lin_list[i](out)
|
| 352 |
+
out = getattr(F, self.act)(out)
|
| 353 |
+
out = self.lin_out(out)
|
| 354 |
+
|
| 355 |
+
##currently only uses node features for late pooling
|
| 356 |
+
elif self.pool_order == "late":
|
| 357 |
+
out = x
|
| 358 |
+
for i in range(0, len(self.post_lin_list)):
|
| 359 |
+
out = self.post_lin_list[i](out)
|
| 360 |
+
out = getattr(F, self.act)(out)
|
| 361 |
+
out = self.lin_out(out)
|
| 362 |
+
if self.pool == "set2set":
|
| 363 |
+
out = self.set2set_x(out, data.batch)
|
| 364 |
+
out = self.lin_out_2(out)
|
| 365 |
+
else:
|
| 366 |
+
out = getattr(torch_geometric.nn, self.pool)(out, data.batch)
|
| 367 |
+
|
| 368 |
+
if out.shape[1] == 1:
|
| 369 |
+
return out.view(-1)
|
| 370 |
+
else:
|
| 371 |
+
return out
|
matdeeplearn/models/mpnn.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn import Sequential, Linear, ReLU, BatchNorm1d, GRU
|
| 5 |
+
import torch_geometric
|
| 6 |
+
from torch_geometric.nn import (
|
| 7 |
+
Set2Set,
|
| 8 |
+
global_mean_pool,
|
| 9 |
+
global_add_pool,
|
| 10 |
+
global_max_pool,
|
| 11 |
+
NNConv,
|
| 12 |
+
)
|
| 13 |
+
from torch_scatter import scatter_mean, scatter_add, scatter_max, scatter
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# CGCNN
|
| 17 |
+
class MPNN(torch.nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
data,
|
| 21 |
+
dim1=64,
|
| 22 |
+
dim2=64,
|
| 23 |
+
dim3=64,
|
| 24 |
+
pre_fc_count=1,
|
| 25 |
+
gc_count=3,
|
| 26 |
+
post_fc_count=1,
|
| 27 |
+
pool="global_mean_pool",
|
| 28 |
+
pool_order="early",
|
| 29 |
+
batch_norm="True",
|
| 30 |
+
batch_track_stats="True",
|
| 31 |
+
act="relu",
|
| 32 |
+
dropout_rate=0.0,
|
| 33 |
+
**kwargs
|
| 34 |
+
):
|
| 35 |
+
super(MPNN, self).__init__()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if batch_track_stats == "False":
|
| 39 |
+
self.batch_track_stats = False
|
| 40 |
+
else:
|
| 41 |
+
self.batch_track_stats = True
|
| 42 |
+
self.batch_norm = batch_norm
|
| 43 |
+
self.pool = pool
|
| 44 |
+
self.act = act
|
| 45 |
+
self.pool_order = pool_order
|
| 46 |
+
self.dropout_rate = dropout_rate
|
| 47 |
+
|
| 48 |
+
##Determine gc dimension dimension
|
| 49 |
+
assert gc_count > 0, "Need at least 1 GC layer"
|
| 50 |
+
if pre_fc_count == 0:
|
| 51 |
+
gc_dim = data.num_features
|
| 52 |
+
else:
|
| 53 |
+
gc_dim = dim1
|
| 54 |
+
##Determine post_fc dimension
|
| 55 |
+
if pre_fc_count == 0:
|
| 56 |
+
post_fc_dim = data.num_features
|
| 57 |
+
else:
|
| 58 |
+
post_fc_dim = dim1
|
| 59 |
+
##Determine output dimension length
|
| 60 |
+
if data[0].y.ndim == 0:
|
| 61 |
+
output_dim = 1
|
| 62 |
+
else:
|
| 63 |
+
output_dim = len(data[0].y[0])
|
| 64 |
+
|
| 65 |
+
##Set up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer)
|
| 66 |
+
if pre_fc_count > 0:
|
| 67 |
+
self.pre_lin_list = torch.nn.ModuleList()
|
| 68 |
+
for i in range(pre_fc_count):
|
| 69 |
+
if i == 0:
|
| 70 |
+
lin = torch.nn.Linear(data.num_features, dim1)
|
| 71 |
+
self.pre_lin_list.append(lin)
|
| 72 |
+
else:
|
| 73 |
+
lin = torch.nn.Linear(dim1, dim1)
|
| 74 |
+
self.pre_lin_list.append(lin)
|
| 75 |
+
elif pre_fc_count == 0:
|
| 76 |
+
self.pre_lin_list = torch.nn.ModuleList()
|
| 77 |
+
|
| 78 |
+
##Set up GNN layers
|
| 79 |
+
self.conv_list = torch.nn.ModuleList()
|
| 80 |
+
self.gru_list = torch.nn.ModuleList()
|
| 81 |
+
self.bn_list = torch.nn.ModuleList()
|
| 82 |
+
for i in range(gc_count):
|
| 83 |
+
nn = Sequential(
|
| 84 |
+
Linear(data.num_edge_features, dim3), ReLU(), Linear(dim3, gc_dim * gc_dim)
|
| 85 |
+
)
|
| 86 |
+
conv = NNConv(
|
| 87 |
+
gc_dim, gc_dim, nn, aggr="mean"
|
| 88 |
+
)
|
| 89 |
+
self.conv_list.append(conv)
|
| 90 |
+
gru = GRU(gc_dim, gc_dim)
|
| 91 |
+
self.gru_list.append(gru)
|
| 92 |
+
|
| 93 |
+
##Track running stats set to false can prevent some instabilities; this causes other issues with different val/test performance from loader size?
|
| 94 |
+
if self.batch_norm == "True":
|
| 95 |
+
bn = BatchNorm1d(gc_dim, track_running_stats=self.batch_track_stats)
|
| 96 |
+
self.bn_list.append(bn)
|
| 97 |
+
|
| 98 |
+
##Set up post-GNN dense layers (NOTE: in v0.1 there was a minimum of 2 dense layers, and fc_count(now post_fc_count) added to this number. In the current version, the minimum is zero)
|
| 99 |
+
if post_fc_count > 0:
|
| 100 |
+
self.post_lin_list = torch.nn.ModuleList()
|
| 101 |
+
for i in range(post_fc_count):
|
| 102 |
+
if i == 0:
|
| 103 |
+
##Set2set pooling has doubled dimension
|
| 104 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 105 |
+
lin = torch.nn.Linear(post_fc_dim * 2, dim2)
|
| 106 |
+
else:
|
| 107 |
+
lin = torch.nn.Linear(post_fc_dim, dim2)
|
| 108 |
+
self.post_lin_list.append(lin)
|
| 109 |
+
else:
|
| 110 |
+
lin = torch.nn.Linear(dim2, dim2)
|
| 111 |
+
self.post_lin_list.append(lin)
|
| 112 |
+
self.lin_out = torch.nn.Linear(dim2, output_dim)
|
| 113 |
+
|
| 114 |
+
elif post_fc_count == 0:
|
| 115 |
+
self.post_lin_list = torch.nn.ModuleList()
|
| 116 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 117 |
+
self.lin_out = torch.nn.Linear(post_fc_dim*2, output_dim)
|
| 118 |
+
else:
|
| 119 |
+
self.lin_out = torch.nn.Linear(post_fc_dim, output_dim)
|
| 120 |
+
|
| 121 |
+
##Set up set2set pooling (if used)
|
| 122 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 123 |
+
self.set2set = Set2Set(post_fc_dim, processing_steps=3)
|
| 124 |
+
elif self.pool_order == "late" and self.pool == "set2set":
|
| 125 |
+
self.set2set = Set2Set(output_dim, processing_steps=3, num_layers=1)
|
| 126 |
+
# workaround for doubled dimension by set2set; if late pooling not reccomended to use set2set
|
| 127 |
+
self.lin_out_2 = torch.nn.Linear(output_dim * 2, output_dim)
|
| 128 |
+
|
| 129 |
+
def forward(self, data):
|
| 130 |
+
|
| 131 |
+
##Pre-GNN dense layers
|
| 132 |
+
for i in range(0, len(self.pre_lin_list)):
|
| 133 |
+
if i == 0:
|
| 134 |
+
out = self.pre_lin_list[i](data.x)
|
| 135 |
+
out = getattr(F, self.act)(out)
|
| 136 |
+
else:
|
| 137 |
+
out = self.pre_lin_list[i](out)
|
| 138 |
+
out = getattr(F, self.act)(out)
|
| 139 |
+
|
| 140 |
+
##GNN layers
|
| 141 |
+
if len(self.pre_lin_list) == 0:
|
| 142 |
+
h = data.x.unsqueeze(0)
|
| 143 |
+
else:
|
| 144 |
+
h = out.unsqueeze(0)
|
| 145 |
+
for i in range(0, len(self.conv_list)):
|
| 146 |
+
if len(self.pre_lin_list) == 0 and i == 0:
|
| 147 |
+
if self.batch_norm == "True":
|
| 148 |
+
m = self.conv_list[i](data.x, data.edge_index, data.edge_attr)
|
| 149 |
+
m = self.bn_list[i](m)
|
| 150 |
+
else:
|
| 151 |
+
m = self.conv_list[i](data.x, data.edge_index, data.edge_attr)
|
| 152 |
+
else:
|
| 153 |
+
if self.batch_norm == "True":
|
| 154 |
+
m = self.conv_list[i](out, data.edge_index, data.edge_attr)
|
| 155 |
+
m = self.bn_list[i](m)
|
| 156 |
+
else:
|
| 157 |
+
m = self.conv_list[i](out, data.edge_index, data.edge_attr)
|
| 158 |
+
m = getattr(F, self.act)(m)
|
| 159 |
+
m = F.dropout(m, p=self.dropout_rate, training=self.training)
|
| 160 |
+
out, h = self.gru_list[i](m.unsqueeze(0), h)
|
| 161 |
+
out = out.squeeze(0)
|
| 162 |
+
|
| 163 |
+
##Post-GNN dense layers
|
| 164 |
+
if self.pool_order == "early":
|
| 165 |
+
if self.pool == "set2set":
|
| 166 |
+
out = self.set2set(out, data.batch)
|
| 167 |
+
else:
|
| 168 |
+
out = getattr(torch_geometric.nn, self.pool)(out, data.batch)
|
| 169 |
+
for i in range(0, len(self.post_lin_list)):
|
| 170 |
+
out = self.post_lin_list[i](out)
|
| 171 |
+
out = getattr(F, self.act)(out)
|
| 172 |
+
out = self.lin_out(out)
|
| 173 |
+
|
| 174 |
+
elif self.pool_order == "late":
|
| 175 |
+
for i in range(0, len(self.post_lin_list)):
|
| 176 |
+
out = self.post_lin_list[i](out)
|
| 177 |
+
out = getattr(F, self.act)(out)
|
| 178 |
+
out = self.lin_out(out)
|
| 179 |
+
if self.pool == "set2set":
|
| 180 |
+
out = self.set2set(out, data.batch)
|
| 181 |
+
out = self.lin_out_2(out)
|
| 182 |
+
else:
|
| 183 |
+
out = getattr(torch_geometric.nn, self.pool)(out, data.batch)
|
| 184 |
+
|
| 185 |
+
if out.shape[1] == 1:
|
| 186 |
+
return out.view(-1)
|
| 187 |
+
else:
|
| 188 |
+
return out
|
matdeeplearn/models/schnet.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn import Sequential, Linear, BatchNorm1d
|
| 5 |
+
import torch_geometric
|
| 6 |
+
from torch_geometric.nn import (
|
| 7 |
+
Set2Set,
|
| 8 |
+
global_mean_pool,
|
| 9 |
+
global_add_pool,
|
| 10 |
+
global_max_pool,
|
| 11 |
+
)
|
| 12 |
+
from torch_scatter import scatter_mean, scatter_add, scatter_max, scatter
|
| 13 |
+
from torch_geometric.nn.models.schnet import InteractionBlock
|
| 14 |
+
|
| 15 |
+
# Schnet
|
| 16 |
+
class SchNet(torch.nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
data,
|
| 20 |
+
dim1=64,
|
| 21 |
+
dim2=64,
|
| 22 |
+
dim3=64,
|
| 23 |
+
cutoff=8,
|
| 24 |
+
pre_fc_count=1,
|
| 25 |
+
gc_count=3,
|
| 26 |
+
post_fc_count=1,
|
| 27 |
+
pool="global_mean_pool",
|
| 28 |
+
pool_order="early",
|
| 29 |
+
batch_norm="True",
|
| 30 |
+
batch_track_stats="True",
|
| 31 |
+
act="relu",
|
| 32 |
+
dropout_rate=0.0,
|
| 33 |
+
**kwargs
|
| 34 |
+
):
|
| 35 |
+
super(SchNet, self).__init__()
|
| 36 |
+
|
| 37 |
+
if batch_track_stats == "False":
|
| 38 |
+
self.batch_track_stats = False
|
| 39 |
+
else:
|
| 40 |
+
self.batch_track_stats = True
|
| 41 |
+
self.batch_norm = batch_norm
|
| 42 |
+
self.pool = pool
|
| 43 |
+
self.act = act
|
| 44 |
+
self.pool_order = pool_order
|
| 45 |
+
self.dropout_rate = dropout_rate
|
| 46 |
+
|
| 47 |
+
##Determine gc dimension dimension
|
| 48 |
+
assert gc_count > 0, "Need at least 1 GC layer"
|
| 49 |
+
if pre_fc_count == 0:
|
| 50 |
+
gc_dim = data.num_features
|
| 51 |
+
else:
|
| 52 |
+
gc_dim = dim1
|
| 53 |
+
##Determine post_fc dimension
|
| 54 |
+
if pre_fc_count == 0:
|
| 55 |
+
post_fc_dim = data.num_features
|
| 56 |
+
else:
|
| 57 |
+
post_fc_dim = dim1
|
| 58 |
+
##Determine output dimension length
|
| 59 |
+
if data[0].y.ndim == 0:
|
| 60 |
+
output_dim = 1
|
| 61 |
+
else:
|
| 62 |
+
output_dim = len(data[0].y[0])
|
| 63 |
+
|
| 64 |
+
##Set up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer)
|
| 65 |
+
if pre_fc_count > 0:
|
| 66 |
+
self.pre_lin_list = torch.nn.ModuleList()
|
| 67 |
+
for i in range(pre_fc_count):
|
| 68 |
+
if i == 0:
|
| 69 |
+
lin = torch.nn.Linear(data.num_features, dim1)
|
| 70 |
+
self.pre_lin_list.append(lin)
|
| 71 |
+
else:
|
| 72 |
+
lin = torch.nn.Linear(dim1, dim1)
|
| 73 |
+
self.pre_lin_list.append(lin)
|
| 74 |
+
elif pre_fc_count == 0:
|
| 75 |
+
self.pre_lin_list = torch.nn.ModuleList()
|
| 76 |
+
|
| 77 |
+
##Set up GNN layers
|
| 78 |
+
self.conv_list = torch.nn.ModuleList()
|
| 79 |
+
self.bn_list = torch.nn.ModuleList()
|
| 80 |
+
for i in range(gc_count):
|
| 81 |
+
conv = InteractionBlock(gc_dim, data.num_edge_features, dim3, cutoff)
|
| 82 |
+
self.conv_list.append(conv)
|
| 83 |
+
##Track running stats set to false can prevent some instabilities; this causes other issues with different val/test performance from loader size?
|
| 84 |
+
if self.batch_norm == "True":
|
| 85 |
+
bn = BatchNorm1d(gc_dim, track_running_stats=self.batch_track_stats)
|
| 86 |
+
self.bn_list.append(bn)
|
| 87 |
+
|
| 88 |
+
##Set up post-GNN dense layers (NOTE: in v0.1 there was a minimum of 2 dense layers, and fc_count(now post_fc_count) added to this number. In the current version, the minimum is zero)
|
| 89 |
+
if post_fc_count > 0:
|
| 90 |
+
self.post_lin_list = torch.nn.ModuleList()
|
| 91 |
+
for i in range(post_fc_count):
|
| 92 |
+
if i == 0:
|
| 93 |
+
##Set2set pooling has doubled dimension
|
| 94 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 95 |
+
lin = torch.nn.Linear(post_fc_dim * 2, dim2)
|
| 96 |
+
else:
|
| 97 |
+
lin = torch.nn.Linear(post_fc_dim, dim2)
|
| 98 |
+
self.post_lin_list.append(lin)
|
| 99 |
+
else:
|
| 100 |
+
lin = torch.nn.Linear(dim2, dim2)
|
| 101 |
+
self.post_lin_list.append(lin)
|
| 102 |
+
self.lin_out = torch.nn.Linear(dim2, output_dim)
|
| 103 |
+
|
| 104 |
+
elif post_fc_count == 0:
|
| 105 |
+
self.post_lin_list = torch.nn.ModuleList()
|
| 106 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 107 |
+
self.lin_out = torch.nn.Linear(post_fc_dim*2, output_dim)
|
| 108 |
+
else:
|
| 109 |
+
self.lin_out = torch.nn.Linear(post_fc_dim, output_dim)
|
| 110 |
+
|
| 111 |
+
##Set up set2set pooling (if used)
|
| 112 |
+
if self.pool_order == "early" and self.pool == "set2set":
|
| 113 |
+
self.set2set = Set2Set(post_fc_dim, processing_steps=3)
|
| 114 |
+
elif self.pool_order == "late" and self.pool == "set2set":
|
| 115 |
+
self.set2set = Set2Set(output_dim, processing_steps=3, num_layers=1)
|
| 116 |
+
# workaround for doubled dimension by set2set; if late pooling not reccomended to use set2set
|
| 117 |
+
self.lin_out_2 = torch.nn.Linear(output_dim * 2, output_dim)
|
| 118 |
+
|
| 119 |
+
def forward(self, data):
|
| 120 |
+
|
| 121 |
+
##Pre-GNN dense layers
|
| 122 |
+
for i in range(0, len(self.pre_lin_list)):
|
| 123 |
+
if i == 0:
|
| 124 |
+
out = self.pre_lin_list[i](data.x)
|
| 125 |
+
out = getattr(F, self.act)(out)
|
| 126 |
+
else:
|
| 127 |
+
out = self.pre_lin_list[i](out)
|
| 128 |
+
out = getattr(F, self.act)(out)
|
| 129 |
+
|
| 130 |
+
##GNN layers
|
| 131 |
+
for i in range(0, len(self.conv_list)):
|
| 132 |
+
if len(self.pre_lin_list) == 0 and i == 0:
|
| 133 |
+
if self.batch_norm == "True":
|
| 134 |
+
out = data.x + self.conv_list[i](data.x, data.edge_index, data.edge_weight, data.edge_attr)
|
| 135 |
+
out = self.bn_list[i](out)
|
| 136 |
+
else:
|
| 137 |
+
out = data.x + self.conv_list[i](data.x, data.edge_index, data.edge_weight, data.edge_attr)
|
| 138 |
+
else:
|
| 139 |
+
if self.batch_norm == "True":
|
| 140 |
+
out = out + self.conv_list[i](out, data.edge_index, data.edge_weight, data.edge_attr)
|
| 141 |
+
out = self.bn_list[i](out)
|
| 142 |
+
else:
|
| 143 |
+
out = out + self.conv_list[i](out, data.edge_index, data.edge_weight, data.edge_attr)
|
| 144 |
+
#out = getattr(F, self.act)(out)
|
| 145 |
+
out = F.dropout(out, p=self.dropout_rate, training=self.training)
|
| 146 |
+
|
| 147 |
+
##Post-GNN dense layers
|
| 148 |
+
if self.pool_order == "early":
|
| 149 |
+
if self.pool == "set2set":
|
| 150 |
+
out = self.set2set(out, data.batch)
|
| 151 |
+
else:
|
| 152 |
+
out = getattr(torch_geometric.nn, self.pool)(out, data.batch)
|
| 153 |
+
for i in range(0, len(self.post_lin_list)):
|
| 154 |
+
out = self.post_lin_list[i](out)
|
| 155 |
+
out = getattr(F, self.act)(out)
|
| 156 |
+
out = self.lin_out(out)
|
| 157 |
+
|
| 158 |
+
elif self.pool_order == "late":
|
| 159 |
+
for i in range(0, len(self.post_lin_list)):
|
| 160 |
+
out = self.post_lin_list[i](out)
|
| 161 |
+
out = getattr(F, self.act)(out)
|
| 162 |
+
out = self.lin_out(out)
|
| 163 |
+
if self.pool == "set2set":
|
| 164 |
+
out = self.set2set(out, data.batch)
|
| 165 |
+
out = self.lin_out_2(out)
|
| 166 |
+
else:
|
| 167 |
+
out = getattr(torch_geometric.nn, self.pool)(out, data.batch)
|
| 168 |
+
|
| 169 |
+
if out.shape[1] == 1:
|
| 170 |
+
return out.view(-1)
|
| 171 |
+
else:
|
| 172 |
+
return out
|
matdeeplearn/models/utils.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
# Prints model summary
|
| 4 |
+
def model_summary(model):
|
| 5 |
+
model_params_list = list(model.named_parameters())
|
| 6 |
+
print("--------------------------------------------------------------------------")
|
| 7 |
+
line_new = "{:>30} {:>20} {:>20}".format(
|
| 8 |
+
"Layer.Parameter", "Param Tensor Shape", "Param #"
|
| 9 |
+
)
|
| 10 |
+
print(line_new)
|
| 11 |
+
print("--------------------------------------------------------------------------")
|
| 12 |
+
for elem in model_params_list:
|
| 13 |
+
p_name = elem[0]
|
| 14 |
+
p_shape = list(elem[1].size())
|
| 15 |
+
p_count = torch.tensor(elem[1].size()).prod().item()
|
| 16 |
+
line_new = "{:>30} {:>20} {:>20}".format(p_name, str(p_shape), str(p_count))
|
| 17 |
+
print(line_new)
|
| 18 |
+
print("--------------------------------------------------------------------------")
|
| 19 |
+
total_params = sum([param.nelement() for param in model.parameters()])
|
| 20 |
+
print("Total params:", total_params)
|
| 21 |
+
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 22 |
+
print("Trainable params:", num_trainable_params)
|
| 23 |
+
print("Non-trainable params:", total_params - num_trainable_params)
|
matdeeplearn/process/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .process import *
|
matdeeplearn/process/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (217 Bytes). View file
|
|
|
matdeeplearn/process/__pycache__/process.cpython-37.pyc
ADDED
|
Binary file (15.1 kB). View file
|
|
|
matdeeplearn/process/dictionary_blank.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"1": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "2": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "3": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "4": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "5": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "6": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "7": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "8": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "9": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "10": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "11": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "12": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "13": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "14": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "15": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "16": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "17": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "18": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "19": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "20": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "21": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "22": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "23": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "24": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "25": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "26": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "27": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "28": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "29": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "30": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "31": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "32": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "33": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "34": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "35": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "36": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "37": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "38": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "39": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "40": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "41": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "42": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "43": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "44": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "45": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "46": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "47": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "48": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "49": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "50": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "51": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "52": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "53": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "54": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "55": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "56": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "57": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "58": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "59": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "60": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "61": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "62": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "63": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "64": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "65": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "66": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "67": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "68": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "69": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "70": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "71": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "72": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "73": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "74": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "75": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "76": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "77": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "78": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "79": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "80": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "81": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "82": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "83": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "84": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "85": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "86": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "87": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "88": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "89": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "90": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "91": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "92": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "93": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "94": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "95": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "96": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "97": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "98": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "99": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "100": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
|
matdeeplearn/process/dictionary_default.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"1": [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "2": [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "3": [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "4": [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "5": [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "6": [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "7": [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "8": [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "9": [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "10": [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "11": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "12": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "13": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "14": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "15": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "16": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "17": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "18": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "19": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "20": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "21": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "22": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "23": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "24": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "25": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "26": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "27": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "28": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "29": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "30": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "31": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "32": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "33": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "34": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "35": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "36": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "37": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "38": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "39": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "40": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "41": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "42": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "43": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "44": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "45": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "46": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "47": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "48": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "49": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "50": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "51": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "52": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "53": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "54": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "55": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "56": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "57": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "58": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "59": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "60": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "61": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "62": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "63": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "64": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "65": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "66": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "67": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "68": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "69": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "70": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "71": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "72": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "73": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "74": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "75": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "76": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "77": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "78": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "79": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "80": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "81": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "82": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "83": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "84": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "85": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "86": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "87": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "88": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "89": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "90": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "91": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], "92": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], "93": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], "94": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], "95": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], "96": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], "97": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], "98": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], "99": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], "100": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]}
|
matdeeplearn/process/process.py
ADDED
|
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import csv
|
| 5 |
+
import json
|
| 6 |
+
import warnings
|
| 7 |
+
import numpy as np
|
| 8 |
+
import ase
|
| 9 |
+
import glob
|
| 10 |
+
from ase import io
|
| 11 |
+
from scipy.stats import rankdata
|
| 12 |
+
from scipy import interpolate
|
| 13 |
+
|
| 14 |
+
##torch imports
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from torch_geometric.data import DataLoader, Dataset, Data, InMemoryDataset
|
| 18 |
+
from torch_geometric.utils import dense_to_sparse, degree, add_self_loops
|
| 19 |
+
import torch_geometric.transforms as T
|
| 20 |
+
from torch_geometric.utils import degree
|
| 21 |
+
|
| 22 |
+
################################################################################
|
| 23 |
+
# Data splitting
|
| 24 |
+
################################################################################
|
| 25 |
+
|
| 26 |
+
##basic train, val, test split
|
| 27 |
+
def split_data(
|
| 28 |
+
dataset,
|
| 29 |
+
train_ratio,
|
| 30 |
+
val_ratio,
|
| 31 |
+
test_ratio,
|
| 32 |
+
seed=np.random.randint(1, 1e6),
|
| 33 |
+
save=False,
|
| 34 |
+
):
|
| 35 |
+
dataset_size = len(dataset)
|
| 36 |
+
if (train_ratio + val_ratio + test_ratio) <= 1:
|
| 37 |
+
train_length = int(dataset_size * train_ratio)
|
| 38 |
+
val_length = int(dataset_size * val_ratio)
|
| 39 |
+
test_length = int(dataset_size * test_ratio)
|
| 40 |
+
unused_length = dataset_size - train_length - val_length - test_length
|
| 41 |
+
(
|
| 42 |
+
train_dataset,
|
| 43 |
+
val_dataset,
|
| 44 |
+
test_dataset,
|
| 45 |
+
unused_dataset,
|
| 46 |
+
) = torch.utils.data.random_split(
|
| 47 |
+
dataset,
|
| 48 |
+
[train_length, val_length, test_length, unused_length],
|
| 49 |
+
generator=torch.Generator().manual_seed(seed),
|
| 50 |
+
)
|
| 51 |
+
print(
|
| 52 |
+
"train length:",
|
| 53 |
+
train_length,
|
| 54 |
+
"val length:",
|
| 55 |
+
val_length,
|
| 56 |
+
"test length:",
|
| 57 |
+
test_length,
|
| 58 |
+
"unused length:",
|
| 59 |
+
unused_length,
|
| 60 |
+
"seed :",
|
| 61 |
+
seed,
|
| 62 |
+
)
|
| 63 |
+
return train_dataset, val_dataset, test_dataset
|
| 64 |
+
else:
|
| 65 |
+
print("invalid ratios")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
##Basic CV split
|
| 69 |
+
def split_data_CV(dataset, num_folds=5, seed=np.random.randint(1, 1e6), save=False):
|
| 70 |
+
dataset_size = len(dataset)
|
| 71 |
+
fold_length = int(dataset_size / num_folds)
|
| 72 |
+
unused_length = dataset_size - fold_length * num_folds
|
| 73 |
+
folds = [fold_length for i in range(num_folds)]
|
| 74 |
+
folds.append(unused_length)
|
| 75 |
+
cv_dataset = torch.utils.data.random_split(
|
| 76 |
+
dataset, folds, generator=torch.Generator().manual_seed(seed)
|
| 77 |
+
)
|
| 78 |
+
print("fold length :", fold_length, "unused length:", unused_length, "seed", seed)
|
| 79 |
+
return cv_dataset[0:num_folds]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
################################################################################
|
| 83 |
+
# Pytorch datasets
|
| 84 |
+
################################################################################
|
| 85 |
+
|
| 86 |
+
##Fetch dataset; processes the raw data if specified
|
| 87 |
+
def get_dataset(data_path, target_index, reprocess="False", processing_args=None):
|
| 88 |
+
if processing_args == None:
|
| 89 |
+
processed_path = "processed"
|
| 90 |
+
else:
|
| 91 |
+
processed_path = processing_args.get("processed_path", "processed")
|
| 92 |
+
|
| 93 |
+
transforms = GetY(index=target_index)
|
| 94 |
+
|
| 95 |
+
if os.path.exists(data_path) == False:
|
| 96 |
+
print("Data not found in:", data_path)
|
| 97 |
+
sys.exit()
|
| 98 |
+
|
| 99 |
+
if reprocess == "True":
|
| 100 |
+
os.system("rm -rf " + os.path.join(data_path, processed_path))
|
| 101 |
+
process_data(data_path, processed_path, processing_args)
|
| 102 |
+
|
| 103 |
+
if os.path.exists(os.path.join(data_path, processed_path, "data.pt")) == True:
|
| 104 |
+
dataset = StructureDataset(
|
| 105 |
+
data_path,
|
| 106 |
+
processed_path,
|
| 107 |
+
transforms,
|
| 108 |
+
)
|
| 109 |
+
elif os.path.exists(os.path.join(data_path, processed_path, "data0.pt")) == True:
|
| 110 |
+
dataset = StructureDataset_large(
|
| 111 |
+
data_path,
|
| 112 |
+
processed_path,
|
| 113 |
+
transforms,
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
process_data(data_path, processed_path, processing_args)
|
| 117 |
+
if os.path.exists(os.path.join(data_path, processed_path, "data.pt")) == True:
|
| 118 |
+
dataset = StructureDataset(
|
| 119 |
+
data_path,
|
| 120 |
+
processed_path,
|
| 121 |
+
transforms,
|
| 122 |
+
)
|
| 123 |
+
elif os.path.exists(os.path.join(data_path, processed_path, "data0.pt")) == True:
|
| 124 |
+
dataset = StructureDataset_large(
|
| 125 |
+
data_path,
|
| 126 |
+
processed_path,
|
| 127 |
+
transforms,
|
| 128 |
+
)
|
| 129 |
+
return dataset
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
##Dataset class from pytorch/pytorch geometric; inmemory case
|
| 133 |
+
class StructureDataset(InMemoryDataset):
|
| 134 |
+
def __init__(
|
| 135 |
+
self, data_path, processed_path="processed", transform=None, pre_transform=None
|
| 136 |
+
):
|
| 137 |
+
self.data_path = data_path
|
| 138 |
+
self.processed_path = processed_path
|
| 139 |
+
super(StructureDataset, self).__init__(data_path, transform, pre_transform)
|
| 140 |
+
self.data, self.slices = torch.load(self.processed_paths[0])
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def raw_file_names(self):
|
| 144 |
+
return []
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def processed_dir(self):
|
| 148 |
+
return os.path.join(self.data_path, self.processed_path)
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
def processed_file_names(self):
|
| 152 |
+
file_names = ["data.pt"]
|
| 153 |
+
return file_names
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
##Dataset class from pytorch/pytorch geometric
|
| 157 |
+
class StructureDataset_large(Dataset):
|
| 158 |
+
def __init__(
|
| 159 |
+
self, data_path, processed_path="processed", transform=None, pre_transform=None
|
| 160 |
+
):
|
| 161 |
+
self.data_path = data_path
|
| 162 |
+
self.processed_path = processed_path
|
| 163 |
+
super(StructureDataset_large, self).__init__(
|
| 164 |
+
data_path, transform, pre_transform
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
@property
|
| 168 |
+
def raw_file_names(self):
|
| 169 |
+
return []
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def processed_dir(self):
|
| 173 |
+
return os.path.join(self.data_path, self.processed_path)
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def processed_file_names(self):
|
| 177 |
+
# file_names = ["data.pt"]
|
| 178 |
+
file_names = []
|
| 179 |
+
for file_name in glob.glob(self.processed_dir + "/data*.pt"):
|
| 180 |
+
file_names.append(os.path.basename(file_name))
|
| 181 |
+
# print(file_names)
|
| 182 |
+
return file_names
|
| 183 |
+
|
| 184 |
+
def len(self):
|
| 185 |
+
return len(self.processed_file_names)
|
| 186 |
+
|
| 187 |
+
def get(self, idx):
|
| 188 |
+
data = torch.load(os.path.join(self.processed_dir, "data_{}.pt".format(idx)))
|
| 189 |
+
return data
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
################################################################################
|
| 193 |
+
# Processing
|
| 194 |
+
################################################################################
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def process_data(data_path, processed_path, processing_args):
|
| 198 |
+
|
| 199 |
+
##Begin processing data
|
| 200 |
+
print("Processing data to: " + os.path.join(data_path, processed_path))
|
| 201 |
+
assert os.path.exists(data_path), "Data path not found in " + data_path
|
| 202 |
+
|
| 203 |
+
##Load dictionary
|
| 204 |
+
if processing_args["dictionary_source"] != "generated":
|
| 205 |
+
if processing_args["dictionary_source"] == "default":
|
| 206 |
+
print("Using default dictionary.")
|
| 207 |
+
atom_dictionary = get_dictionary(
|
| 208 |
+
os.path.join(
|
| 209 |
+
os.path.dirname(os.path.realpath(__file__)),
|
| 210 |
+
"dictionary_default.json",
|
| 211 |
+
)
|
| 212 |
+
)
|
| 213 |
+
elif processing_args["dictionary_source"] == "blank":
|
| 214 |
+
print(
|
| 215 |
+
"Using blank dictionary. Warning: only do this if you know what you are doing"
|
| 216 |
+
)
|
| 217 |
+
atom_dictionary = get_dictionary(
|
| 218 |
+
os.path.join(
|
| 219 |
+
os.path.dirname(os.path.realpath(__file__)), "dictionary_blank.json"
|
| 220 |
+
)
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
dictionary_file_path = os.path.join(
|
| 224 |
+
data_path, processing_args["dictionary_path"]
|
| 225 |
+
)
|
| 226 |
+
if os.path.exists(dictionary_file_path) == False:
|
| 227 |
+
print("Atom dictionary not found, exiting program...")
|
| 228 |
+
sys.exit()
|
| 229 |
+
else:
|
| 230 |
+
print("Loading atom dictionary from file.")
|
| 231 |
+
atom_dictionary = get_dictionary(dictionary_file_path)
|
| 232 |
+
|
| 233 |
+
##Load targets
|
| 234 |
+
target_property_file = os.path.join(data_path, processing_args["target_path"])
|
| 235 |
+
assert os.path.exists(target_property_file), (
|
| 236 |
+
"targets not found in " + target_property_file
|
| 237 |
+
)
|
| 238 |
+
with open(target_property_file) as f:
|
| 239 |
+
reader = csv.reader(f)
|
| 240 |
+
target_data = [row for row in reader]
|
| 241 |
+
|
| 242 |
+
##Read db file if specified
|
| 243 |
+
ase_crystal_list = []
|
| 244 |
+
if processing_args["data_format"] == "db":
|
| 245 |
+
db = ase.db.connect(os.path.join(data_path, "data.db"))
|
| 246 |
+
row_count = 0
|
| 247 |
+
# target_data=[]
|
| 248 |
+
for row in db.select():
|
| 249 |
+
# target_data.append([str(row_count), row.get('target')])
|
| 250 |
+
ase_temp = row.toatoms()
|
| 251 |
+
ase_crystal_list.append(ase_temp)
|
| 252 |
+
row_count = row_count + 1
|
| 253 |
+
if row_count % 500 == 0:
|
| 254 |
+
print("db processed: ", row_count)
|
| 255 |
+
|
| 256 |
+
##Process structure files and create structure graphs
|
| 257 |
+
data_list = []
|
| 258 |
+
for index in range(0, len(target_data)):
|
| 259 |
+
|
| 260 |
+
structure_id = target_data[index][0]
|
| 261 |
+
data = Data()
|
| 262 |
+
|
| 263 |
+
##Read in structure file using ase
|
| 264 |
+
if processing_args["data_format"] != "db":
|
| 265 |
+
ase_crystal = ase.io.read(
|
| 266 |
+
os.path.join(
|
| 267 |
+
data_path, structure_id + "." + processing_args["data_format"]
|
| 268 |
+
)
|
| 269 |
+
)
|
| 270 |
+
data.ase = ase_crystal
|
| 271 |
+
else:
|
| 272 |
+
ase_crystal = ase_crystal_list[index]
|
| 273 |
+
data.ase = ase_crystal
|
| 274 |
+
|
| 275 |
+
##Compile structure sizes (# of atoms) and elemental compositions
|
| 276 |
+
if index == 0:
|
| 277 |
+
length = [len(ase_crystal)]
|
| 278 |
+
elements = [list(set(ase_crystal.get_chemical_symbols()))]
|
| 279 |
+
else:
|
| 280 |
+
length.append(len(ase_crystal))
|
| 281 |
+
elements.append(list(set(ase_crystal.get_chemical_symbols())))
|
| 282 |
+
|
| 283 |
+
##Obtain distance matrix with ase
|
| 284 |
+
distance_matrix = ase_crystal.get_all_distances(mic=True)
|
| 285 |
+
|
| 286 |
+
##Create sparse graph from distance matrix
|
| 287 |
+
distance_matrix_trimmed = threshold_sort(
|
| 288 |
+
distance_matrix,
|
| 289 |
+
processing_args["graph_max_radius"],
|
| 290 |
+
processing_args["graph_max_neighbors"],
|
| 291 |
+
adj=False,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
distance_matrix_trimmed = torch.Tensor(distance_matrix_trimmed)
|
| 295 |
+
out = dense_to_sparse(distance_matrix_trimmed)
|
| 296 |
+
edge_index = out[0]
|
| 297 |
+
edge_weight = out[1]
|
| 298 |
+
|
| 299 |
+
self_loops = True
|
| 300 |
+
if self_loops == True:
|
| 301 |
+
edge_index, edge_weight = add_self_loops(
|
| 302 |
+
edge_index, edge_weight, num_nodes=len(ase_crystal), fill_value=0
|
| 303 |
+
)
|
| 304 |
+
data.edge_index = edge_index
|
| 305 |
+
data.edge_weight = edge_weight
|
| 306 |
+
|
| 307 |
+
distance_matrix_mask = (
|
| 308 |
+
distance_matrix_trimmed.fill_diagonal_(1) != 0
|
| 309 |
+
).int()
|
| 310 |
+
elif self_loops == False:
|
| 311 |
+
data.edge_index = edge_index
|
| 312 |
+
data.edge_weight = edge_weight
|
| 313 |
+
|
| 314 |
+
distance_matrix_mask = (distance_matrix_trimmed != 0).int()
|
| 315 |
+
|
| 316 |
+
data.edge_descriptor = {}
|
| 317 |
+
data.edge_descriptor["distance"] = edge_weight
|
| 318 |
+
data.edge_descriptor["mask"] = distance_matrix_mask
|
| 319 |
+
|
| 320 |
+
target = target_data[index][1:]
|
| 321 |
+
y = torch.Tensor(np.array([target], dtype=np.float32))
|
| 322 |
+
data.y = y
|
| 323 |
+
|
| 324 |
+
# pos = torch.Tensor(ase_crystal.get_positions())
|
| 325 |
+
# data.pos = pos
|
| 326 |
+
z = torch.LongTensor(ase_crystal.get_atomic_numbers())
|
| 327 |
+
data.z = z
|
| 328 |
+
|
| 329 |
+
###placeholder for state feature
|
| 330 |
+
u = np.zeros((3))
|
| 331 |
+
u = torch.Tensor(u[np.newaxis, ...])
|
| 332 |
+
data.u = u
|
| 333 |
+
|
| 334 |
+
data.structure_id = [[structure_id] * len(data.y)]
|
| 335 |
+
|
| 336 |
+
if processing_args["verbose"] == "True" and (
|
| 337 |
+
(index + 1) % 500 == 0 or (index + 1) == len(target_data)
|
| 338 |
+
):
|
| 339 |
+
print("Data processed: ", index + 1, "out of", len(target_data))
|
| 340 |
+
# if index == 0:
|
| 341 |
+
# print(data)
|
| 342 |
+
# print(data.edge_weight, data.edge_attr[0])
|
| 343 |
+
|
| 344 |
+
data_list.append(data)
|
| 345 |
+
|
| 346 |
+
##
|
| 347 |
+
n_atoms_max = max(length)
|
| 348 |
+
species = list(set(sum(elements, [])))
|
| 349 |
+
species.sort()
|
| 350 |
+
num_species = len(species)
|
| 351 |
+
if processing_args["verbose"] == "True":
|
| 352 |
+
print(
|
| 353 |
+
"Max structure size: ",
|
| 354 |
+
n_atoms_max,
|
| 355 |
+
"Max number of elements: ",
|
| 356 |
+
num_species,
|
| 357 |
+
)
|
| 358 |
+
print("Unique species:", species)
|
| 359 |
+
crystal_length = len(ase_crystal)
|
| 360 |
+
data.length = torch.LongTensor([crystal_length])
|
| 361 |
+
|
| 362 |
+
##Generate node features
|
| 363 |
+
if processing_args["dictionary_source"] != "generated":
|
| 364 |
+
##Atom features(node features) from atom dictionary file
|
| 365 |
+
for index in range(0, len(data_list)):
|
| 366 |
+
atom_fea = np.vstack(
|
| 367 |
+
[
|
| 368 |
+
atom_dictionary[str(data_list[index].ase.get_atomic_numbers()[i])]
|
| 369 |
+
for i in range(len(data_list[index].ase))
|
| 370 |
+
]
|
| 371 |
+
).astype(float)
|
| 372 |
+
data_list[index].x = torch.Tensor(atom_fea)
|
| 373 |
+
elif processing_args["dictionary_source"] == "generated":
|
| 374 |
+
##Generates one-hot node features rather than using dict file
|
| 375 |
+
from sklearn.preprocessing import LabelBinarizer
|
| 376 |
+
|
| 377 |
+
lb = LabelBinarizer()
|
| 378 |
+
lb.fit(species)
|
| 379 |
+
for index in range(0, len(data_list)):
|
| 380 |
+
data_list[index].x = torch.Tensor(
|
| 381 |
+
lb.transform(data_list[index].ase.get_chemical_symbols())
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
##Adds node degree to node features (appears to improve performance)
|
| 385 |
+
for index in range(0, len(data_list)):
|
| 386 |
+
data_list[index] = OneHotDegree(
|
| 387 |
+
data_list[index], processing_args["graph_max_neighbors"] + 1
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
##Get graphs based on voronoi connectivity; todo: also get voronoi features
|
| 391 |
+
##avoid use for the time being until a good approach is found
|
| 392 |
+
processing_args["voronoi"] = "False"
|
| 393 |
+
if processing_args["voronoi"] == "True":
|
| 394 |
+
from pymatgen.core.structure import Structure
|
| 395 |
+
from pymatgen.analysis.structure_analyzer import VoronoiConnectivity
|
| 396 |
+
from pymatgen.io.ase import AseAtomsAdaptor
|
| 397 |
+
|
| 398 |
+
Converter = AseAtomsAdaptor()
|
| 399 |
+
|
| 400 |
+
for index in range(0, len(data_list)):
|
| 401 |
+
pymatgen_crystal = Converter.get_structure(data_list[index].ase)
|
| 402 |
+
# double check if cutoff distance does anything
|
| 403 |
+
Voronoi = VoronoiConnectivity(
|
| 404 |
+
pymatgen_crystal, cutoff=processing_args["graph_max_radius"]
|
| 405 |
+
)
|
| 406 |
+
connections = Voronoi.max_connectivity
|
| 407 |
+
|
| 408 |
+
distance_matrix_voronoi = threshold_sort(
|
| 409 |
+
connections,
|
| 410 |
+
9999,
|
| 411 |
+
processing_args["graph_max_neighbors"],
|
| 412 |
+
reverse=True,
|
| 413 |
+
adj=False,
|
| 414 |
+
)
|
| 415 |
+
distance_matrix_voronoi = torch.Tensor(distance_matrix_voronoi)
|
| 416 |
+
|
| 417 |
+
out = dense_to_sparse(distance_matrix_voronoi)
|
| 418 |
+
edge_index_voronoi = out[0]
|
| 419 |
+
edge_weight_voronoi = out[1]
|
| 420 |
+
|
| 421 |
+
edge_attr_voronoi = distance_gaussian(edge_weight_voronoi)
|
| 422 |
+
edge_attr_voronoi = edge_attr_voronoi.float()
|
| 423 |
+
|
| 424 |
+
data_list[index].edge_index_voronoi = edge_index_voronoi
|
| 425 |
+
data_list[index].edge_weight_voronoi = edge_weight_voronoi
|
| 426 |
+
data_list[index].edge_attr_voronoi = edge_attr_voronoi
|
| 427 |
+
if index % 500 == 0:
|
| 428 |
+
print("Voronoi data processed: ", index)
|
| 429 |
+
|
| 430 |
+
##makes SOAP and SM features from dscribe
|
| 431 |
+
if processing_args["SOAP_descriptor"] == "True":
|
| 432 |
+
if True in data_list[0].ase.pbc:
|
| 433 |
+
periodicity = True
|
| 434 |
+
else:
|
| 435 |
+
periodicity = False
|
| 436 |
+
|
| 437 |
+
from dscribe.descriptors import SOAP
|
| 438 |
+
|
| 439 |
+
make_feature_SOAP = SOAP(
|
| 440 |
+
species=species,
|
| 441 |
+
rcut=processing_args["SOAP_rcut"],
|
| 442 |
+
nmax=processing_args["SOAP_nmax"],
|
| 443 |
+
lmax=processing_args["SOAP_lmax"],
|
| 444 |
+
sigma=processing_args["SOAP_sigma"],
|
| 445 |
+
periodic=periodicity,
|
| 446 |
+
sparse=False,
|
| 447 |
+
average="inner",
|
| 448 |
+
rbf="gto",
|
| 449 |
+
crossover=False,
|
| 450 |
+
)
|
| 451 |
+
for index in range(0, len(data_list)):
|
| 452 |
+
features_SOAP = make_feature_SOAP.create(data_list[index].ase)
|
| 453 |
+
data_list[index].extra_features_SOAP = torch.Tensor(features_SOAP)
|
| 454 |
+
if processing_args["verbose"] == "True" and index % 500 == 0:
|
| 455 |
+
if index == 0:
|
| 456 |
+
print(
|
| 457 |
+
"SOAP length: ",
|
| 458 |
+
features_SOAP.shape,
|
| 459 |
+
)
|
| 460 |
+
print("SOAP descriptor processed: ", index)
|
| 461 |
+
|
| 462 |
+
elif processing_args["SM_descriptor"] == "True":
|
| 463 |
+
if True in data_list[0].ase.pbc:
|
| 464 |
+
periodicity = True
|
| 465 |
+
else:
|
| 466 |
+
periodicity = False
|
| 467 |
+
|
| 468 |
+
from dscribe.descriptors import SineMatrix, CoulombMatrix
|
| 469 |
+
|
| 470 |
+
if periodicity == True:
|
| 471 |
+
make_feature_SM = SineMatrix(
|
| 472 |
+
n_atoms_max=n_atoms_max,
|
| 473 |
+
permutation="eigenspectrum",
|
| 474 |
+
sparse=False,
|
| 475 |
+
flatten=True,
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
make_feature_SM = CoulombMatrix(
|
| 479 |
+
n_atoms_max=n_atoms_max,
|
| 480 |
+
permutation="eigenspectrum",
|
| 481 |
+
sparse=False,
|
| 482 |
+
flatten=True,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
for index in range(0, len(data_list)):
|
| 486 |
+
features_SM = make_feature_SM.create(data_list[index].ase)
|
| 487 |
+
data_list[index].extra_features_SM = torch.Tensor(features_SM)
|
| 488 |
+
if processing_args["verbose"] == "True" and index % 500 == 0:
|
| 489 |
+
if index == 0:
|
| 490 |
+
print(
|
| 491 |
+
"SM length: ",
|
| 492 |
+
features_SM.shape,
|
| 493 |
+
)
|
| 494 |
+
print("SM descriptor processed: ", index)
|
| 495 |
+
|
| 496 |
+
##Generate edge features
|
| 497 |
+
if processing_args["edge_features"] == "True":
|
| 498 |
+
|
| 499 |
+
##Distance descriptor using a Gaussian basis
|
| 500 |
+
distance_gaussian = GaussianSmearing(
|
| 501 |
+
0, 1, processing_args["graph_edge_length"], 0.2
|
| 502 |
+
)
|
| 503 |
+
# print(GetRanges(data_list, 'distance'))
|
| 504 |
+
NormalizeEdge(data_list, "distance")
|
| 505 |
+
# print(GetRanges(data_list, 'distance'))
|
| 506 |
+
for index in range(0, len(data_list)):
|
| 507 |
+
data_list[index].edge_attr = distance_gaussian(
|
| 508 |
+
data_list[index].edge_descriptor["distance"]
|
| 509 |
+
)
|
| 510 |
+
if processing_args["verbose"] == "True" and (
|
| 511 |
+
(index + 1) % 500 == 0 or (index + 1) == len(target_data)
|
| 512 |
+
):
|
| 513 |
+
print("Edge processed: ", index + 1, "out of", len(target_data))
|
| 514 |
+
|
| 515 |
+
Cleanup(data_list, ["ase", "edge_descriptor"])
|
| 516 |
+
|
| 517 |
+
if os.path.isdir(os.path.join(data_path, processed_path)) == False:
|
| 518 |
+
os.mkdir(os.path.join(data_path, processed_path))
|
| 519 |
+
|
| 520 |
+
##Save processed dataset to file
|
| 521 |
+
if processing_args["dataset_type"] == "inmemory":
|
| 522 |
+
data, slices = InMemoryDataset.collate(data_list)
|
| 523 |
+
torch.save((data, slices), os.path.join(data_path, processed_path, "data.pt"))
|
| 524 |
+
|
| 525 |
+
elif processing_args["dataset_type"] == "large":
|
| 526 |
+
for i in range(0, len(data_list)):
|
| 527 |
+
torch.save(
|
| 528 |
+
data_list[i],
|
| 529 |
+
os.path.join(
|
| 530 |
+
os.path.join(data_path, processed_path), "data_{}.pt".format(i)
|
| 531 |
+
),
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
################################################################################
|
| 536 |
+
# Processing sub-functions
|
| 537 |
+
################################################################################
|
| 538 |
+
|
| 539 |
+
##Selects edges with distance threshold and limited number of neighbors
|
| 540 |
+
def threshold_sort(matrix, threshold, neighbors, reverse=False, adj=False):
|
| 541 |
+
mask = matrix > threshold
|
| 542 |
+
distance_matrix_trimmed = np.ma.array(matrix, mask=mask)
|
| 543 |
+
if reverse == False:
|
| 544 |
+
distance_matrix_trimmed = rankdata(
|
| 545 |
+
distance_matrix_trimmed, method="ordinal", axis=1
|
| 546 |
+
)
|
| 547 |
+
elif reverse == True:
|
| 548 |
+
distance_matrix_trimmed = rankdata(
|
| 549 |
+
distance_matrix_trimmed * -1, method="ordinal", axis=1
|
| 550 |
+
)
|
| 551 |
+
distance_matrix_trimmed = np.nan_to_num(
|
| 552 |
+
np.where(mask, np.nan, distance_matrix_trimmed)
|
| 553 |
+
)
|
| 554 |
+
distance_matrix_trimmed[distance_matrix_trimmed > neighbors + 1] = 0
|
| 555 |
+
|
| 556 |
+
if adj == False:
|
| 557 |
+
distance_matrix_trimmed = np.where(
|
| 558 |
+
distance_matrix_trimmed == 0, distance_matrix_trimmed, matrix
|
| 559 |
+
)
|
| 560 |
+
return distance_matrix_trimmed
|
| 561 |
+
elif adj == True:
|
| 562 |
+
adj_list = np.zeros((matrix.shape[0], neighbors + 1))
|
| 563 |
+
adj_attr = np.zeros((matrix.shape[0], neighbors + 1))
|
| 564 |
+
for i in range(0, matrix.shape[0]):
|
| 565 |
+
temp = np.where(distance_matrix_trimmed[i] != 0)[0]
|
| 566 |
+
adj_list[i, :] = np.pad(
|
| 567 |
+
temp,
|
| 568 |
+
pad_width=(0, neighbors + 1 - len(temp)),
|
| 569 |
+
mode="constant",
|
| 570 |
+
constant_values=0,
|
| 571 |
+
)
|
| 572 |
+
adj_attr[i, :] = matrix[i, adj_list[i, :].astype(int)]
|
| 573 |
+
distance_matrix_trimmed = np.where(
|
| 574 |
+
distance_matrix_trimmed == 0, distance_matrix_trimmed, matrix
|
| 575 |
+
)
|
| 576 |
+
return distance_matrix_trimmed, adj_list, adj_attr
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
##Slightly edited version from pytorch geometric to create edge from gaussian basis
|
| 580 |
+
class GaussianSmearing(torch.nn.Module):
|
| 581 |
+
def __init__(self, start=0.0, stop=5.0, resolution=50, width=0.05, **kwargs):
|
| 582 |
+
super(GaussianSmearing, self).__init__()
|
| 583 |
+
offset = torch.linspace(start, stop, resolution)
|
| 584 |
+
# self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
|
| 585 |
+
self.coeff = -0.5 / ((stop - start) * width) ** 2
|
| 586 |
+
self.register_buffer("offset", offset)
|
| 587 |
+
|
| 588 |
+
def forward(self, dist):
|
| 589 |
+
dist = dist.unsqueeze(-1) - self.offset.view(1, -1)
|
| 590 |
+
return torch.exp(self.coeff * torch.pow(dist, 2))
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
##Obtain node degree in one-hot representation
|
| 594 |
+
def OneHotDegree(data, max_degree, in_degree=False, cat=True):
|
| 595 |
+
idx, x = data.edge_index[1 if in_degree else 0], data.x
|
| 596 |
+
deg = degree(idx, data.num_nodes, dtype=torch.long)
|
| 597 |
+
deg = F.one_hot(deg, num_classes=max_degree + 1).to(torch.float)
|
| 598 |
+
|
| 599 |
+
if x is not None and cat:
|
| 600 |
+
x = x.view(-1, 1) if x.dim() == 1 else x
|
| 601 |
+
data.x = torch.cat([x, deg.to(x.dtype)], dim=-1)
|
| 602 |
+
else:
|
| 603 |
+
data.x = deg
|
| 604 |
+
|
| 605 |
+
return data
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
##Obtain dictionary file for elemental features
|
| 609 |
+
def get_dictionary(dictionary_file):
|
| 610 |
+
with open(dictionary_file) as f:
|
| 611 |
+
atom_dictionary = json.load(f)
|
| 612 |
+
return atom_dictionary
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
##Deletes unnecessary data due to slow dataloader
|
| 616 |
+
def Cleanup(data_list, entries):
|
| 617 |
+
for data in data_list:
|
| 618 |
+
for entry in entries:
|
| 619 |
+
try:
|
| 620 |
+
delattr(data, entry)
|
| 621 |
+
except Exception:
|
| 622 |
+
pass
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
##Get min/max ranges for normalized edges
|
| 626 |
+
def GetRanges(dataset, descriptor_label):
|
| 627 |
+
mean = 0.0
|
| 628 |
+
std = 0.0
|
| 629 |
+
for index in range(0, len(dataset)):
|
| 630 |
+
if len(dataset[index].edge_descriptor[descriptor_label]) > 0:
|
| 631 |
+
if index == 0:
|
| 632 |
+
feature_max = dataset[index].edge_descriptor[descriptor_label].max()
|
| 633 |
+
feature_min = dataset[index].edge_descriptor[descriptor_label].min()
|
| 634 |
+
mean += dataset[index].edge_descriptor[descriptor_label].mean()
|
| 635 |
+
std += dataset[index].edge_descriptor[descriptor_label].std()
|
| 636 |
+
if dataset[index].edge_descriptor[descriptor_label].max() > feature_max:
|
| 637 |
+
feature_max = dataset[index].edge_descriptor[descriptor_label].max()
|
| 638 |
+
if dataset[index].edge_descriptor[descriptor_label].min() < feature_min:
|
| 639 |
+
feature_min = dataset[index].edge_descriptor[descriptor_label].min()
|
| 640 |
+
|
| 641 |
+
mean = mean / len(dataset)
|
| 642 |
+
std = std / len(dataset)
|
| 643 |
+
return mean, std, feature_min, feature_max
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
##Normalizes edges
|
| 647 |
+
def NormalizeEdge(dataset, descriptor_label):
|
| 648 |
+
mean, std, feature_min, feature_max = GetRanges(dataset, descriptor_label)
|
| 649 |
+
|
| 650 |
+
for data in dataset:
|
| 651 |
+
data.edge_descriptor[descriptor_label] = (
|
| 652 |
+
data.edge_descriptor[descriptor_label] - feature_min
|
| 653 |
+
) / (feature_max - feature_min)
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
# WIP
|
| 657 |
+
def SM_Edge(dataset):
|
| 658 |
+
from dscribe.descriptors import (
|
| 659 |
+
CoulombMatrix,
|
| 660 |
+
SOAP,
|
| 661 |
+
MBTR,
|
| 662 |
+
EwaldSumMatrix,
|
| 663 |
+
SineMatrix,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
count = 0
|
| 667 |
+
for data in dataset:
|
| 668 |
+
n_atoms_max = len(data.ase)
|
| 669 |
+
make_feature_SM = SineMatrix(
|
| 670 |
+
n_atoms_max=n_atoms_max,
|
| 671 |
+
permutation="none",
|
| 672 |
+
sparse=False,
|
| 673 |
+
flatten=False,
|
| 674 |
+
)
|
| 675 |
+
features_SM = make_feature_SM.create(data.ase)
|
| 676 |
+
features_SM_trimmed = np.where(data.mask == 0, data.mask, features_SM)
|
| 677 |
+
features_SM_trimmed = torch.Tensor(features_SM_trimmed)
|
| 678 |
+
out = dense_to_sparse(features_SM_trimmed)
|
| 679 |
+
edge_index = out[0]
|
| 680 |
+
edge_weight = out[1]
|
| 681 |
+
data.edge_descriptor["SM"] = edge_weight
|
| 682 |
+
|
| 683 |
+
if count % 500 == 0:
|
| 684 |
+
print("SM data processed: ", count)
|
| 685 |
+
count = count + 1
|
| 686 |
+
|
| 687 |
+
return dataset
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
################################################################################
|
| 691 |
+
# Transforms
|
| 692 |
+
################################################################################
|
| 693 |
+
|
| 694 |
+
##Get specified y index from data.y
|
| 695 |
+
class GetY(object):
|
| 696 |
+
def __init__(self, index=0):
|
| 697 |
+
self.index = index
|
| 698 |
+
|
| 699 |
+
def __call__(self, data):
|
| 700 |
+
# Specify target.
|
| 701 |
+
if self.index != -1:
|
| 702 |
+
data.y = data.y[0][self.index]
|
| 703 |
+
return data
|
matdeeplearn/training/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .training import *
|
matdeeplearn/training/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (219 Bytes). View file
|
|
|
matdeeplearn/training/__pycache__/training.cpython-37.pyc
ADDED
|
Binary file (20.4 kB). View file
|
|
|
matdeeplearn/training/training.py
ADDED
|
@@ -0,0 +1,1290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
##General imports
|
| 2 |
+
import csv
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
import shutil
|
| 7 |
+
import copy
|
| 8 |
+
import numpy as np
|
| 9 |
+
from functools import partial
|
| 10 |
+
import platform
|
| 11 |
+
|
| 12 |
+
##Torch imports
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import torch
|
| 15 |
+
from torch_geometric.data import DataLoader, Dataset
|
| 16 |
+
from torch_geometric.nn import DataParallel
|
| 17 |
+
import torch_geometric.transforms as T
|
| 18 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 19 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 20 |
+
import torch.distributed as dist
|
| 21 |
+
import torch.multiprocessing as mp
|
| 22 |
+
|
| 23 |
+
##Matdeeplearn imports
|
| 24 |
+
from matdeeplearn import models
|
| 25 |
+
import matdeeplearn.process as process
|
| 26 |
+
import matdeeplearn.training as training
|
| 27 |
+
from matdeeplearn.models.utils import model_summary
|
| 28 |
+
|
| 29 |
+
################################################################################
|
| 30 |
+
# Training functions
|
| 31 |
+
################################################################################
|
| 32 |
+
|
| 33 |
+
##Train step, runs model in train mode
|
| 34 |
+
def train(model, optimizer, loader, loss_method, rank):
|
| 35 |
+
model.train()
|
| 36 |
+
loss_all = 0
|
| 37 |
+
count = 0
|
| 38 |
+
for data in loader:
|
| 39 |
+
data = data.to(rank)
|
| 40 |
+
optimizer.zero_grad()
|
| 41 |
+
output = model(data)
|
| 42 |
+
# print(data.y.shape, output.shape)
|
| 43 |
+
loss = getattr(F, loss_method)(output, data.y)
|
| 44 |
+
loss.backward()
|
| 45 |
+
loss_all += loss.detach() * output.size(0)
|
| 46 |
+
|
| 47 |
+
# clip = 10
|
| 48 |
+
# torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
|
| 49 |
+
|
| 50 |
+
optimizer.step()
|
| 51 |
+
count = count + output.size(0)
|
| 52 |
+
|
| 53 |
+
loss_all = loss_all / count
|
| 54 |
+
return loss_all
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
##Evaluation step, runs model in eval mode
|
| 58 |
+
def evaluate(loader, model, loss_method, rank, out=False):
|
| 59 |
+
model.eval()
|
| 60 |
+
loss_all = 0
|
| 61 |
+
count = 0
|
| 62 |
+
for data in loader:
|
| 63 |
+
data = data.to(rank)
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
output = model(data)
|
| 66 |
+
loss = getattr(F, loss_method)(output, data.y)
|
| 67 |
+
loss_all += loss * output.size(0)
|
| 68 |
+
if out == True:
|
| 69 |
+
if count == 0:
|
| 70 |
+
ids = [item for sublist in data.structure_id for item in sublist]
|
| 71 |
+
ids = [item for sublist in ids for item in sublist]
|
| 72 |
+
predict = output.data.cpu().numpy()
|
| 73 |
+
target = data.y.cpu().numpy()
|
| 74 |
+
else:
|
| 75 |
+
ids_temp = [
|
| 76 |
+
item for sublist in data.structure_id for item in sublist
|
| 77 |
+
]
|
| 78 |
+
ids_temp = [item for sublist in ids_temp for item in sublist]
|
| 79 |
+
ids = ids + ids_temp
|
| 80 |
+
predict = np.concatenate(
|
| 81 |
+
(predict, output.data.cpu().numpy()), axis=0
|
| 82 |
+
)
|
| 83 |
+
target = np.concatenate((target, data.y.cpu().numpy()), axis=0)
|
| 84 |
+
count = count + output.size(0)
|
| 85 |
+
|
| 86 |
+
loss_all = loss_all / count
|
| 87 |
+
|
| 88 |
+
if out == True:
|
| 89 |
+
test_out = np.column_stack((ids, target, predict))
|
| 90 |
+
return loss_all, test_out
|
| 91 |
+
elif out == False:
|
| 92 |
+
return loss_all
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
##Model trainer
|
| 96 |
+
def trainer(
|
| 97 |
+
rank,
|
| 98 |
+
world_size,
|
| 99 |
+
model,
|
| 100 |
+
optimizer,
|
| 101 |
+
scheduler,
|
| 102 |
+
loss,
|
| 103 |
+
train_loader,
|
| 104 |
+
val_loader,
|
| 105 |
+
train_sampler,
|
| 106 |
+
epochs,
|
| 107 |
+
verbosity,
|
| 108 |
+
filename = "my_model_temp.pth",
|
| 109 |
+
):
|
| 110 |
+
|
| 111 |
+
train_error = val_error = test_error = epoch_time = float("NaN")
|
| 112 |
+
train_start = time.time()
|
| 113 |
+
best_val_error = 1e10
|
| 114 |
+
model_best = model
|
| 115 |
+
##Start training over epochs loop
|
| 116 |
+
for epoch in range(1, epochs + 1):
|
| 117 |
+
|
| 118 |
+
lr = scheduler.optimizer.param_groups[0]["lr"]
|
| 119 |
+
if rank not in ("cpu", "cuda"):
|
| 120 |
+
train_sampler.set_epoch(epoch)
|
| 121 |
+
##Train model
|
| 122 |
+
train_error = train(model, optimizer, train_loader, loss, rank=rank)
|
| 123 |
+
if rank not in ("cpu", "cuda"):
|
| 124 |
+
torch.distributed.reduce(train_error, dst=0)
|
| 125 |
+
train_error = train_error / world_size
|
| 126 |
+
|
| 127 |
+
##Get validation performance
|
| 128 |
+
if rank not in ("cpu", "cuda"):
|
| 129 |
+
dist.barrier()
|
| 130 |
+
if val_loader != None and rank in (0, "cpu", "cuda"):
|
| 131 |
+
if rank not in ("cpu", "cuda"):
|
| 132 |
+
val_error = evaluate(
|
| 133 |
+
val_loader, model.module, loss, rank=rank, out=False
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
val_error = evaluate(val_loader, model, loss, rank=rank, out=False)
|
| 137 |
+
|
| 138 |
+
##Train loop timings
|
| 139 |
+
epoch_time = time.time() - train_start
|
| 140 |
+
train_start = time.time()
|
| 141 |
+
|
| 142 |
+
##remember the best val error and save model and checkpoint
|
| 143 |
+
if val_loader != None and rank in (0, "cpu", "cuda"):
|
| 144 |
+
if val_error == float("NaN") or val_error < best_val_error:
|
| 145 |
+
if rank not in ("cpu", "cuda"):
|
| 146 |
+
model_best = copy.deepcopy(model.module)
|
| 147 |
+
torch.save(
|
| 148 |
+
{
|
| 149 |
+
"state_dict": model.state_dict(),
|
| 150 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 151 |
+
"scheduler_state_dict": scheduler.state_dict(),
|
| 152 |
+
"full_model": model,
|
| 153 |
+
},
|
| 154 |
+
filename,
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
model_best = copy.deepcopy(model)
|
| 158 |
+
torch.save(
|
| 159 |
+
{
|
| 160 |
+
"state_dict": model.state_dict(),
|
| 161 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 162 |
+
"scheduler_state_dict": scheduler.state_dict(),
|
| 163 |
+
"full_model": model,
|
| 164 |
+
},
|
| 165 |
+
filename,
|
| 166 |
+
)
|
| 167 |
+
best_val_error = min(val_error, best_val_error)
|
| 168 |
+
elif val_loader == None and rank in (0, "cpu", "cuda"):
|
| 169 |
+
if rank not in ("cpu", "cuda"):
|
| 170 |
+
model_best = copy.deepcopy(model.module)
|
| 171 |
+
torch.save(
|
| 172 |
+
{
|
| 173 |
+
"state_dict": model.state_dict(),
|
| 174 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 175 |
+
"scheduler_state_dict": scheduler.state_dict(),
|
| 176 |
+
"full_model": model,
|
| 177 |
+
},
|
| 178 |
+
filename,
|
| 179 |
+
)
|
| 180 |
+
else:
|
| 181 |
+
model_best = copy.deepcopy(model)
|
| 182 |
+
torch.save(
|
| 183 |
+
{
|
| 184 |
+
"state_dict": model.state_dict(),
|
| 185 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 186 |
+
"scheduler_state_dict": scheduler.state_dict(),
|
| 187 |
+
"full_model": model,
|
| 188 |
+
},
|
| 189 |
+
filename,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
##scheduler on train error
|
| 193 |
+
scheduler.step(train_error)
|
| 194 |
+
|
| 195 |
+
##Print performance
|
| 196 |
+
if epoch % verbosity == 0:
|
| 197 |
+
if rank in (0, "cpu", "cuda"):
|
| 198 |
+
print(
|
| 199 |
+
"Epoch: {:04d}, Learning Rate: {:.6f}, Training Error: {:.5f}, Val Error: {:.5f}, Time per epoch (s): {:.5f}".format(
|
| 200 |
+
epoch, lr, train_error, val_error, epoch_time
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
if rank not in ("cpu", "cuda"):
|
| 205 |
+
dist.barrier()
|
| 206 |
+
|
| 207 |
+
return model_best
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
##Write results to csv file
|
| 211 |
+
def write_results(output, filename):
|
| 212 |
+
shape = output.shape
|
| 213 |
+
with open(filename, "w") as f:
|
| 214 |
+
csvwriter = csv.writer(f)
|
| 215 |
+
for i in range(0, len(output)):
|
| 216 |
+
if i == 0:
|
| 217 |
+
csvwriter.writerow(
|
| 218 |
+
["ids"]
|
| 219 |
+
+ ["target"] * int((shape[1] - 1) / 2)
|
| 220 |
+
+ ["prediction"] * int((shape[1] - 1) / 2)
|
| 221 |
+
)
|
| 222 |
+
elif i > 0:
|
| 223 |
+
csvwriter.writerow(output[i - 1, :])
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
##Pytorch ddp setup
|
| 227 |
+
def ddp_setup(rank, world_size):
|
| 228 |
+
if rank in ("cpu", "cuda"):
|
| 229 |
+
return
|
| 230 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 231 |
+
os.environ["MASTER_PORT"] = "12355"
|
| 232 |
+
if platform.system() == 'Windows':
|
| 233 |
+
dist.init_process_group("gloo", rank=rank, world_size=world_size)
|
| 234 |
+
else:
|
| 235 |
+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
| 236 |
+
torch.backends.cudnn.enabled = False
|
| 237 |
+
torch.backends.cudnn.benchmark = True
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
##Pytorch model setup
|
| 241 |
+
def model_setup(
|
| 242 |
+
rank,
|
| 243 |
+
model_name,
|
| 244 |
+
model_params,
|
| 245 |
+
dataset,
|
| 246 |
+
load_model=False,
|
| 247 |
+
model_path=None,
|
| 248 |
+
print_model=True,
|
| 249 |
+
):
|
| 250 |
+
model = getattr(models, model_name)(
|
| 251 |
+
data=dataset, **(model_params if model_params is not None else {})
|
| 252 |
+
).to(rank)
|
| 253 |
+
if load_model == "True":
|
| 254 |
+
assert os.path.exists(model_path), "Saved model not found"
|
| 255 |
+
if str(rank) in ("cpu"):
|
| 256 |
+
saved = torch.load(model_path, map_location=torch.device("cpu"))
|
| 257 |
+
else:
|
| 258 |
+
saved = torch.load(model_path)
|
| 259 |
+
model.load_state_dict(saved["model_state_dict"])
|
| 260 |
+
# optimizer.load_state_dict(saved['optimizer_state_dict'])
|
| 261 |
+
|
| 262 |
+
# DDP
|
| 263 |
+
if rank not in ("cpu", "cuda"):
|
| 264 |
+
model = DistributedDataParallel(
|
| 265 |
+
model, device_ids=[rank], find_unused_parameters=True
|
| 266 |
+
)
|
| 267 |
+
# model = DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=False)
|
| 268 |
+
if print_model == True and rank in (0, "cpu", "cuda"):
|
| 269 |
+
model_summary(model)
|
| 270 |
+
return model
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
##Pytorch loader setup
|
| 274 |
+
def loader_setup(
|
| 275 |
+
train_ratio,
|
| 276 |
+
val_ratio,
|
| 277 |
+
test_ratio,
|
| 278 |
+
batch_size,
|
| 279 |
+
dataset,
|
| 280 |
+
rank,
|
| 281 |
+
seed,
|
| 282 |
+
world_size=0,
|
| 283 |
+
num_workers=0,
|
| 284 |
+
):
|
| 285 |
+
##Split datasets
|
| 286 |
+
train_dataset, val_dataset, test_dataset = process.split_data(
|
| 287 |
+
dataset, train_ratio, val_ratio, test_ratio, seed
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
##DDP
|
| 291 |
+
if rank not in ("cpu", "cuda"):
|
| 292 |
+
train_sampler = DistributedSampler(
|
| 293 |
+
train_dataset, num_replicas=world_size, rank=rank
|
| 294 |
+
)
|
| 295 |
+
elif rank in ("cpu", "cuda"):
|
| 296 |
+
train_sampler = None
|
| 297 |
+
|
| 298 |
+
##Load data
|
| 299 |
+
train_loader = val_loader = test_loader = None
|
| 300 |
+
train_loader = DataLoader(
|
| 301 |
+
train_dataset,
|
| 302 |
+
batch_size=batch_size,
|
| 303 |
+
shuffle=(train_sampler is None),
|
| 304 |
+
num_workers=num_workers,
|
| 305 |
+
pin_memory=True,
|
| 306 |
+
sampler=train_sampler,
|
| 307 |
+
)
|
| 308 |
+
# may scale down batch size if memory is an issue
|
| 309 |
+
if rank in (0, "cpu", "cuda"):
|
| 310 |
+
if len(val_dataset) > 0:
|
| 311 |
+
val_loader = DataLoader(
|
| 312 |
+
val_dataset,
|
| 313 |
+
batch_size=batch_size,
|
| 314 |
+
shuffle=False,
|
| 315 |
+
num_workers=num_workers,
|
| 316 |
+
pin_memory=True,
|
| 317 |
+
)
|
| 318 |
+
if len(test_dataset) > 0:
|
| 319 |
+
test_loader = DataLoader(
|
| 320 |
+
test_dataset,
|
| 321 |
+
batch_size=batch_size,
|
| 322 |
+
shuffle=False,
|
| 323 |
+
num_workers=num_workers,
|
| 324 |
+
pin_memory=True,
|
| 325 |
+
)
|
| 326 |
+
return (
|
| 327 |
+
train_loader,
|
| 328 |
+
val_loader,
|
| 329 |
+
test_loader,
|
| 330 |
+
train_sampler,
|
| 331 |
+
train_dataset,
|
| 332 |
+
val_dataset,
|
| 333 |
+
test_dataset,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def loader_setup_CV(index, batch_size, dataset, rank, world_size=0, num_workers=0):
|
| 338 |
+
##Split datasets
|
| 339 |
+
train_dataset = [x for i, x in enumerate(dataset) if i != index]
|
| 340 |
+
train_dataset = torch.utils.data.ConcatDataset(train_dataset)
|
| 341 |
+
test_dataset = dataset[index]
|
| 342 |
+
|
| 343 |
+
##DDP
|
| 344 |
+
if rank not in ("cpu", "cuda"):
|
| 345 |
+
train_sampler = DistributedSampler(
|
| 346 |
+
train_dataset, num_replicas=world_size, rank=rank
|
| 347 |
+
)
|
| 348 |
+
elif rank in ("cpu", "cuda"):
|
| 349 |
+
train_sampler = None
|
| 350 |
+
|
| 351 |
+
train_loader = val_loader = test_loader = None
|
| 352 |
+
train_loader = DataLoader(
|
| 353 |
+
train_dataset,
|
| 354 |
+
batch_size=batch_size,
|
| 355 |
+
shuffle=(train_sampler is None),
|
| 356 |
+
num_workers=num_workers,
|
| 357 |
+
pin_memory=True,
|
| 358 |
+
sampler=train_sampler,
|
| 359 |
+
)
|
| 360 |
+
if rank in (0, "cpu", "cuda"):
|
| 361 |
+
test_loader = DataLoader(
|
| 362 |
+
test_dataset,
|
| 363 |
+
batch_size=batch_size,
|
| 364 |
+
shuffle=False,
|
| 365 |
+
num_workers=num_workers,
|
| 366 |
+
pin_memory=True,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
return train_loader, test_loader, train_sampler, train_dataset, test_dataset
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
################################################################################
|
| 373 |
+
# Trainers
|
| 374 |
+
################################################################################
|
| 375 |
+
|
| 376 |
+
###Regular training with train, val, test split
|
| 377 |
+
def train_regular(
|
| 378 |
+
rank,
|
| 379 |
+
world_size,
|
| 380 |
+
data_path,
|
| 381 |
+
job_parameters=None,
|
| 382 |
+
training_parameters=None,
|
| 383 |
+
model_parameters=None,
|
| 384 |
+
):
|
| 385 |
+
##DDP
|
| 386 |
+
ddp_setup(rank, world_size)
|
| 387 |
+
##some issues with DDP learning rate
|
| 388 |
+
if rank not in ("cpu", "cuda"):
|
| 389 |
+
model_parameters["lr"] = model_parameters["lr"] * world_size
|
| 390 |
+
|
| 391 |
+
##Get dataset
|
| 392 |
+
dataset = process.get_dataset(data_path, training_parameters["target_index"], False)
|
| 393 |
+
|
| 394 |
+
if rank not in ("cpu", "cuda"):
|
| 395 |
+
dist.barrier()
|
| 396 |
+
|
| 397 |
+
##Set up loader
|
| 398 |
+
(
|
| 399 |
+
train_loader,
|
| 400 |
+
val_loader,
|
| 401 |
+
test_loader,
|
| 402 |
+
train_sampler,
|
| 403 |
+
train_dataset,
|
| 404 |
+
_,
|
| 405 |
+
_,
|
| 406 |
+
) = loader_setup(
|
| 407 |
+
training_parameters["train_ratio"],
|
| 408 |
+
training_parameters["val_ratio"],
|
| 409 |
+
training_parameters["test_ratio"],
|
| 410 |
+
model_parameters["batch_size"],
|
| 411 |
+
dataset,
|
| 412 |
+
rank,
|
| 413 |
+
job_parameters["seed"],
|
| 414 |
+
world_size,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
##Set up model
|
| 418 |
+
model = model_setup(
|
| 419 |
+
rank,
|
| 420 |
+
model_parameters["model"],
|
| 421 |
+
model_parameters,
|
| 422 |
+
dataset,
|
| 423 |
+
job_parameters["load_model"],
|
| 424 |
+
job_parameters["model_path"],
|
| 425 |
+
model_parameters.get("print_model", True),
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
##Set-up optimizer & scheduler
|
| 429 |
+
optimizer = getattr(torch.optim, model_parameters["optimizer"])(
|
| 430 |
+
model.parameters(),
|
| 431 |
+
lr=model_parameters["lr"],
|
| 432 |
+
**model_parameters["optimizer_args"]
|
| 433 |
+
)
|
| 434 |
+
scheduler = getattr(torch.optim.lr_scheduler, model_parameters["scheduler"])(
|
| 435 |
+
optimizer, **model_parameters["scheduler_args"]
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
##Start training
|
| 439 |
+
model = trainer(
|
| 440 |
+
rank,
|
| 441 |
+
world_size,
|
| 442 |
+
model,
|
| 443 |
+
optimizer,
|
| 444 |
+
scheduler,
|
| 445 |
+
training_parameters["loss"],
|
| 446 |
+
train_loader,
|
| 447 |
+
val_loader,
|
| 448 |
+
train_sampler,
|
| 449 |
+
model_parameters["epochs"],
|
| 450 |
+
training_parameters["verbosity"],
|
| 451 |
+
"my_model_temp.pth",
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
if rank in (0, "cpu", "cuda"):
|
| 455 |
+
|
| 456 |
+
train_error = val_error = test_error = float("NaN")
|
| 457 |
+
|
| 458 |
+
##workaround to get training output in DDP mode
|
| 459 |
+
##outputs are slightly different, could be due to dropout or batchnorm?
|
| 460 |
+
train_loader = DataLoader(
|
| 461 |
+
train_dataset,
|
| 462 |
+
batch_size=model_parameters["batch_size"],
|
| 463 |
+
shuffle=False,
|
| 464 |
+
num_workers=0,
|
| 465 |
+
pin_memory=True,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
##Get train error in eval mode
|
| 469 |
+
train_error, train_out = evaluate(
|
| 470 |
+
train_loader, model, training_parameters["loss"], rank, out=True
|
| 471 |
+
)
|
| 472 |
+
print("Train Error: {:.5f}".format(train_error))
|
| 473 |
+
|
| 474 |
+
##Get val error
|
| 475 |
+
if val_loader != None:
|
| 476 |
+
val_error, val_out = evaluate(
|
| 477 |
+
val_loader, model, training_parameters["loss"], rank, out=True
|
| 478 |
+
)
|
| 479 |
+
print("Val Error: {:.5f}".format(val_error))
|
| 480 |
+
|
| 481 |
+
##Get test error
|
| 482 |
+
if test_loader != None:
|
| 483 |
+
test_error, test_out = evaluate(
|
| 484 |
+
test_loader, model, training_parameters["loss"], rank, out=True
|
| 485 |
+
)
|
| 486 |
+
print("Test Error: {:.5f}".format(test_error))
|
| 487 |
+
|
| 488 |
+
##Save model
|
| 489 |
+
if job_parameters["save_model"] == "True":
|
| 490 |
+
|
| 491 |
+
if rank not in ("cpu", "cuda"):
|
| 492 |
+
torch.save(
|
| 493 |
+
{
|
| 494 |
+
"model_state_dict": model.state_dict(),
|
| 495 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 496 |
+
"scheduler_state_dict": scheduler.state_dict(),
|
| 497 |
+
"full_model": model,
|
| 498 |
+
},
|
| 499 |
+
job_parameters["model_path"],
|
| 500 |
+
)
|
| 501 |
+
else:
|
| 502 |
+
torch.save(
|
| 503 |
+
{
|
| 504 |
+
"model_state_dict": model.state_dict(),
|
| 505 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 506 |
+
"scheduler_state_dict": scheduler.state_dict(),
|
| 507 |
+
"full_model": model,
|
| 508 |
+
},
|
| 509 |
+
job_parameters["model_path"],
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
##Write outputs
|
| 513 |
+
if job_parameters["write_output"] == "True":
|
| 514 |
+
|
| 515 |
+
write_results(
|
| 516 |
+
train_out, str(job_parameters["job_name"]) + "_train_outputs.csv"
|
| 517 |
+
)
|
| 518 |
+
if val_loader != None:
|
| 519 |
+
write_results(
|
| 520 |
+
val_out, str(job_parameters["job_name"]) + "_val_outputs.csv"
|
| 521 |
+
)
|
| 522 |
+
if test_loader != None:
|
| 523 |
+
write_results(
|
| 524 |
+
test_out, str(job_parameters["job_name"]) + "_test_outputs.csv"
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
if rank not in ("cpu", "cuda"):
|
| 528 |
+
dist.destroy_process_group()
|
| 529 |
+
|
| 530 |
+
##Write out model performance to file
|
| 531 |
+
error_values = np.array((train_error.cpu(), val_error.cpu(), test_error.cpu()))
|
| 532 |
+
if job_parameters.get("write_error") == "True":
|
| 533 |
+
np.savetxt(
|
| 534 |
+
job_parameters["job_name"] + "_errorvalues.csv",
|
| 535 |
+
error_values[np.newaxis, ...],
|
| 536 |
+
delimiter=",",
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
return error_values
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
###Predict using a saved movel
|
| 543 |
+
def predict(dataset, loss, job_parameters=None):
|
| 544 |
+
|
| 545 |
+
rank = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 546 |
+
|
| 547 |
+
##Loads predict dataset in one go, care needed for large datasets)
|
| 548 |
+
loader = DataLoader(
|
| 549 |
+
dataset,
|
| 550 |
+
batch_size=128,
|
| 551 |
+
shuffle=False,
|
| 552 |
+
num_workers=0,
|
| 553 |
+
pin_memory=True,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
##Load saved model
|
| 557 |
+
assert os.path.exists(job_parameters["model_path"]), "Saved model not found"
|
| 558 |
+
if str(rank) == "cpu":
|
| 559 |
+
saved = torch.load(
|
| 560 |
+
job_parameters["model_path"], map_location=torch.device("cpu")
|
| 561 |
+
)
|
| 562 |
+
else:
|
| 563 |
+
saved = torch.load(
|
| 564 |
+
job_parameters["model_path"], map_location=torch.device("cuda")
|
| 565 |
+
)
|
| 566 |
+
model = saved["full_model"]
|
| 567 |
+
model = model.to(rank)
|
| 568 |
+
model_summary(model)
|
| 569 |
+
|
| 570 |
+
##Get predictions
|
| 571 |
+
time_start = time.time()
|
| 572 |
+
test_error, test_out = evaluate(loader, model, loss, rank, out=True)
|
| 573 |
+
elapsed_time = time.time() - time_start
|
| 574 |
+
|
| 575 |
+
print("Evaluation time (s): {:.5f}".format(elapsed_time))
|
| 576 |
+
|
| 577 |
+
##Write output
|
| 578 |
+
if job_parameters["write_output"] == "True":
|
| 579 |
+
write_results(
|
| 580 |
+
test_out, str(job_parameters["job_name"]) + "_predicted_outputs.csv"
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
return test_error
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
###n-fold cross validation
|
| 587 |
+
def train_CV(
|
| 588 |
+
rank,
|
| 589 |
+
world_size,
|
| 590 |
+
data_path,
|
| 591 |
+
job_parameters=None,
|
| 592 |
+
training_parameters=None,
|
| 593 |
+
model_parameters=None,
|
| 594 |
+
):
|
| 595 |
+
|
| 596 |
+
job_parameters["load_model"] = "False"
|
| 597 |
+
job_parameters["save_model"] = "False"
|
| 598 |
+
job_parameters["model_path"] = None
|
| 599 |
+
##DDP
|
| 600 |
+
ddp_setup(rank, world_size)
|
| 601 |
+
##some issues with DDP learning rate
|
| 602 |
+
if rank not in ("cpu", "cuda"):
|
| 603 |
+
model_parameters["lr"] = model_parameters["lr"] * world_size
|
| 604 |
+
|
| 605 |
+
##Get dataset
|
| 606 |
+
dataset = process.get_dataset(data_path, training_parameters["target_index"], False)
|
| 607 |
+
|
| 608 |
+
##Split datasets
|
| 609 |
+
cv_dataset = process.split_data_CV(
|
| 610 |
+
dataset, num_folds=job_parameters["cv_folds"], seed=job_parameters["seed"]
|
| 611 |
+
)
|
| 612 |
+
cv_error = 0
|
| 613 |
+
|
| 614 |
+
for index in range(0, len(cv_dataset)):
|
| 615 |
+
|
| 616 |
+
##Set up model
|
| 617 |
+
if index == 0:
|
| 618 |
+
model = model_setup(
|
| 619 |
+
rank,
|
| 620 |
+
model_parameters["model"],
|
| 621 |
+
model_parameters,
|
| 622 |
+
dataset,
|
| 623 |
+
job_parameters["load_model"],
|
| 624 |
+
job_parameters["model_path"],
|
| 625 |
+
print_model=True,
|
| 626 |
+
)
|
| 627 |
+
else:
|
| 628 |
+
model = model_setup(
|
| 629 |
+
rank,
|
| 630 |
+
model_parameters["model"],
|
| 631 |
+
model_parameters,
|
| 632 |
+
dataset,
|
| 633 |
+
job_parameters["load_model"],
|
| 634 |
+
job_parameters["model_path"],
|
| 635 |
+
print_model=False,
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
##Set-up optimizer & scheduler
|
| 639 |
+
optimizer = getattr(torch.optim, model_parameters["optimizer"])(
|
| 640 |
+
model.parameters(),
|
| 641 |
+
lr=model_parameters["lr"],
|
| 642 |
+
**model_parameters["optimizer_args"]
|
| 643 |
+
)
|
| 644 |
+
scheduler = getattr(torch.optim.lr_scheduler, model_parameters["scheduler"])(
|
| 645 |
+
optimizer, **model_parameters["scheduler_args"]
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
##Set up loader
|
| 649 |
+
train_loader, test_loader, train_sampler, train_dataset, _ = loader_setup_CV(
|
| 650 |
+
index, model_parameters["batch_size"], cv_dataset, rank, world_size
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
##Start training
|
| 654 |
+
model = trainer(
|
| 655 |
+
rank,
|
| 656 |
+
world_size,
|
| 657 |
+
model,
|
| 658 |
+
optimizer,
|
| 659 |
+
scheduler,
|
| 660 |
+
training_parameters["loss"],
|
| 661 |
+
train_loader,
|
| 662 |
+
None,
|
| 663 |
+
train_sampler,
|
| 664 |
+
model_parameters["epochs"],
|
| 665 |
+
training_parameters["verbosity"],
|
| 666 |
+
"my_model_temp.pth",
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
if rank not in ("cpu", "cuda"):
|
| 670 |
+
dist.barrier()
|
| 671 |
+
|
| 672 |
+
if rank in (0, "cpu", "cuda"):
|
| 673 |
+
|
| 674 |
+
train_loader = DataLoader(
|
| 675 |
+
train_dataset,
|
| 676 |
+
batch_size=model_parameters["batch_size"],
|
| 677 |
+
shuffle=False,
|
| 678 |
+
num_workers=0,
|
| 679 |
+
pin_memory=True,
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
##Get train error
|
| 683 |
+
train_error, train_out = evaluate(
|
| 684 |
+
train_loader, model, training_parameters["loss"], rank, out=True
|
| 685 |
+
)
|
| 686 |
+
print("Train Error: {:.5f}".format(train_error))
|
| 687 |
+
|
| 688 |
+
##Get test error
|
| 689 |
+
test_error, test_out = evaluate(
|
| 690 |
+
test_loader, model, training_parameters["loss"], rank, out=True
|
| 691 |
+
)
|
| 692 |
+
print("Test Error: {:.5f}".format(test_error))
|
| 693 |
+
|
| 694 |
+
cv_error = cv_error + test_error
|
| 695 |
+
|
| 696 |
+
if index == 0:
|
| 697 |
+
total_rows = test_out
|
| 698 |
+
else:
|
| 699 |
+
total_rows = np.vstack((total_rows, test_out))
|
| 700 |
+
|
| 701 |
+
##Write output
|
| 702 |
+
if rank in (0, "cpu", "cuda"):
|
| 703 |
+
if job_parameters["write_output"] == "True":
|
| 704 |
+
if test_loader != None:
|
| 705 |
+
write_results(
|
| 706 |
+
total_rows, str(job_parameters["job_name"]) + "_CV_outputs.csv"
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
cv_error = cv_error / len(cv_dataset)
|
| 710 |
+
print("CV Error: {:.5f}".format(cv_error))
|
| 711 |
+
|
| 712 |
+
if rank not in ("cpu", "cuda"):
|
| 713 |
+
dist.destroy_process_group()
|
| 714 |
+
|
| 715 |
+
return cv_error
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
### Repeat training for n times
|
| 719 |
+
def train_repeat(
|
| 720 |
+
data_path,
|
| 721 |
+
job_parameters=None,
|
| 722 |
+
training_parameters=None,
|
| 723 |
+
model_parameters=None,
|
| 724 |
+
):
|
| 725 |
+
|
| 726 |
+
world_size = torch.cuda.device_count()
|
| 727 |
+
job_name = job_parameters["job_name"]
|
| 728 |
+
model_path = job_parameters["model_path"]
|
| 729 |
+
job_parameters["write_error"] = "True"
|
| 730 |
+
job_parameters["load_model"] = "False"
|
| 731 |
+
job_parameters["save_model"] = "False"
|
| 732 |
+
##Loop over number of repeated trials
|
| 733 |
+
for i in range(0, job_parameters["repeat_trials"]):
|
| 734 |
+
|
| 735 |
+
##new seed each time for different data split
|
| 736 |
+
job_parameters["seed"] = np.random.randint(1, 1e6)
|
| 737 |
+
|
| 738 |
+
if i == 0:
|
| 739 |
+
model_parameters["print_model"] = True
|
| 740 |
+
else:
|
| 741 |
+
model_parameters["print_model"] = False
|
| 742 |
+
|
| 743 |
+
job_parameters["job_name"] = job_name + str(i)
|
| 744 |
+
job_parameters["model_path"] = str(i) + "_" + model_path
|
| 745 |
+
|
| 746 |
+
if world_size == 0:
|
| 747 |
+
print("Running on CPU - this will be slow")
|
| 748 |
+
training.train_regular(
|
| 749 |
+
"cpu",
|
| 750 |
+
world_size,
|
| 751 |
+
data_path,
|
| 752 |
+
job_parameters,
|
| 753 |
+
training_parameters,
|
| 754 |
+
model_parameters,
|
| 755 |
+
)
|
| 756 |
+
elif world_size > 0:
|
| 757 |
+
if job_parameters["parallel"] == "True":
|
| 758 |
+
print("Running on", world_size, "GPUs")
|
| 759 |
+
mp.spawn(
|
| 760 |
+
training.train_regular,
|
| 761 |
+
args=(
|
| 762 |
+
world_size,
|
| 763 |
+
data_path,
|
| 764 |
+
job_parameters,
|
| 765 |
+
training_parameters,
|
| 766 |
+
model_parameters,
|
| 767 |
+
),
|
| 768 |
+
nprocs=world_size,
|
| 769 |
+
join=True,
|
| 770 |
+
)
|
| 771 |
+
if job_parameters["parallel"] == "False":
|
| 772 |
+
print("Running on one GPU")
|
| 773 |
+
training.train_regular(
|
| 774 |
+
"cuda",
|
| 775 |
+
world_size,
|
| 776 |
+
data_path,
|
| 777 |
+
job_parameters,
|
| 778 |
+
training_parameters,
|
| 779 |
+
model_parameters,
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
##Compile error metrics from individual trials
|
| 783 |
+
print("Individual training finished.")
|
| 784 |
+
print("Compiling metrics from individual trials...")
|
| 785 |
+
error_values = np.zeros((job_parameters["repeat_trials"], 3))
|
| 786 |
+
for i in range(0, job_parameters["repeat_trials"]):
|
| 787 |
+
filename = job_name + str(i) + "_errorvalues.csv"
|
| 788 |
+
error_values[i] = np.genfromtxt(filename, delimiter=",")
|
| 789 |
+
mean_values = [
|
| 790 |
+
np.mean(error_values[:, 0]),
|
| 791 |
+
np.mean(error_values[:, 1]),
|
| 792 |
+
np.mean(error_values[:, 2]),
|
| 793 |
+
]
|
| 794 |
+
std_values = [
|
| 795 |
+
np.std(error_values[:, 0]),
|
| 796 |
+
np.std(error_values[:, 1]),
|
| 797 |
+
np.std(error_values[:, 2]),
|
| 798 |
+
]
|
| 799 |
+
|
| 800 |
+
##Print error
|
| 801 |
+
print(
|
| 802 |
+
"Training Error Avg: {:.3f}, Training Standard Dev: {:.3f}".format(
|
| 803 |
+
mean_values[0], std_values[0]
|
| 804 |
+
)
|
| 805 |
+
)
|
| 806 |
+
print(
|
| 807 |
+
"Val Error Avg: {:.3f}, Val Standard Dev: {:.3f}".format(
|
| 808 |
+
mean_values[1], std_values[1]
|
| 809 |
+
)
|
| 810 |
+
)
|
| 811 |
+
print(
|
| 812 |
+
"Test Error Avg: {:.3f}, Test Standard Dev: {:.3f}".format(
|
| 813 |
+
mean_values[2], std_values[2]
|
| 814 |
+
)
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
##Write error metrics
|
| 818 |
+
if job_parameters["write_output"] == "True":
|
| 819 |
+
with open(job_name + "_all_errorvalues.csv", "w") as f:
|
| 820 |
+
csvwriter = csv.writer(f)
|
| 821 |
+
csvwriter.writerow(
|
| 822 |
+
[
|
| 823 |
+
"",
|
| 824 |
+
"Training",
|
| 825 |
+
"Validation",
|
| 826 |
+
"Test",
|
| 827 |
+
]
|
| 828 |
+
)
|
| 829 |
+
for i in range(0, len(error_values)):
|
| 830 |
+
csvwriter.writerow(
|
| 831 |
+
[
|
| 832 |
+
"Trial " + str(i),
|
| 833 |
+
error_values[i, 0],
|
| 834 |
+
error_values[i, 1],
|
| 835 |
+
error_values[i, 2],
|
| 836 |
+
]
|
| 837 |
+
)
|
| 838 |
+
csvwriter.writerow(["Mean", mean_values[0], mean_values[1], mean_values[2]])
|
| 839 |
+
csvwriter.writerow(["Std", std_values[0], std_values[1], std_values[2]])
|
| 840 |
+
elif job_parameters["write_output"] == "False":
|
| 841 |
+
for i in range(0, job_parameters["repeat_trials"]):
|
| 842 |
+
filename = job_name + str(i) + "_errorvalues.csv"
|
| 843 |
+
os.remove(filename)
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
###Hyperparameter optimization
|
| 847 |
+
# trainable function for ray tune (no parallel, max 1 GPU per job)
|
| 848 |
+
def tune_trainable(config, checkpoint_dir=None, data_path=None):
|
| 849 |
+
|
| 850 |
+
# imports
|
| 851 |
+
from ray import tune
|
| 852 |
+
|
| 853 |
+
print("Hyperparameter trial start")
|
| 854 |
+
hyper_args = config["hyper_args"]
|
| 855 |
+
job_parameters = config["job_parameters"]
|
| 856 |
+
processing_parameters = config["processing_parameters"]
|
| 857 |
+
training_parameters = config["training_parameters"]
|
| 858 |
+
model_parameters = config["model_parameters"]
|
| 859 |
+
|
| 860 |
+
##Merge hyperparameter parameters with constant parameters, with precedence over hyperparameter ones
|
| 861 |
+
##Omit training and job parameters as they should not be part of hyperparameter opt, in theory
|
| 862 |
+
model_parameters = {**model_parameters, **hyper_args}
|
| 863 |
+
processing_parameters = {**processing_parameters, **hyper_args}
|
| 864 |
+
|
| 865 |
+
##Assume 1 gpu or 1 cpu per trial, no functionality for parallel yet
|
| 866 |
+
world_size = 1
|
| 867 |
+
rank = "cpu"
|
| 868 |
+
if torch.cuda.is_available():
|
| 869 |
+
rank = "cuda"
|
| 870 |
+
|
| 871 |
+
##Reprocess data in a separate directory to prevent conflict
|
| 872 |
+
if job_parameters["reprocess"] == "True":
|
| 873 |
+
time = datetime.now()
|
| 874 |
+
processing_parameters["processed_path"] = time.strftime("%H%M%S%f")
|
| 875 |
+
processing_parameters["verbose"] = "False"
|
| 876 |
+
data_path = os.path.dirname(
|
| 877 |
+
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
| 878 |
+
)
|
| 879 |
+
data_path = os.path.join(data_path, processing_parameters["data_path"])
|
| 880 |
+
data_path = os.path.normpath(data_path)
|
| 881 |
+
print("Data path", data_path)
|
| 882 |
+
|
| 883 |
+
##Set up dataset
|
| 884 |
+
dataset = process.get_dataset(
|
| 885 |
+
data_path,
|
| 886 |
+
training_parameters["target_index"],
|
| 887 |
+
job_parameters["reprocess"],
|
| 888 |
+
processing_parameters,
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
##Set up loader
|
| 892 |
+
(
|
| 893 |
+
train_loader,
|
| 894 |
+
val_loader,
|
| 895 |
+
test_loader,
|
| 896 |
+
train_sampler,
|
| 897 |
+
train_dataset,
|
| 898 |
+
_,
|
| 899 |
+
_,
|
| 900 |
+
) = loader_setup(
|
| 901 |
+
training_parameters["train_ratio"],
|
| 902 |
+
training_parameters["val_ratio"],
|
| 903 |
+
training_parameters["test_ratio"],
|
| 904 |
+
model_parameters["batch_size"],
|
| 905 |
+
dataset,
|
| 906 |
+
rank,
|
| 907 |
+
job_parameters["seed"],
|
| 908 |
+
world_size,
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
##Set up model
|
| 912 |
+
model = model_setup(
|
| 913 |
+
rank,
|
| 914 |
+
model_parameters["model"],
|
| 915 |
+
model_parameters,
|
| 916 |
+
dataset,
|
| 917 |
+
False,
|
| 918 |
+
None,
|
| 919 |
+
False,
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
##Set-up optimizer & scheduler
|
| 923 |
+
optimizer = getattr(torch.optim, model_parameters["optimizer"])(
|
| 924 |
+
model.parameters(),
|
| 925 |
+
lr=model_parameters["lr"],
|
| 926 |
+
**model_parameters["optimizer_args"]
|
| 927 |
+
)
|
| 928 |
+
scheduler = getattr(torch.optim.lr_scheduler, model_parameters["scheduler"])(
|
| 929 |
+
optimizer, **model_parameters["scheduler_args"]
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
##Load checkpoint
|
| 933 |
+
if checkpoint_dir:
|
| 934 |
+
model_state, optimizer_state, scheduler_state = torch.load(
|
| 935 |
+
os.path.join(checkpoint_dir, "checkpoint")
|
| 936 |
+
)
|
| 937 |
+
model.load_state_dict(model_state)
|
| 938 |
+
optimizer.load_state_dict(optimizer_state)
|
| 939 |
+
scheduler.load_state_dict(scheduler_state)
|
| 940 |
+
|
| 941 |
+
##Training loop
|
| 942 |
+
for epoch in range(1, model_parameters["epochs"] + 1):
|
| 943 |
+
lr = scheduler.optimizer.param_groups[0]["lr"]
|
| 944 |
+
train_error = train(
|
| 945 |
+
model, optimizer, train_loader, training_parameters["loss"], rank=rank
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
val_error = evaluate(
|
| 949 |
+
val_loader, model, training_parameters["loss"], rank=rank, out=False
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
##Delete processed data
|
| 953 |
+
if epoch == model_parameters["epochs"]:
|
| 954 |
+
if (
|
| 955 |
+
job_parameters["reprocess"] == "True"
|
| 956 |
+
and job_parameters["hyper_delete_processed"] == "True"
|
| 957 |
+
):
|
| 958 |
+
shutil.rmtree(
|
| 959 |
+
os.path.join(data_path, processing_parameters["processed_path"])
|
| 960 |
+
)
|
| 961 |
+
print("Finished Training")
|
| 962 |
+
|
| 963 |
+
##Update to tune
|
| 964 |
+
if epoch % job_parameters["hyper_iter"] == 0:
|
| 965 |
+
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
|
| 966 |
+
path = os.path.join(checkpoint_dir, "checkpoint")
|
| 967 |
+
torch.save(
|
| 968 |
+
(
|
| 969 |
+
model.state_dict(),
|
| 970 |
+
optimizer.state_dict(),
|
| 971 |
+
scheduler.state_dict(),
|
| 972 |
+
),
|
| 973 |
+
path,
|
| 974 |
+
)
|
| 975 |
+
##Somehow tune does not recognize value without *1
|
| 976 |
+
tune.report(loss=val_error.cpu().numpy() * 1)
|
| 977 |
+
# tune.report(loss=val_error)
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
# Tune setup
|
| 981 |
+
def tune_setup(
|
| 982 |
+
hyper_args,
|
| 983 |
+
job_parameters,
|
| 984 |
+
processing_parameters,
|
| 985 |
+
training_parameters,
|
| 986 |
+
model_parameters,
|
| 987 |
+
):
|
| 988 |
+
|
| 989 |
+
# imports
|
| 990 |
+
import ray
|
| 991 |
+
from ray import tune
|
| 992 |
+
from ray.tune.schedulers import ASHAScheduler
|
| 993 |
+
from ray.tune.suggest.hyperopt import HyperOptSearch
|
| 994 |
+
from ray.tune.suggest import ConcurrencyLimiter
|
| 995 |
+
from ray.tune import CLIReporter
|
| 996 |
+
|
| 997 |
+
ray.init()
|
| 998 |
+
data_path = "_"
|
| 999 |
+
local_dir = "ray_results"
|
| 1000 |
+
# currently no support for paralleization per trial
|
| 1001 |
+
gpus_per_trial = 1
|
| 1002 |
+
|
| 1003 |
+
##Set up search algo
|
| 1004 |
+
search_algo = HyperOptSearch(metric="loss", mode="min", n_initial_points=5)
|
| 1005 |
+
search_algo = ConcurrencyLimiter(
|
| 1006 |
+
search_algo, max_concurrent=job_parameters["hyper_concurrency"]
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
##Resume run
|
| 1010 |
+
if os.path.exists(local_dir + "/" + job_parameters["job_name"]) and os.path.isdir(
|
| 1011 |
+
local_dir + "/" + job_parameters["job_name"]
|
| 1012 |
+
):
|
| 1013 |
+
if job_parameters["hyper_resume"] == "False":
|
| 1014 |
+
resume = False
|
| 1015 |
+
elif job_parameters["hyper_resume"] == "True":
|
| 1016 |
+
resume = True
|
| 1017 |
+
# else:
|
| 1018 |
+
# resume = "PROMPT"
|
| 1019 |
+
else:
|
| 1020 |
+
resume = False
|
| 1021 |
+
|
| 1022 |
+
##Print out hyperparameters
|
| 1023 |
+
parameter_columns = [
|
| 1024 |
+
element for element in hyper_args.keys() if element not in "global"
|
| 1025 |
+
]
|
| 1026 |
+
parameter_columns = ["hyper_args"]
|
| 1027 |
+
reporter = CLIReporter(
|
| 1028 |
+
max_progress_rows=20, max_error_rows=5, parameter_columns=parameter_columns
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
##Run tune
|
| 1032 |
+
tune_result = tune.run(
|
| 1033 |
+
partial(tune_trainable, data_path=data_path),
|
| 1034 |
+
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
|
| 1035 |
+
config={
|
| 1036 |
+
"hyper_args": hyper_args,
|
| 1037 |
+
"job_parameters": job_parameters,
|
| 1038 |
+
"processing_parameters": processing_parameters,
|
| 1039 |
+
"training_parameters": training_parameters,
|
| 1040 |
+
"model_parameters": model_parameters,
|
| 1041 |
+
},
|
| 1042 |
+
num_samples=job_parameters["hyper_trials"],
|
| 1043 |
+
# scheduler=scheduler,
|
| 1044 |
+
search_alg=search_algo,
|
| 1045 |
+
local_dir=local_dir,
|
| 1046 |
+
progress_reporter=reporter,
|
| 1047 |
+
verbose=job_parameters["hyper_verbosity"],
|
| 1048 |
+
resume=resume,
|
| 1049 |
+
log_to_file=True,
|
| 1050 |
+
name=job_parameters["job_name"],
|
| 1051 |
+
max_failures=4,
|
| 1052 |
+
raise_on_failed_trial=False,
|
| 1053 |
+
# keep_checkpoints_num=job_parameters["hyper_keep_checkpoints_num"],
|
| 1054 |
+
# checkpoint_score_attr="min-loss",
|
| 1055 |
+
stop={
|
| 1056 |
+
"training_iteration": model_parameters["epochs"]
|
| 1057 |
+
// job_parameters["hyper_iter"]
|
| 1058 |
+
},
|
| 1059 |
+
)
|
| 1060 |
+
|
| 1061 |
+
##Get best trial
|
| 1062 |
+
best_trial = tune_result.get_best_trial("loss", "min", "all")
|
| 1063 |
+
# best_trial = tune_result.get_best_trial("loss", "min", "last")
|
| 1064 |
+
|
| 1065 |
+
return best_trial
|
| 1066 |
+
|
| 1067 |
+
|
| 1068 |
+
###Simple ensemble using averages
|
| 1069 |
+
def train_ensemble(
|
| 1070 |
+
data_path,
|
| 1071 |
+
job_parameters=None,
|
| 1072 |
+
training_parameters=None,
|
| 1073 |
+
model_parameters=None,
|
| 1074 |
+
):
|
| 1075 |
+
|
| 1076 |
+
world_size = torch.cuda.device_count()
|
| 1077 |
+
job_name = job_parameters["job_name"]
|
| 1078 |
+
write_output = job_parameters["write_output"]
|
| 1079 |
+
model_path = job_parameters["model_path"]
|
| 1080 |
+
job_parameters["write_error"] = "True"
|
| 1081 |
+
job_parameters["write_output"] = "True"
|
| 1082 |
+
job_parameters["load_model"] = "False"
|
| 1083 |
+
##Loop over number of repeated trials
|
| 1084 |
+
for i in range(0, len(job_parameters["ensemble_list"])):
|
| 1085 |
+
job_parameters["job_name"] = job_name + str(i)
|
| 1086 |
+
job_parameters["model_path"] = (
|
| 1087 |
+
str(i) + "_" + job_parameters["ensemble_list"][i] + "_" + model_path
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
+
if world_size == 0:
|
| 1091 |
+
print("Running on CPU - this will be slow")
|
| 1092 |
+
training.train_regular(
|
| 1093 |
+
"cpu",
|
| 1094 |
+
world_size,
|
| 1095 |
+
data_path,
|
| 1096 |
+
job_parameters,
|
| 1097 |
+
training_parameters,
|
| 1098 |
+
model_parameters[job_parameters["ensemble_list"][i]],
|
| 1099 |
+
)
|
| 1100 |
+
elif world_size > 0:
|
| 1101 |
+
if job_parameters["parallel"] == "True":
|
| 1102 |
+
print("Running on", world_size, "GPUs")
|
| 1103 |
+
mp.spawn(
|
| 1104 |
+
training.train_regular,
|
| 1105 |
+
args=(
|
| 1106 |
+
world_size,
|
| 1107 |
+
data_path,
|
| 1108 |
+
job_parameters,
|
| 1109 |
+
training_parameters,
|
| 1110 |
+
model_parameters[job_parameters["ensemble_list"][i]],
|
| 1111 |
+
),
|
| 1112 |
+
nprocs=world_size,
|
| 1113 |
+
join=True,
|
| 1114 |
+
)
|
| 1115 |
+
if job_parameters["parallel"] == "False":
|
| 1116 |
+
print("Running on one GPU")
|
| 1117 |
+
training.train_regular(
|
| 1118 |
+
"cuda",
|
| 1119 |
+
world_size,
|
| 1120 |
+
data_path,
|
| 1121 |
+
job_parameters,
|
| 1122 |
+
training_parameters,
|
| 1123 |
+
model_parameters[job_parameters["ensemble_list"][i]],
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
##Compile error metrics from individual models
|
| 1127 |
+
print("Individual training finished.")
|
| 1128 |
+
print("Compiling metrics from individual models...")
|
| 1129 |
+
error_values = np.zeros((len(job_parameters["ensemble_list"]), 3))
|
| 1130 |
+
for i in range(0, len(job_parameters["ensemble_list"])):
|
| 1131 |
+
filename = job_name + str(i) + "_errorvalues.csv"
|
| 1132 |
+
error_values[i] = np.genfromtxt(filename, delimiter=",")
|
| 1133 |
+
mean_values = [
|
| 1134 |
+
np.mean(error_values[:, 0]),
|
| 1135 |
+
np.mean(error_values[:, 1]),
|
| 1136 |
+
np.mean(error_values[:, 2]),
|
| 1137 |
+
]
|
| 1138 |
+
std_values = [
|
| 1139 |
+
np.std(error_values[:, 0]),
|
| 1140 |
+
np.std(error_values[:, 1]),
|
| 1141 |
+
np.std(error_values[:, 2]),
|
| 1142 |
+
]
|
| 1143 |
+
|
| 1144 |
+
# average ensembling, takes the mean of the predictions
|
| 1145 |
+
for i in range(0, len(job_parameters["ensemble_list"])):
|
| 1146 |
+
filename = job_name + str(i) + "_test_outputs.csv"
|
| 1147 |
+
test_out = np.genfromtxt(filename, delimiter=",", skip_header=1)
|
| 1148 |
+
if i == 0:
|
| 1149 |
+
test_total = test_out
|
| 1150 |
+
elif i > 0:
|
| 1151 |
+
test_total = np.column_stack((test_total, test_out[:, 2]))
|
| 1152 |
+
|
| 1153 |
+
ensemble_test = np.mean(np.array(test_total[:, 2:]).astype(np.float), axis=1)
|
| 1154 |
+
ensemble_test_error = getattr(F, training_parameters["loss"])(
|
| 1155 |
+
torch.tensor(ensemble_test),
|
| 1156 |
+
torch.tensor(test_total[:, 1].astype(np.float)),
|
| 1157 |
+
)
|
| 1158 |
+
test_total = np.column_stack((test_total, ensemble_test))
|
| 1159 |
+
|
| 1160 |
+
##Print performance
|
| 1161 |
+
for i in range(0, len(job_parameters["ensemble_list"])):
|
| 1162 |
+
print(
|
| 1163 |
+
job_parameters["ensemble_list"][i]
|
| 1164 |
+
+ " Test Error: {:.5f}".format(error_values[i, 2])
|
| 1165 |
+
)
|
| 1166 |
+
print(
|
| 1167 |
+
"Test Error Avg: {:.3f}, Test Standard Dev: {:.3f}".format(
|
| 1168 |
+
mean_values[2], std_values[2]
|
| 1169 |
+
)
|
| 1170 |
+
)
|
| 1171 |
+
print("Ensemble Error: {:.5f}".format(ensemble_test_error))
|
| 1172 |
+
|
| 1173 |
+
##Write output
|
| 1174 |
+
if write_output == "True" or write_output == "Partial":
|
| 1175 |
+
with open(
|
| 1176 |
+
str(job_name) + "_test_ensemble_outputs.csv", "w"
|
| 1177 |
+
) as f:
|
| 1178 |
+
csvwriter = csv.writer(f)
|
| 1179 |
+
for i in range(0, len(test_total) + 1):
|
| 1180 |
+
if i == 0:
|
| 1181 |
+
csvwriter.writerow(
|
| 1182 |
+
[
|
| 1183 |
+
"ids",
|
| 1184 |
+
"target",
|
| 1185 |
+
]
|
| 1186 |
+
+ job_parameters["ensemble_list"]
|
| 1187 |
+
+ ["ensemble"]
|
| 1188 |
+
)
|
| 1189 |
+
elif i > 0:
|
| 1190 |
+
csvwriter.writerow(test_total[i - 1, :])
|
| 1191 |
+
if write_output == "False" or write_output == "Partial":
|
| 1192 |
+
for i in range(0, len(job_parameters["ensemble_list"])):
|
| 1193 |
+
filename = job_name + str(i) + "_errorvalues.csv"
|
| 1194 |
+
os.remove(filename)
|
| 1195 |
+
filename = job_name + str(i) + "_test_outputs.csv"
|
| 1196 |
+
os.remove(filename)
|
| 1197 |
+
|
| 1198 |
+
##Obtains features from graph in a trained model and analysis with tsne
|
| 1199 |
+
def analysis(
|
| 1200 |
+
dataset,
|
| 1201 |
+
model_path,
|
| 1202 |
+
tsne_args,
|
| 1203 |
+
):
|
| 1204 |
+
|
| 1205 |
+
# imports
|
| 1206 |
+
from sklearn.decomposition import PCA
|
| 1207 |
+
from sklearn.manifold import TSNE
|
| 1208 |
+
import matplotlib.pyplot as plt
|
| 1209 |
+
|
| 1210 |
+
rank = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1211 |
+
|
| 1212 |
+
inputs = []
|
| 1213 |
+
|
| 1214 |
+
def hook(module, input, output):
|
| 1215 |
+
inputs.append(input)
|
| 1216 |
+
|
| 1217 |
+
assert os.path.exists(model_path), "saved model not found"
|
| 1218 |
+
if str(rank) == "cpu":
|
| 1219 |
+
saved = torch.load(model_path, map_location=torch.device("cpu"))
|
| 1220 |
+
else:
|
| 1221 |
+
saved = torch.load(model_path, map_location=torch.device("cuda"))
|
| 1222 |
+
model = saved["full_model"]
|
| 1223 |
+
model_summary(model)
|
| 1224 |
+
|
| 1225 |
+
print(dataset)
|
| 1226 |
+
|
| 1227 |
+
loader = DataLoader(
|
| 1228 |
+
dataset,
|
| 1229 |
+
batch_size=512,
|
| 1230 |
+
shuffle=False,
|
| 1231 |
+
num_workers=0,
|
| 1232 |
+
pin_memory=True,
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
model.eval()
|
| 1236 |
+
##Grabs the input of the first linear layer after the GNN
|
| 1237 |
+
model.post_lin_list[0].register_forward_hook(hook)
|
| 1238 |
+
for data in loader:
|
| 1239 |
+
with torch.no_grad():
|
| 1240 |
+
data = data.to(rank)
|
| 1241 |
+
output = model(data)
|
| 1242 |
+
|
| 1243 |
+
inputs = [i for sub in inputs for i in sub]
|
| 1244 |
+
inputs = torch.cat(inputs)
|
| 1245 |
+
inputs = inputs.cpu().numpy()
|
| 1246 |
+
print("Number of samples: ", inputs.shape[0])
|
| 1247 |
+
print("Number of features: ", inputs.shape[1])
|
| 1248 |
+
|
| 1249 |
+
# only works for when targets has one index
|
| 1250 |
+
targets = dataset.data.y.numpy()
|
| 1251 |
+
|
| 1252 |
+
# pca = PCA(n_components=2)
|
| 1253 |
+
# pca_out=pca.fit_transform(inputs)
|
| 1254 |
+
# print(pca_out.shape)
|
| 1255 |
+
# np.savetxt('pca.csv', pca_out, delimiter=',')
|
| 1256 |
+
# plt.scatter(pca_out[:,1],pca_out[:,0],c=targets,s=15)
|
| 1257 |
+
# plt.colorbar()
|
| 1258 |
+
# plt.show()
|
| 1259 |
+
# plt.clf()
|
| 1260 |
+
|
| 1261 |
+
##Start t-SNE analysis
|
| 1262 |
+
tsne = TSNE(**tsne_args)
|
| 1263 |
+
tsne_out = tsne.fit_transform(inputs)
|
| 1264 |
+
rows = zip(
|
| 1265 |
+
dataset.data.structure_id,
|
| 1266 |
+
list(dataset.data.y.numpy()),
|
| 1267 |
+
list(tsne_out[:, 0]),
|
| 1268 |
+
list(tsne_out[:, 1]),
|
| 1269 |
+
)
|
| 1270 |
+
|
| 1271 |
+
with open("tsne_output.csv", "w") as csv_file:
|
| 1272 |
+
writer = csv.writer(csv_file, delimiter=",")
|
| 1273 |
+
for row in rows:
|
| 1274 |
+
writer.writerow(row)
|
| 1275 |
+
|
| 1276 |
+
fig, ax = plt.subplots()
|
| 1277 |
+
main = plt.scatter(tsne_out[:, 1], tsne_out[:, 0], c=targets, s=3)
|
| 1278 |
+
ax.set_xticklabels([])
|
| 1279 |
+
ax.set_yticklabels([])
|
| 1280 |
+
ax.set_xticks([])
|
| 1281 |
+
ax.set_yticks([])
|
| 1282 |
+
cbar = plt.colorbar(main, ax=ax)
|
| 1283 |
+
stdev = np.std(targets)
|
| 1284 |
+
cbar.mappable.set_clim(
|
| 1285 |
+
np.mean(targets) - 2 * np.std(targets), np.mean(targets) + 2 * np.std(targets)
|
| 1286 |
+
)
|
| 1287 |
+
# cbar.ax.tick_params(labelsize=50)
|
| 1288 |
+
# cbar.ax.tick_params(size=40)
|
| 1289 |
+
plt.savefig("tsne_output.png", format="png", dpi=600)
|
| 1290 |
+
plt.show()
|