gaurv007 commited on
Commit
597978a
Β·
verified Β·
1 Parent(s): 85cf385

Fix: cast labels to float32 for BCEWithLogitsLoss compatibility

Browse files
Files changed (1) hide show
  1. ml/train_classifier.py +7 -2
ml/train_classifier.py CHANGED
@@ -8,7 +8,7 @@ Compatible with: Transformers 5.6.x, Datasets 4.8.x (April 2026)
8
  import os
9
  import numpy as np
10
  import torch
11
- from datasets import load_dataset
12
  from sklearn.metrics import f1_score, precision_score, recall_score
13
  from transformers import (
14
  AutoConfig,
@@ -85,7 +85,12 @@ def preprocess(examples):
85
  return tokenized
86
 
87
  print("Tokenizing dataset...")
88
- tokenized_ds = dataset.map(preprocess, batched=True, remove_columns=["text"])
 
 
 
 
 
89
  tokenized_ds.set_format("torch")
90
 
91
  # ─── 4. Metrics ───
 
8
  import os
9
  import numpy as np
10
  import torch
11
+ from datasets import load_dataset, Sequence, Value
12
  from sklearn.metrics import f1_score, precision_score, recall_score
13
  from transformers import (
14
  AutoConfig,
 
85
  return tokenized
86
 
87
  print("Tokenizing dataset...")
88
+ tokenized_ds = dataset.map(preprocess, batched=True, remove_columns=dataset["train"].column_names)
89
+
90
+ # Critical: cast labels to float32 for BCEWithLogitsLoss (datasets default is int64)
91
+ for split in tokenized_ds:
92
+ tokenized_ds[split] = tokenized_ds[split].cast_column("labels", Sequence(Value("float32")))
93
+
94
  tokenized_ds.set_format("torch")
95
 
96
  # ─── 4. Metrics ───