dennny123 Claude Sonnet 4.5 (1M context) commited on
Commit
2cdf689
·
1 Parent(s): e90ff32

Fix ops.py patch - actually move to CPU not just convert dtype

Browse files

Previous patch did input.float() which kept tensors on GPU.
Now does input.cpu().float() to actually run on CPU.
Then moves result back: x.to(device).to(dtype)

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -84,8 +84,9 @@ def _patch_qwen_for_mig_gpu():
84
  f'{space} x = torch.nn.functional.linear(input, weight, bias)\n',
85
  f'{space}except RuntimeError as e:\n',
86
  f'{space} if "CUBLAS" in str(e):\n',
87
- f'{space} x = torch.nn.functional.linear(input.float(), weight.float(), bias.float() if bias is not None else None)\n',
88
- f'{space} x = x.to(input.dtype)\n',
 
89
  f'{space} else:\n',
90
  f'{space} raise\n'
91
  ]
 
84
  f'{space} x = torch.nn.functional.linear(input, weight, bias)\n',
85
  f'{space}except RuntimeError as e:\n',
86
  f'{space} if "CUBLAS" in str(e):\n',
87
+ f'{space} device = input.device\n',
88
+ f'{space} x = torch.nn.functional.linear(input.cpu().float(), weight.cpu().float(), bias.cpu().float() if bias is not None else None)\n',
89
+ f'{space} x = x.to(device).to(input.dtype)\n',
90
  f'{space} else:\n',
91
  f'{space} raise\n'
92
  ]