dennny123 Claude Sonnet 4.5 (1M context) commited on
Commit
5c8c826
·
1 Parent(s): fb34b62

Fix MIG GPU CUBLAS error in Qwen text encoder frequency computation

Browse files

Previous patch was to ops.py but error happens in llama.py:301
Now patching the actual error location:
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)

On CUBLAS error, falls back to CPU computation then moves back to GPU.
This should finally fix the MIG GPU incompatibility.

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +18 -20
app.py CHANGED
@@ -31,50 +31,48 @@ COMFYUI_DIR = os.path.join(ROOT_DIR, "ComfyUI")
31
  BYPASS_REPO_DIR = os.path.join(ROOT_DIR, "reference_repo")
32
 
33
  def _patch_qwen_for_mig_gpu():
34
- """Patch Qwen/Llama text encoder for MIG GPU compatibility"""
35
- ops_file = os.path.join(COMFYUI_DIR, "comfy/ops.py")
36
- if not os.path.exists(ops_file):
37
  return
38
 
39
- with open(ops_file, 'r') as f:
40
  lines = f.readlines()
41
 
42
  # Check if patch already applied
43
  content = ''.join(lines)
44
- if 'MIG GPU CUBLAS fix' in content:
45
- print("[OK] Qwen MIG GPU patch already applied")
46
  return
47
 
48
- # Find and patch the return statement in forward_comfy_cast_weights
49
  patched = False
50
  for i, line in enumerate(lines):
51
- if 'return torch.nn.functional.linear(input, weight, bias)' in line and not patched:
52
  indent = len(line) - len(line.lstrip())
53
  space = ' ' * indent
54
- # Replace the single return line with try-except block
55
  new_lines = [
56
- f'{space}# MIG GPU CUBLAS fix\n',
57
  f'{space}try:\n',
58
- f'{space} return torch.nn.functional.linear(input, weight, bias)\n',
59
  f'{space}except RuntimeError as e:\n',
60
  f'{space} if "CUBLAS" in str(e):\n',
61
- f'{space} input_f32 = input.float()\n',
62
- f'{space} weight_f32 = weight.float()\n',
63
- f'{space} bias_f32 = bias.float() if bias is not None else None\n',
64
- f'{space} result = torch.nn.functional.linear(input_f32, weight_f32, bias_f32)\n',
65
- f'{space} return result.to(input.dtype)\n',
66
- f'{space} raise\n'
67
  ]
68
  lines[i:i+1] = new_lines
69
  patched = True
70
  break
71
 
72
  if patched:
73
- with open(ops_file, 'w') as f:
74
  f.writelines(lines)
75
- print("[OK] Applied MIG GPU CUBLAS fix to linear operations")
76
  else:
77
- print("[WARN] Could not find patch location in ops.py")
78
 
79
  def setup():
80
  """Environment setup for Hugging Face Space"""
 
31
  BYPASS_REPO_DIR = os.path.join(ROOT_DIR, "reference_repo")
32
 
33
  def _patch_qwen_for_mig_gpu():
34
+ """Force Qwen text encoder to CPU - MIG GPU incompatible"""
35
+ llama_file = os.path.join(COMFYUI_DIR, "comfy/text_encoders/llama.py")
36
+ if not os.path.exists(llama_file):
37
  return
38
 
39
+ with open(llama_file, 'r') as f:
40
  lines = f.readlines()
41
 
42
  # Check if patch already applied
43
  content = ''.join(lines)
44
+ if 'MIG GPU: force CPU' in content:
45
+ print("[OK] Qwen CPU fallback already applied")
46
  return
47
 
48
+ # Patch the problematic matmul at line ~301
49
  patched = False
50
  for i, line in enumerate(lines):
51
+ if 'freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)' in line and not patched:
52
  indent = len(line) - len(line.lstrip())
53
  space = ' ' * indent
54
+ # Force this operation to CPU to avoid CUBLAS errors on MIG GPUs
55
  new_lines = [
56
+ f'{space}# MIG GPU: force CPU for matmul to avoid CUBLAS errors\n',
57
  f'{space}try:\n',
58
+ f'{space} freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n',
59
  f'{space}except RuntimeError as e:\n',
60
  f'{space} if "CUBLAS" in str(e):\n',
61
+ f'{space} device = inv_freq_expanded.device\n',
62
+ f'{space} freqs = (inv_freq_expanded.float().cpu() @ position_ids_expanded.float().cpu()).transpose(1, 2).to(device)\n',
63
+ f'{space} else:\n',
64
+ f'{space} raise\n'
 
 
65
  ]
66
  lines[i:i+1] = new_lines
67
  patched = True
68
  break
69
 
70
  if patched:
71
+ with open(llama_file, 'w') as f:
72
  f.writelines(lines)
73
+ print("[OK] Applied Qwen CPU fallback for MIG GPU")
74
  else:
75
+ print("[WARN] Could not find freqs computation in llama.py")
76
 
77
  def setup():
78
  """Environment setup for Hugging Face Space"""