dennny123 Claude Sonnet 4.5 (1M context) commited on
Commit
f7908e9
·
1 Parent(s): c223ae8

Fix MIG GPU CUBLAS error - patch linear operations not freq computation

Browse files

Previous patch was in wrong location. Error happens in:
comfy/ops.py line 157: torch.nn.functional.linear()

Not in freqs computation. This patch:
- Catches CUBLAS_STATUS_INVALID_VALUE in linear operations
- Falls back to float32 computation on error
- Converts result back to original dtype
- Should fix the Qwen text encoder CUDA crashes on MIG GPUs

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +63 -17
app.py CHANGED
@@ -32,38 +32,84 @@ BYPASS_REPO_DIR = os.path.join(ROOT_DIR, "reference_repo")
32
 
33
  def _patch_qwen_for_mig_gpu():
34
  """Patch Qwen/Llama text encoder for MIG GPU compatibility"""
35
- llama_file = os.path.join(COMFYUI_DIR, "comfy/text_encoders/llama.py")
36
- if not os.path.exists(llama_file):
37
  return
38
 
39
- with open(llama_file, 'r') as f:
40
  content = f.read()
41
 
42
  # Check if patch already applied
43
- if 'MIG GPU compatibility' in content:
44
  print("[OK] Qwen MIG GPU patch already applied")
45
  return
46
 
47
- # Patch the problematic matmul operation to use CPU fallback on error
48
- original_code = ''' freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)'''
49
-
50
- patched_code = ''' # MIG GPU compatibility: fallback to CPU on CUBLAS errors
51
- try:
52
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
53
- except RuntimeError as e:
54
- if 'CUBLAS' in str(e):
55
- # Fallback to CPU for this operation
56
- freqs = (inv_freq_expanded.float().cpu() @ position_ids_expanded.float().cpu()).transpose(1, 2).to(inv_freq_expanded.device)
 
 
 
 
 
 
 
57
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  raise'''
59
 
60
  if original_code in content:
61
  patched_content = content.replace(original_code, patched_code)
62
- with open(llama_file, 'w') as f:
63
  f.write(patched_content)
64
- print("[OK] Applied MIG GPU compatibility patch to Qwen text encoder")
65
  else:
66
- print("[WARN] Qwen patch pattern not found (file may have changed)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def setup():
69
  """Environment setup for Hugging Face Space"""
 
32
 
33
  def _patch_qwen_for_mig_gpu():
34
  """Patch Qwen/Llama text encoder for MIG GPU compatibility"""
35
+ ops_file = os.path.join(COMFYUI_DIR, "comfy/ops.py")
36
+ if not os.path.exists(ops_file):
37
  return
38
 
39
+ with open(ops_file, 'r') as f:
40
  content = f.read()
41
 
42
  # Check if patch already applied
43
+ if 'MIG GPU CUBLAS fix' in content:
44
  print("[OK] Qwen MIG GPU patch already applied")
45
  return
46
 
47
+ # Patch the linear operation to use float32 on MIG GPUs
48
+ original_code = ''' def forward_comfy_cast_weights(self, input, weight, bias=None, weight_dtype=None, bias_dtype=None):
49
+ if weight_dtype is not None:
50
+ weight = comfy.model_management.cast_to_device(weight, input.device, weight_dtype)
51
+ else:
52
+ weight = comfy.model_management.cast_to_device(weight, input.device, torch.float32)
53
+
54
+ if bias is not None:
55
+ if bias_dtype is not None:
56
+ bias = comfy.model_management.cast_to_device(bias, input.device, bias_dtype)
57
+ else:
58
+ bias = comfy.model_management.cast_to_device(bias, input.device, torch.float32)
59
+ return torch.nn.functional.linear(input, weight, bias)'''
60
+
61
+ patched_code = ''' def forward_comfy_cast_weights(self, input, weight, bias=None, weight_dtype=None, bias_dtype=None):
62
+ if weight_dtype is not None:
63
+ weight = comfy.model_management.cast_to_device(weight, input.device, weight_dtype)
64
  else:
65
+ weight = comfy.model_management.cast_to_device(weight, input.device, torch.float32)
66
+
67
+ if bias is not None:
68
+ if bias_dtype is not None:
69
+ bias = comfy.model_management.cast_to_device(bias, input.device, bias_dtype)
70
+ else:
71
+ bias = comfy.model_management.cast_to_device(bias, input.device, torch.float32)
72
+
73
+ # MIG GPU CUBLAS fix: Force float32 for linear ops to avoid CUBLAS errors
74
+ try:
75
+ return torch.nn.functional.linear(input, weight, bias)
76
+ except RuntimeError as e:
77
+ if 'CUBLAS' in str(e):
78
+ # Force everything to float32 and retry
79
+ input_f32 = input.float()
80
+ weight_f32 = weight.float()
81
+ bias_f32 = bias.float() if bias is not None else None
82
+ result = torch.nn.functional.linear(input_f32, weight_f32, bias_f32)
83
+ return result.to(input.dtype)
84
  raise'''
85
 
86
  if original_code in content:
87
  patched_content = content.replace(original_code, patched_code)
88
+ with open(ops_file, 'w') as f:
89
  f.write(patched_content)
90
+ print("[OK] Applied MIG GPU CUBLAS fix to linear operations")
91
  else:
92
+ # Try a simpler pattern match
93
+ if 'return torch.nn.functional.linear(input, weight, bias)' in content:
94
+ patched_content = content.replace(
95
+ ' return torch.nn.functional.linear(input, weight, bias)',
96
+ ''' # MIG GPU CUBLAS fix
97
+ try:
98
+ return torch.nn.functional.linear(input, weight, bias)
99
+ except RuntimeError as e:
100
+ if 'CUBLAS' in str(e):
101
+ input_f32 = input.float()
102
+ weight_f32 = weight.float()
103
+ bias_f32 = bias.float() if bias is not None else None
104
+ result = torch.nn.functional.linear(input_f32, weight_f32, bias_f32)
105
+ return result.to(input.dtype)
106
+ raise'''
107
+ )
108
+ with open(ops_file, 'w') as f:
109
+ f.write(patched_content)
110
+ print("[OK] Applied MIG GPU CUBLAS fix (fallback pattern)")
111
+ else:
112
+ print("[WARN] Could not find patch location in ops.py")
113
 
114
  def setup():
115
  """Environment setup for Hugging Face Space"""