Fix bf16 AMP detection for T4 in train_production.py
Browse files- iris/train_production.py +7 -1
iris/train_production.py
CHANGED
|
@@ -38,7 +38,13 @@ def main():
|
|
| 38 |
torch.manual_seed(args.seed)
|
| 39 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 40 |
use_amp = device.type == "cuda"
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
print(f"IRIS Training - {args.config} | Device: {device}, AMP: {amp_dtype}")
|
| 44 |
model_cfg = get_model_config(args.config)
|
|
|
|
| 38 |
torch.manual_seed(args.seed)
|
| 39 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 40 |
use_amp = device.type == "cuda"
|
| 41 |
+
# T4 (compute cap 7.5) reports bf16 supported but cuDNN conv kernels crash.
|
| 42 |
+
# Force fp16 on GPUs below Ampere (compute cap < 8.0).
|
| 43 |
+
if use_amp:
|
| 44 |
+
cc = torch.cuda.get_device_capability(0)
|
| 45 |
+
amp_dtype = torch.float16 if cc[0] < 8 else torch.bfloat16
|
| 46 |
+
else:
|
| 47 |
+
amp_dtype = torch.float32
|
| 48 |
|
| 49 |
print(f"IRIS Training - {args.config} | Device: {device}, AMP: {amp_dtype}")
|
| 50 |
model_cfg = get_model_config(args.config)
|