Update model_class.py
Browse files- model_class.py +1 -2
model_class.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 4 |
-
# from huggingface_hub import PyTorchModelHubMixin
|
| 5 |
|
| 6 |
|
| 7 |
class MultiheadAttention(nn.Module):
|
|
@@ -203,7 +202,7 @@ class TicketGPT(
|
|
| 203 |
|
| 204 |
# Model inference
|
| 205 |
with torch.no_grad():
|
| 206 |
-
logits = self
|
| 207 |
predicted_label = torch.argmax(logits, dim=-1).item()
|
| 208 |
|
| 209 |
# Return the classified result
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class MultiheadAttention(nn.Module):
|
|
|
|
| 202 |
|
| 203 |
# Model inference
|
| 204 |
with torch.no_grad():
|
| 205 |
+
logits = self(input_tensor)[:, -1, :] # Logits of the last output token
|
| 206 |
predicted_label = torch.argmax(logits, dim=-1).item()
|
| 207 |
|
| 208 |
# Return the classified result
|