FarhanAK128 commited on
Commit
54415bf
·
verified ·
1 Parent(s): 1690b97

Update model_class.py

Browse files
Files changed (1) hide show
  1. model_class.py +1 -2
model_class.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
4
- # from huggingface_hub import PyTorchModelHubMixin
5
 
6
 
7
  class MultiheadAttention(nn.Module):
@@ -203,7 +202,7 @@ class TicketGPT(
203
 
204
  # Model inference
205
  with torch.no_grad():
206
- logits = self.forward(input_tensor)[:, -1, :] # Logits of the last output token
207
  predicted_label = torch.argmax(logits, dim=-1).item()
208
 
209
  # Return the classified result
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
 
4
 
5
 
6
  class MultiheadAttention(nn.Module):
 
202
 
203
  # Model inference
204
  with torch.no_grad():
205
+ logits = self(input_tensor)[:, -1, :] # Logits of the last output token
206
  predicted_label = torch.argmax(logits, dim=-1).item()
207
 
208
  # Return the classified result