Commit ·
9bb0b6e
1
Parent(s): 7019440
removed index from BharatAIs forward method and replaced index everywhere with input_ids
Browse files
model.py
CHANGED
|
@@ -116,12 +116,13 @@ class BharatAI(PreTrainedModel):
|
|
| 116 |
elif isinstance(module, nn.Embedding):
|
| 117 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 118 |
|
| 119 |
-
def forward(self, input_ids,
|
| 120 |
-
B, T = index.shape
|
|
|
|
| 121 |
x = input_ids
|
| 122 |
|
| 123 |
# idx and targets are both (B,T) tensor of integers
|
| 124 |
-
tok_emb = self.token_embedding_table(
|
| 125 |
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
|
| 126 |
x = tok_emb + pos_emb # (B,T,C)
|
| 127 |
x = self.blocks(x) # (B,T,C)
|
|
|
|
| 116 |
elif isinstance(module, nn.Embedding):
|
| 117 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 118 |
|
| 119 |
+
def forward(self, input_ids,labels=None): #, targets # index,
|
| 120 |
+
#B, T = index.shape
|
| 121 |
+
B, T = input_ids.shape
|
| 122 |
x = input_ids
|
| 123 |
|
| 124 |
# idx and targets are both (B,T) tensor of integers
|
| 125 |
+
tok_emb = self.token_embedding_table(input_ids) # (B,T,C)
|
| 126 |
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
|
| 127 |
x = tok_emb + pos_emb # (B,T,C)
|
| 128 |
x = self.blocks(x) # (B,T,C)
|