asdf98 commited on
Commit
e90110a
·
verified ·
1 Parent(s): 654d061

Fix conv2d bf16 crash on T4: colab_train_iris.py

Browse files
Files changed (1) hide show
  1. colab_train_iris.py +14 -3
colab_train_iris.py CHANGED
@@ -72,14 +72,25 @@ import gc
72
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
  if device.type == "cuda":
74
  gpu_name = torch.cuda.get_device_name(0)
75
- gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9
76
  print(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")
77
  else:
78
  print("WARNING: No GPU detected. Training will be very slow.")
79
  print("In Colab: Runtime -> Change runtime type -> T4 GPU")
80
 
81
  use_amp = device.type == "cuda"
82
- amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16 if use_amp else torch.float32
 
 
 
 
 
 
 
 
 
 
 
83
  print(f"AMP dtype: {amp_dtype}")
84
 
85
  # ============================================================
@@ -224,7 +235,7 @@ print(f" Core: {counts['core']:,}")
224
  print(f" Decoder: {counts['tiny_decoder']:,}")
225
 
226
  if device.type == "cuda":
227
- print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB / {torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB")
228
 
229
  # ============================================================
230
  # CELL 9: Train!
 
72
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
  if device.type == "cuda":
74
  gpu_name = torch.cuda.get_device_name(0)
75
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
76
  print(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")
77
  else:
78
  print("WARNING: No GPU detected. Training will be very slow.")
79
  print("In Colab: Runtime -> Change runtime type -> T4 GPU")
80
 
81
  use_amp = device.type == "cuda"
82
+ # T4 (compute capability 7.5) reports bf16 supported but cuDNN conv2d kernels
83
+ # lack bf16 engines → crashes at runtime. Force fp16 which T4 natively supports.
84
+ if use_amp:
85
+ cc = torch.cuda.get_device_capability(0)
86
+ if cc[0] < 8: # Ampere (8.0+) has native bf16; Turing (7.5) does not
87
+ amp_dtype = torch.float16
88
+ print(f"GPU compute capability {cc[0]}.{cc[1]} — using fp16 (bf16 conv kernels unavailable)")
89
+ else:
90
+ amp_dtype = torch.bfloat16
91
+ print(f"GPU compute capability {cc[0]}.{cc[1]} — using bf16")
92
+ else:
93
+ amp_dtype = torch.float32
94
  print(f"AMP dtype: {amp_dtype}")
95
 
96
  # ============================================================
 
235
  print(f" Decoder: {counts['tiny_decoder']:,}")
236
 
237
  if device.type == "cuda":
238
+ print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB / {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
239
 
240
  # ============================================================
241
  # CELL 9: Train!