vermouthdky commited on
Commit
8c38a63
·
verified ·
1 Parent(s): c4d739c

Upload 3 files

Browse files
Files changed (2) hide show
  1. modeling_qwen2.py +1 -1
  2. nets.py +191 -0
modeling_qwen2.py CHANGED
@@ -40,7 +40,7 @@ from transformers.utils import (add_start_docstrings,
40
  is_flash_attn_greater_or_equal_2_10, logging,
41
  replace_return_docstrings)
42
 
43
- from ..nets import EnsembleModel
44
  from .configuration_qwen2 import QwenEnPRMConfig as Qwen2Config
45
 
46
  if is_flash_attn_2_available():
 
40
  is_flash_attn_greater_or_equal_2_10, logging,
41
  replace_return_docstrings)
42
 
43
+ from .nets import EnsembleModel
44
  from .configuration_qwen2 import QwenEnPRMConfig as Qwen2Config
45
 
46
  if is_flash_attn_2_available():
nets.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Garena Online Private Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Deep networks."""
16
+
17
+ from copy import deepcopy
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+
24
+
25
+ def init_weights(m):
26
+ @torch.no_grad()
27
+ def truncated_normal_init(t, mean=0.0, std=0.01):
28
+ # torch.nn.init.normal_(t, mean=mean, std=std)
29
+ t.data.normal_(mean, std)
30
+ while True:
31
+ cond = torch.logical_or(t < mean - 2 * std, t > mean + 2 * std)
32
+ if not torch.sum(cond):
33
+ break
34
+ w = torch.empty(t.shape, device=t.device, dtype=t.dtype)
35
+ # torch.nn.init.normal_(w, mean=mean, std=std)
36
+ w.data.normal_(mean, std)
37
+ t = torch.where(cond, w, t)
38
+ return t
39
+
40
+ if type(m) is nn.Linear or isinstance(m, EnsembleFC):
41
+ truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(m.in_features)))
42
+ if m.bias is not None:
43
+ m.bias.data.fill_(0.0)
44
+
45
+
46
+ def init_weights_uniform(m):
47
+ input_dim = m.in_features
48
+ torch.nn.init.uniform(m.weight, -1 / np.sqrt(input_dim), 1 / np.sqrt(input_dim))
49
+ if m.bias is not None:
50
+ m.bias.data.fill_(0.0)
51
+
52
+
53
+ class Swish(nn.Module):
54
+ def __init__(self):
55
+ super(Swish, self).__init__()
56
+
57
+ def forward(self, x):
58
+ x = x * F.sigmoid(x)
59
+ return x
60
+
61
+
62
+ class MLPModel(nn.Module):
63
+ def __init__(self, encoding_dim, hidden_dim=128, activation="relu") -> None:
64
+ super(MLPModel, self).__init__()
65
+ self.hidden_size = hidden_dim
66
+ self.output_dim = 1
67
+
68
+ self.nn1 = nn.Linear(encoding_dim, hidden_dim)
69
+ self.nn2 = nn.Linear(hidden_dim, hidden_dim)
70
+ self.nn_out = nn.Linear(hidden_dim, self.output_dim)
71
+
72
+ self.apply(init_weights)
73
+
74
+ if activation == "swish":
75
+ self.activation = Swish()
76
+ elif activation == "relu":
77
+ self.activation = nn.ReLU()
78
+ else:
79
+ raise ValueError(f"Unknown activation {activation}")
80
+
81
+ def get_params(self) -> torch.Tensor:
82
+ params = []
83
+ for pp in list(self.parameters()):
84
+ params.append(pp.view(-1))
85
+ return torch.cat(params)
86
+
87
+ def forward(self, encoding: torch.Tensor) -> torch.Tensor:
88
+ x = self.activation(self.nn1(encoding))
89
+ x = self.activation(self.nn2(x))
90
+ score = self.nn_out(x)
91
+ return score
92
+
93
+ def init(self):
94
+ self.init_params = self.get_params().data.clone()
95
+ if torch.cuda.is_available():
96
+ self.init_params = self.init_params.cuda()
97
+
98
+ def regularization(self):
99
+ """Prior towards independent initialization."""
100
+ return ((self.get_params() - self.init_params) ** 2).mean()
101
+
102
+
103
+ class EnsembleFC(nn.Module):
104
+ __constants__ = ["in_features", "out_features"]
105
+ in_features: int
106
+ out_features: int
107
+ ensemble_size: int
108
+ weight: torch.Tensor
109
+
110
+ def __init__(
111
+ self,
112
+ in_features: int,
113
+ out_features: int,
114
+ ensemble_size: int,
115
+ bias: bool = True,
116
+ dtype=torch.float32,
117
+ ) -> None:
118
+ super(EnsembleFC, self).__init__()
119
+ self.in_features = in_features
120
+ self.out_features = out_features
121
+ self.ensemble_size = ensemble_size
122
+ # init immediately to avoid error
123
+ self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features, dtype=dtype))
124
+ if bias:
125
+ self.bias = nn.Parameter(torch.empty(ensemble_size, out_features, dtype=dtype))
126
+ else:
127
+ self.register_parameter("bias", None)
128
+
129
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
130
+ input = input.to(self.weight.dtype)
131
+ wx = torch.einsum("eblh,ehm->eblm", input, self.weight)
132
+
133
+ return torch.add(wx, self.bias[:, None, None, :]) # w times x + b
134
+
135
+
136
+ def get_params(model):
137
+ return torch.cat([p.view(-1) for p in model.parameters()])
138
+
139
+
140
+ class _EnsembleModel(nn.Module):
141
+ def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None:
142
+ # super().__init__(encoding_dim, hidden_dim, activation)
143
+ super(_EnsembleModel, self).__init__()
144
+ self.num_ensemble = num_ensemble
145
+ self.hidden_dim = hidden_dim
146
+ self.output_dim = 1
147
+
148
+ self.nn1 = EnsembleFC(encoding_dim, hidden_dim, num_ensemble, dtype=dtype)
149
+ self.nn2 = EnsembleFC(hidden_dim, hidden_dim, num_ensemble, dtype=dtype)
150
+ self.nn_out = EnsembleFC(hidden_dim, self.output_dim, num_ensemble, dtype=dtype)
151
+
152
+ self.apply(init_weights)
153
+
154
+ if activation == "swish":
155
+ self.activation = Swish()
156
+ elif activation == "relu":
157
+ self.activation = nn.ReLU()
158
+ else:
159
+ raise ValueError(f"Unknown activation {activation}")
160
+
161
+ def forward(self, encoding: torch.Tensor) -> torch.Tensor:
162
+ x = self.activation(self.nn1(encoding))
163
+ x = self.activation(self.nn2(x))
164
+ score = self.nn_out(x)
165
+ return score
166
+
167
+ def regularization(self):
168
+ """Prior towards independent initialization."""
169
+ return ((self.get_params() - self.init_params) ** 2).mean()
170
+
171
+
172
+ class EnsembleModel(nn.Module):
173
+ def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None:
174
+ super(EnsembleModel, self).__init__()
175
+ self.encoding_dim = encoding_dim
176
+ self.num_ensemble = num_ensemble
177
+ self.hidden_dim = hidden_dim
178
+ self.model = _EnsembleModel(encoding_dim, num_ensemble, hidden_dim, activation, dtype)
179
+ self.reg_model = deepcopy(self.model) # only used for regularization
180
+ # freeze the reg model
181
+ for param in self.reg_model.parameters():
182
+ param.requires_grad = False
183
+
184
+ def forward(self, encoding: torch.Tensor) -> torch.Tensor:
185
+ return self.model(encoding)
186
+
187
+ def regularization(self):
188
+ """Prior towards independent initialization."""
189
+ model_params = get_params(self.model)
190
+ reg_params = get_params(self.reg_model).detach()
191
+ return ((model_params - reg_params) ** 2).mean()