style: fix yapf/isort formatting for CI --all-files check
Browse filesCo-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
tests/test_fused_mul_grouped_poly_norm.py
CHANGED
|
@@ -498,8 +498,14 @@ def test_fused_mul_grouped_poly_norm_hidden_clamp_backward(
|
|
| 498 |
PADDING_SIZES = [64, 256]
|
| 499 |
|
| 500 |
|
| 501 |
-
def _make_padded_inputs(num_valid_tokens,
|
| 502 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
"""Create inputs with extra padding rows (large values) beyond valid tokens."""
|
| 504 |
torch.manual_seed(seed)
|
| 505 |
probs = torch.ones(num_experts) / num_experts
|
|
@@ -530,7 +536,8 @@ def _make_padded_inputs(num_valid_tokens, num_padding, hidden_dim, num_experts,
|
|
| 530 |
@pytest.mark.parametrize("num_experts", [8])
|
| 531 |
@pytest.mark.parametrize("dtype", DTYPES)
|
| 532 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 533 |
-
def test_padded_forward(num_tokens, num_padding, d, num_experts, dtype,
|
|
|
|
| 534 |
"""Forward with padded input: valid rows correct, padding rows zero."""
|
| 535 |
torch.set_default_device(device)
|
| 536 |
input_t, mul_t, weight, bias, offsets = _make_padded_inputs(
|
|
@@ -547,7 +554,9 @@ def test_padded_forward(num_tokens, num_padding, d, num_experts, dtype, device):
|
|
| 547 |
assert out_ref.shape == (M, d)
|
| 548 |
|
| 549 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
|
| 550 |
-
assert_close(out_cuda[:num_valid],
|
|
|
|
|
|
|
| 551 |
rtol=rtol)
|
| 552 |
assert out_cuda[num_valid:].abs().max() == 0, \
|
| 553 |
f"Padding rows not zero: max={out_cuda[num_valid:].abs().max().item()}"
|
|
@@ -569,8 +578,9 @@ def test_padded_backward(num_tokens, num_padding, d, num_experts, dtype,
|
|
| 569 |
num_valid = int(offsets[-1].item())
|
| 570 |
|
| 571 |
# Reference backward on valid-only rows
|
| 572 |
-
_, ig_ref, mg_ref, wg_ref, bg_ref, _ = _run_ref(
|
| 573 |
-
|
|
|
|
| 574 |
|
| 575 |
# CUDA backward on full padded input
|
| 576 |
_, ig_cuda, mg_cuda, wg_cuda, bg_cuda, _ = _run_cuda(
|
|
@@ -611,13 +621,23 @@ def test_padded_forward_scored(num_tokens, num_padding, d, num_experts, dtype,
|
|
| 611 |
M = num_tokens + num_padding
|
| 612 |
scores = _make_scores(M, device)
|
| 613 |
|
| 614 |
-
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
|
| 619 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
|
| 620 |
-
assert_close(out_cuda[:num_valid],
|
|
|
|
|
|
|
| 621 |
rtol=rtol)
|
| 622 |
assert out_cuda[num_valid:].abs().max() == 0, \
|
| 623 |
"Padding rows not zero with scores"
|
|
@@ -642,12 +662,20 @@ def test_padded_backward_scored(num_tokens, num_padding, d, num_experts, dtype,
|
|
| 642 |
|
| 643 |
# Reference on valid-only
|
| 644 |
_, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
|
| 645 |
-
input_t[:num_valid],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
scores=scores[:num_valid])
|
| 647 |
|
| 648 |
# CUDA on full padded
|
| 649 |
-
_, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda(
|
| 650 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
|
| 652 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
|
| 653 |
wg_atol = 5e-4 if dtype == torch.float32 else 5e-2
|
|
@@ -681,15 +709,25 @@ def test_padded_forward_scored_clamp(num_tokens, num_padding, d, num_experts,
|
|
| 681 |
M = num_tokens + num_padding
|
| 682 |
scores = _make_scores(M, device)
|
| 683 |
|
| 684 |
-
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
|
| 685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
hidden_clamp=hidden_clamp)
|
| 687 |
-
out_cuda = fused_mul_grouped_poly_norm(input_t,
|
| 688 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
hidden_clamp=hidden_clamp)
|
| 690 |
|
| 691 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
|
| 692 |
-
assert_close(out_cuda[:num_valid],
|
|
|
|
|
|
|
| 693 |
rtol=rtol)
|
| 694 |
assert out_cuda[num_valid:].abs().max() == 0, \
|
| 695 |
"Padding rows not zero with scores+clamp"
|
|
@@ -715,12 +753,22 @@ def test_padded_backward_scored_clamp(num_tokens, num_padding, d, num_experts,
|
|
| 715 |
|
| 716 |
# Reference on valid-only
|
| 717 |
_, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
|
| 718 |
-
input_t[:num_valid],
|
| 719 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
|
| 721 |
# CUDA on full padded
|
| 722 |
_, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda(
|
| 723 |
-
input_t,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
hidden_clamp=hidden_clamp)
|
| 725 |
|
| 726 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
|
|
|
|
| 498 |
PADDING_SIZES = [64, 256]
|
| 499 |
|
| 500 |
|
| 501 |
+
def _make_padded_inputs(num_valid_tokens,
|
| 502 |
+
num_padding,
|
| 503 |
+
hidden_dim,
|
| 504 |
+
num_experts,
|
| 505 |
+
dtype,
|
| 506 |
+
device,
|
| 507 |
+
seed=42,
|
| 508 |
+
expert_offset=0):
|
| 509 |
"""Create inputs with extra padding rows (large values) beyond valid tokens."""
|
| 510 |
torch.manual_seed(seed)
|
| 511 |
probs = torch.ones(num_experts) / num_experts
|
|
|
|
| 536 |
@pytest.mark.parametrize("num_experts", [8])
|
| 537 |
@pytest.mark.parametrize("dtype", DTYPES)
|
| 538 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 539 |
+
def test_padded_forward(num_tokens, num_padding, d, num_experts, dtype,
|
| 540 |
+
device):
|
| 541 |
"""Forward with padded input: valid rows correct, padding rows zero."""
|
| 542 |
torch.set_default_device(device)
|
| 543 |
input_t, mul_t, weight, bias, offsets = _make_padded_inputs(
|
|
|
|
| 554 |
assert out_ref.shape == (M, d)
|
| 555 |
|
| 556 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
|
| 557 |
+
assert_close(out_cuda[:num_valid],
|
| 558 |
+
out_ref[:num_valid],
|
| 559 |
+
atol=atol,
|
| 560 |
rtol=rtol)
|
| 561 |
assert out_cuda[num_valid:].abs().max() == 0, \
|
| 562 |
f"Padding rows not zero: max={out_cuda[num_valid:].abs().max().item()}"
|
|
|
|
| 578 |
num_valid = int(offsets[-1].item())
|
| 579 |
|
| 580 |
# Reference backward on valid-only rows
|
| 581 |
+
_, ig_ref, mg_ref, wg_ref, bg_ref, _ = _run_ref(input_t[:num_valid],
|
| 582 |
+
mul_t[:num_valid], weight,
|
| 583 |
+
bias, offsets)
|
| 584 |
|
| 585 |
# CUDA backward on full padded input
|
| 586 |
_, ig_cuda, mg_cuda, wg_cuda, bg_cuda, _ = _run_cuda(
|
|
|
|
| 621 |
M = num_tokens + num_padding
|
| 622 |
scores = _make_scores(M, device)
|
| 623 |
|
| 624 |
+
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
|
| 625 |
+
mul_t,
|
| 626 |
+
weight,
|
| 627 |
+
bias,
|
| 628 |
+
offsets,
|
| 629 |
+
scores=scores)
|
| 630 |
+
out_cuda = fused_mul_grouped_poly_norm(input_t,
|
| 631 |
+
mul_t,
|
| 632 |
+
weight,
|
| 633 |
+
bias,
|
| 634 |
+
offsets,
|
| 635 |
+
scores=scores)
|
| 636 |
|
| 637 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
|
| 638 |
+
assert_close(out_cuda[:num_valid],
|
| 639 |
+
out_ref[:num_valid],
|
| 640 |
+
atol=atol,
|
| 641 |
rtol=rtol)
|
| 642 |
assert out_cuda[num_valid:].abs().max() == 0, \
|
| 643 |
"Padding rows not zero with scores"
|
|
|
|
| 662 |
|
| 663 |
# Reference on valid-only
|
| 664 |
_, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
|
| 665 |
+
input_t[:num_valid],
|
| 666 |
+
mul_t[:num_valid],
|
| 667 |
+
weight,
|
| 668 |
+
bias,
|
| 669 |
+
offsets,
|
| 670 |
scores=scores[:num_valid])
|
| 671 |
|
| 672 |
# CUDA on full padded
|
| 673 |
+
_, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda(input_t,
|
| 674 |
+
mul_t,
|
| 675 |
+
weight,
|
| 676 |
+
bias,
|
| 677 |
+
offsets,
|
| 678 |
+
scores=scores)
|
| 679 |
|
| 680 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
|
| 681 |
wg_atol = 5e-4 if dtype == torch.float32 else 5e-2
|
|
|
|
| 709 |
M = num_tokens + num_padding
|
| 710 |
scores = _make_scores(M, device)
|
| 711 |
|
| 712 |
+
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
|
| 713 |
+
mul_t,
|
| 714 |
+
weight,
|
| 715 |
+
bias,
|
| 716 |
+
offsets,
|
| 717 |
+
scores=scores,
|
| 718 |
hidden_clamp=hidden_clamp)
|
| 719 |
+
out_cuda = fused_mul_grouped_poly_norm(input_t,
|
| 720 |
+
mul_t,
|
| 721 |
+
weight,
|
| 722 |
+
bias,
|
| 723 |
+
offsets,
|
| 724 |
+
scores=scores,
|
| 725 |
hidden_clamp=hidden_clamp)
|
| 726 |
|
| 727 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
|
| 728 |
+
assert_close(out_cuda[:num_valid],
|
| 729 |
+
out_ref[:num_valid],
|
| 730 |
+
atol=atol,
|
| 731 |
rtol=rtol)
|
| 732 |
assert out_cuda[num_valid:].abs().max() == 0, \
|
| 733 |
"Padding rows not zero with scores+clamp"
|
|
|
|
| 753 |
|
| 754 |
# Reference on valid-only
|
| 755 |
_, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
|
| 756 |
+
input_t[:num_valid],
|
| 757 |
+
mul_t[:num_valid],
|
| 758 |
+
weight,
|
| 759 |
+
bias,
|
| 760 |
+
offsets,
|
| 761 |
+
scores=scores[:num_valid],
|
| 762 |
+
hidden_clamp=hidden_clamp)
|
| 763 |
|
| 764 |
# CUDA on full padded
|
| 765 |
_, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda(
|
| 766 |
+
input_t,
|
| 767 |
+
mul_t,
|
| 768 |
+
weight,
|
| 769 |
+
bias,
|
| 770 |
+
offsets,
|
| 771 |
+
scores=scores,
|
| 772 |
hidden_clamp=hidden_clamp)
|
| 773 |
|
| 774 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
|
torch-ext/activation/_ops.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
"""Op loader — works with both kernel-builder (.abi3.so) and local setup.py builds."""
|
| 2 |
|
| 3 |
import importlib
|
|
|
|
| 4 |
import torch
|
| 5 |
|
| 6 |
# Try kernel-builder build first (namespace like _activation_HASH)
|
| 7 |
# Fall back to local setup.py build (_activation)
|
| 8 |
_lib = None
|
| 9 |
-
for _name in sorted(
|
| 10 |
-
|
| 11 |
-
):
|
| 12 |
_lib = importlib.import_module(f".{_name}", __package__)
|
| 13 |
break
|
| 14 |
|
|
@@ -28,8 +28,7 @@ if _lib is None:
|
|
| 28 |
ops = getattr(torch.ops, _mod_name)
|
| 29 |
else:
|
| 30 |
raise ImportError(
|
| 31 |
-
"No activation extension found. Build with: pip install -e ."
|
| 32 |
-
)
|
| 33 |
else:
|
| 34 |
ops = getattr(torch.ops, _lib.__name__.split(".")[-1])
|
| 35 |
|
|
|
|
| 1 |
"""Op loader — works with both kernel-builder (.abi3.so) and local setup.py builds."""
|
| 2 |
|
| 3 |
import importlib
|
| 4 |
+
|
| 5 |
import torch
|
| 6 |
|
| 7 |
# Try kernel-builder build first (namespace like _activation_HASH)
|
| 8 |
# Fall back to local setup.py build (_activation)
|
| 9 |
_lib = None
|
| 10 |
+
for _name in sorted([m for m in dir() if m.startswith("_activation")],
|
| 11 |
+
reverse=True):
|
|
|
|
| 12 |
_lib = importlib.import_module(f".{_name}", __package__)
|
| 13 |
break
|
| 14 |
|
|
|
|
| 28 |
ops = getattr(torch.ops, _mod_name)
|
| 29 |
else:
|
| 30 |
raise ImportError(
|
| 31 |
+
"No activation extension found. Build with: pip install -e .")
|
|
|
|
| 32 |
else:
|
| 33 |
ops = getattr(torch.ops, _lib.__name__.split(".")[-1])
|
| 34 |
|
torch-ext/activation/grouped_poly_norm.py
CHANGED
|
@@ -146,8 +146,10 @@ def fused_mul_grouped_poly_norm_ref(
|
|
| 146 |
if hidden_clamp is not None:
|
| 147 |
result = result.clamp(-hidden_clamp, hidden_clamp)
|
| 148 |
if input.shape[0] > num_valid:
|
| 149 |
-
padding = torch.zeros(input.shape[0] - num_valid,
|
| 150 |
-
|
|
|
|
|
|
|
| 151 |
return torch.cat([result.to(orig_dtype), padding], dim=0)
|
| 152 |
return result.to(orig_dtype)
|
| 153 |
|
|
|
|
| 146 |
if hidden_clamp is not None:
|
| 147 |
result = result.clamp(-hidden_clamp, hidden_clamp)
|
| 148 |
if input.shape[0] > num_valid:
|
| 149 |
+
padding = torch.zeros(input.shape[0] - num_valid,
|
| 150 |
+
input.shape[-1],
|
| 151 |
+
dtype=orig_dtype,
|
| 152 |
+
device=input.device)
|
| 153 |
return torch.cat([result.to(orig_dtype), padding], dim=0)
|
| 154 |
return result.to(orig_dtype)
|
| 155 |
|