Charlie81 commited on
Commit
580eff8
·
1 Parent(s): 1f3825f

train aga

Browse files
Files changed (1) hide show
  1. scripts/train.py +6 -2
scripts/train.py CHANGED
@@ -137,8 +137,12 @@ def main():
137
 
138
  # Test forward/backward pass before training
139
  print("Testing gradient flow...")
140
- test_batch = next(iter(DataLoader(tokenized_dataset, batch_size=1)))
141
- test_batch = {k: v.to(model.device) for k, v in test_batch.items()}
 
 
 
 
142
 
143
  model.train()
144
  outputs = model(**test_batch)
 
137
 
138
  # Test forward/backward pass before training
139
  print("Testing gradient flow...")
140
+ test_loader = DataLoader(tokenized_dataset, batch_size=1, collate_fn=data_collator)
141
+ test_batch = next(iter(test_loader))
142
+
143
+ # Move batch to model's device
144
+ device = next(model.parameters()).device
145
+ test_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in test_batch.items()}
146
 
147
  model.train()
148
  outputs = model(**test_batch)