File size: 590 Bytes
f9791fd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | import platform
import torch
import first_kernel
def test_first_kernel():
if platform.system() == "Darwin":
device = torch.device("mps")
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
elif torch.version.cuda is not None and torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
x = torch.randn(1024, 1024, dtype=torch.float32, device=device)
expected = x + 1.0
result = first_kernel.first_kernel(x)
torch.testing.assert_close(result, expected)
|