asdf98 commited on
Commit
4aebda2
·
verified ·
1 Parent(s): e90110a

Fix bf16 AMP detection for T4 in train_production.py

Browse files
Files changed (1) hide show
  1. iris/train_production.py +7 -1
iris/train_production.py CHANGED
@@ -38,7 +38,13 @@ def main():
38
  torch.manual_seed(args.seed)
39
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  use_amp = device.type == "cuda"
41
- amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16 if use_amp else torch.float32
 
 
 
 
 
 
42
 
43
  print(f"IRIS Training - {args.config} | Device: {device}, AMP: {amp_dtype}")
44
  model_cfg = get_model_config(args.config)
 
38
  torch.manual_seed(args.seed)
39
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  use_amp = device.type == "cuda"
41
+ # T4 (compute cap 7.5) reports bf16 supported but cuDNN conv kernels crash.
42
+ # Force fp16 on GPUs below Ampere (compute cap < 8.0).
43
+ if use_amp:
44
+ cc = torch.cuda.get_device_capability(0)
45
+ amp_dtype = torch.float16 if cc[0] < 8 else torch.bfloat16
46
+ else:
47
+ amp_dtype = torch.float32
48
 
49
  print(f"IRIS Training - {args.config} | Device: {device}, AMP: {amp_dtype}")
50
  model_cfg = get_model_config(args.config)