Spaces:
Build error
Build error
Update model/model_v2.py
Browse files- model/model_v2.py +6 -0
model/model_v2.py
CHANGED
|
@@ -217,6 +217,12 @@ class Modelv2(lightning.LightningModule):
|
|
| 217 |
preds = torch.argmax(output, dim=-1) # cls, bsz or bsz
|
| 218 |
return preds
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
def one_step(self, batch, stage):
|
| 221 |
xs, ys = batch
|
| 222 |
if self.rdrop == None:
|
|
|
|
| 217 |
preds = torch.argmax(output, dim=-1) # cls, bsz or bsz
|
| 218 |
return preds
|
| 219 |
|
| 220 |
+
def predict_prob(self, batch_x):
|
| 221 |
+
output = self(batch_x)
|
| 222 |
+
probs = torch.softmax(output, dim=-1)
|
| 223 |
+
probs = probs[..., 1]
|
| 224 |
+
return probs
|
| 225 |
+
|
| 226 |
def one_step(self, batch, stage):
|
| 227 |
xs, ys = batch
|
| 228 |
if self.rdrop == None:
|