Zpwang-AI commited on
Commit
0211d7e
·
1 Parent(s): cc27249

Update model/model_v2.py

Browse files
Files changed (1) hide show
  1. model/model_v2.py +6 -0
model/model_v2.py CHANGED
@@ -217,6 +217,12 @@ class Modelv2(lightning.LightningModule):
217
  preds = torch.argmax(output, dim=-1) # cls, bsz or bsz
218
  return preds
219
 
 
 
 
 
 
 
220
  def one_step(self, batch, stage):
221
  xs, ys = batch
222
  if self.rdrop == None:
 
217
  preds = torch.argmax(output, dim=-1) # cls, bsz or bsz
218
  return preds
219
 
220
+ def predict_prob(self, batch_x):
221
+ output = self(batch_x)
222
+ probs = torch.softmax(output, dim=-1)
223
+ probs = probs[..., 1]
224
+ return probs
225
+
226
  def one_step(self, batch, stage):
227
  xs, ys = batch
228
  if self.rdrop == None: