KrushiJethe commited on
Commit
9bb0b6e
·
1 Parent(s): 7019440

removed index from BharatAIs forward method and replaced index everywhere with input_ids

Browse files
Files changed (1) hide show
  1. model.py +4 -3
model.py CHANGED
@@ -116,12 +116,13 @@ class BharatAI(PreTrainedModel):
116
  elif isinstance(module, nn.Embedding):
117
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
118
 
119
- def forward(self, input_ids, index, labels=None): #, targets
120
- B, T = index.shape
 
121
  x = input_ids
122
 
123
  # idx and targets are both (B,T) tensor of integers
124
- tok_emb = self.token_embedding_table(index) # (B,T,C)
125
  pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
126
  x = tok_emb + pos_emb # (B,T,C)
127
  x = self.blocks(x) # (B,T,C)
 
116
  elif isinstance(module, nn.Embedding):
117
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
118
 
119
+ def forward(self, input_ids,labels=None): #, targets # index,
120
+ #B, T = index.shape
121
+ B, T = input_ids.shape
122
  x = input_ids
123
 
124
  # idx and targets are both (B,T) tensor of integers
125
+ tok_emb = self.token_embedding_table(input_ids) # (B,T,C)
126
  pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
127
  x = tok_emb + pos_emb # (B,T,C)
128
  x = self.blocks(x) # (B,T,C)