Kernels
wyldecat Claude Opus 4.6 (1M context) commited on
Commit
3f2678c
·
1 Parent(s): 60615a0

style: fix yapf/isort formatting for CI --all-files check

Browse files

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

tests/test_fused_mul_grouped_poly_norm.py CHANGED
@@ -498,8 +498,14 @@ def test_fused_mul_grouped_poly_norm_hidden_clamp_backward(
498
  PADDING_SIZES = [64, 256]
499
 
500
 
501
- def _make_padded_inputs(num_valid_tokens, num_padding, hidden_dim, num_experts,
502
- dtype, device, seed=42, expert_offset=0):
 
 
 
 
 
 
503
  """Create inputs with extra padding rows (large values) beyond valid tokens."""
504
  torch.manual_seed(seed)
505
  probs = torch.ones(num_experts) / num_experts
@@ -530,7 +536,8 @@ def _make_padded_inputs(num_valid_tokens, num_padding, hidden_dim, num_experts,
530
  @pytest.mark.parametrize("num_experts", [8])
531
  @pytest.mark.parametrize("dtype", DTYPES)
532
  @pytest.mark.parametrize("device", CUDA_DEVICES)
533
- def test_padded_forward(num_tokens, num_padding, d, num_experts, dtype, device):
 
534
  """Forward with padded input: valid rows correct, padding rows zero."""
535
  torch.set_default_device(device)
536
  input_t, mul_t, weight, bias, offsets = _make_padded_inputs(
@@ -547,7 +554,9 @@ def test_padded_forward(num_tokens, num_padding, d, num_experts, dtype, device):
547
  assert out_ref.shape == (M, d)
548
 
549
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
550
- assert_close(out_cuda[:num_valid], out_ref[:num_valid], atol=atol,
 
 
551
  rtol=rtol)
552
  assert out_cuda[num_valid:].abs().max() == 0, \
553
  f"Padding rows not zero: max={out_cuda[num_valid:].abs().max().item()}"
@@ -569,8 +578,9 @@ def test_padded_backward(num_tokens, num_padding, d, num_experts, dtype,
569
  num_valid = int(offsets[-1].item())
570
 
571
  # Reference backward on valid-only rows
572
- _, ig_ref, mg_ref, wg_ref, bg_ref, _ = _run_ref(
573
- input_t[:num_valid], mul_t[:num_valid], weight, bias, offsets)
 
574
 
575
  # CUDA backward on full padded input
576
  _, ig_cuda, mg_cuda, wg_cuda, bg_cuda, _ = _run_cuda(
@@ -611,13 +621,23 @@ def test_padded_forward_scored(num_tokens, num_padding, d, num_experts, dtype,
611
  M = num_tokens + num_padding
612
  scores = _make_scores(M, device)
613
 
614
- out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias,
615
- offsets, scores=scores)
616
- out_cuda = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias,
617
- offsets, scores=scores)
 
 
 
 
 
 
 
 
618
 
619
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
620
- assert_close(out_cuda[:num_valid], out_ref[:num_valid], atol=atol,
 
 
621
  rtol=rtol)
622
  assert out_cuda[num_valid:].abs().max() == 0, \
623
  "Padding rows not zero with scores"
@@ -642,12 +662,20 @@ def test_padded_backward_scored(num_tokens, num_padding, d, num_experts, dtype,
642
 
643
  # Reference on valid-only
644
  _, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
645
- input_t[:num_valid], mul_t[:num_valid], weight, bias, offsets,
 
 
 
 
646
  scores=scores[:num_valid])
647
 
648
  # CUDA on full padded
649
- _, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda(
650
- input_t, mul_t, weight, bias, offsets, scores=scores)
 
 
 
 
651
 
652
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
653
  wg_atol = 5e-4 if dtype == torch.float32 else 5e-2
@@ -681,15 +709,25 @@ def test_padded_forward_scored_clamp(num_tokens, num_padding, d, num_experts,
681
  M = num_tokens + num_padding
682
  scores = _make_scores(M, device)
683
 
684
- out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias,
685
- offsets, scores=scores,
 
 
 
 
686
  hidden_clamp=hidden_clamp)
687
- out_cuda = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias,
688
- offsets, scores=scores,
 
 
 
 
689
  hidden_clamp=hidden_clamp)
690
 
691
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
692
- assert_close(out_cuda[:num_valid], out_ref[:num_valid], atol=atol,
 
 
693
  rtol=rtol)
694
  assert out_cuda[num_valid:].abs().max() == 0, \
695
  "Padding rows not zero with scores+clamp"
@@ -715,12 +753,22 @@ def test_padded_backward_scored_clamp(num_tokens, num_padding, d, num_experts,
715
 
716
  # Reference on valid-only
717
  _, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
718
- input_t[:num_valid], mul_t[:num_valid], weight, bias, offsets,
719
- scores=scores[:num_valid], hidden_clamp=hidden_clamp)
 
 
 
 
 
720
 
721
  # CUDA on full padded
722
  _, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda(
723
- input_t, mul_t, weight, bias, offsets, scores=scores,
 
 
 
 
 
724
  hidden_clamp=hidden_clamp)
725
 
726
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
 
498
  PADDING_SIZES = [64, 256]
499
 
500
 
501
+ def _make_padded_inputs(num_valid_tokens,
502
+ num_padding,
503
+ hidden_dim,
504
+ num_experts,
505
+ dtype,
506
+ device,
507
+ seed=42,
508
+ expert_offset=0):
509
  """Create inputs with extra padding rows (large values) beyond valid tokens."""
510
  torch.manual_seed(seed)
511
  probs = torch.ones(num_experts) / num_experts
 
536
  @pytest.mark.parametrize("num_experts", [8])
537
  @pytest.mark.parametrize("dtype", DTYPES)
538
  @pytest.mark.parametrize("device", CUDA_DEVICES)
539
+ def test_padded_forward(num_tokens, num_padding, d, num_experts, dtype,
540
+ device):
541
  """Forward with padded input: valid rows correct, padding rows zero."""
542
  torch.set_default_device(device)
543
  input_t, mul_t, weight, bias, offsets = _make_padded_inputs(
 
554
  assert out_ref.shape == (M, d)
555
 
556
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
557
+ assert_close(out_cuda[:num_valid],
558
+ out_ref[:num_valid],
559
+ atol=atol,
560
  rtol=rtol)
561
  assert out_cuda[num_valid:].abs().max() == 0, \
562
  f"Padding rows not zero: max={out_cuda[num_valid:].abs().max().item()}"
 
578
  num_valid = int(offsets[-1].item())
579
 
580
  # Reference backward on valid-only rows
581
+ _, ig_ref, mg_ref, wg_ref, bg_ref, _ = _run_ref(input_t[:num_valid],
582
+ mul_t[:num_valid], weight,
583
+ bias, offsets)
584
 
585
  # CUDA backward on full padded input
586
  _, ig_cuda, mg_cuda, wg_cuda, bg_cuda, _ = _run_cuda(
 
621
  M = num_tokens + num_padding
622
  scores = _make_scores(M, device)
623
 
624
+ out_ref = fused_mul_grouped_poly_norm_ref(input_t,
625
+ mul_t,
626
+ weight,
627
+ bias,
628
+ offsets,
629
+ scores=scores)
630
+ out_cuda = fused_mul_grouped_poly_norm(input_t,
631
+ mul_t,
632
+ weight,
633
+ bias,
634
+ offsets,
635
+ scores=scores)
636
 
637
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
638
+ assert_close(out_cuda[:num_valid],
639
+ out_ref[:num_valid],
640
+ atol=atol,
641
  rtol=rtol)
642
  assert out_cuda[num_valid:].abs().max() == 0, \
643
  "Padding rows not zero with scores"
 
662
 
663
  # Reference on valid-only
664
  _, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
665
+ input_t[:num_valid],
666
+ mul_t[:num_valid],
667
+ weight,
668
+ bias,
669
+ offsets,
670
  scores=scores[:num_valid])
671
 
672
  # CUDA on full padded
673
+ _, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda(input_t,
674
+ mul_t,
675
+ weight,
676
+ bias,
677
+ offsets,
678
+ scores=scores)
679
 
680
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
681
  wg_atol = 5e-4 if dtype == torch.float32 else 5e-2
 
709
  M = num_tokens + num_padding
710
  scores = _make_scores(M, device)
711
 
712
+ out_ref = fused_mul_grouped_poly_norm_ref(input_t,
713
+ mul_t,
714
+ weight,
715
+ bias,
716
+ offsets,
717
+ scores=scores,
718
  hidden_clamp=hidden_clamp)
719
+ out_cuda = fused_mul_grouped_poly_norm(input_t,
720
+ mul_t,
721
+ weight,
722
+ bias,
723
+ offsets,
724
+ scores=scores,
725
  hidden_clamp=hidden_clamp)
726
 
727
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
728
+ assert_close(out_cuda[:num_valid],
729
+ out_ref[:num_valid],
730
+ atol=atol,
731
  rtol=rtol)
732
  assert out_cuda[num_valid:].abs().max() == 0, \
733
  "Padding rows not zero with scores+clamp"
 
753
 
754
  # Reference on valid-only
755
  _, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
756
+ input_t[:num_valid],
757
+ mul_t[:num_valid],
758
+ weight,
759
+ bias,
760
+ offsets,
761
+ scores=scores[:num_valid],
762
+ hidden_clamp=hidden_clamp)
763
 
764
  # CUDA on full padded
765
  _, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda(
766
+ input_t,
767
+ mul_t,
768
+ weight,
769
+ bias,
770
+ offsets,
771
+ scores=scores,
772
  hidden_clamp=hidden_clamp)
773
 
774
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
torch-ext/activation/_ops.py CHANGED
@@ -1,14 +1,14 @@
1
  """Op loader — works with both kernel-builder (.abi3.so) and local setup.py builds."""
2
 
3
  import importlib
 
4
  import torch
5
 
6
  # Try kernel-builder build first (namespace like _activation_HASH)
7
  # Fall back to local setup.py build (_activation)
8
  _lib = None
9
- for _name in sorted(
10
- [m for m in dir() if m.startswith("_activation")], reverse=True
11
- ):
12
  _lib = importlib.import_module(f".{_name}", __package__)
13
  break
14
 
@@ -28,8 +28,7 @@ if _lib is None:
28
  ops = getattr(torch.ops, _mod_name)
29
  else:
30
  raise ImportError(
31
- "No activation extension found. Build with: pip install -e ."
32
- )
33
  else:
34
  ops = getattr(torch.ops, _lib.__name__.split(".")[-1])
35
 
 
1
  """Op loader — works with both kernel-builder (.abi3.so) and local setup.py builds."""
2
 
3
  import importlib
4
+
5
  import torch
6
 
7
  # Try kernel-builder build first (namespace like _activation_HASH)
8
  # Fall back to local setup.py build (_activation)
9
  _lib = None
10
+ for _name in sorted([m for m in dir() if m.startswith("_activation")],
11
+ reverse=True):
 
12
  _lib = importlib.import_module(f".{_name}", __package__)
13
  break
14
 
 
28
  ops = getattr(torch.ops, _mod_name)
29
  else:
30
  raise ImportError(
31
+ "No activation extension found. Build with: pip install -e .")
 
32
  else:
33
  ops = getattr(torch.ops, _lib.__name__.split(".")[-1])
34
 
torch-ext/activation/grouped_poly_norm.py CHANGED
@@ -146,8 +146,10 @@ def fused_mul_grouped_poly_norm_ref(
146
  if hidden_clamp is not None:
147
  result = result.clamp(-hidden_clamp, hidden_clamp)
148
  if input.shape[0] > num_valid:
149
- padding = torch.zeros(input.shape[0] - num_valid, input.shape[-1],
150
- dtype=orig_dtype, device=input.device)
 
 
151
  return torch.cat([result.to(orig_dtype), padding], dim=0)
152
  return result.to(orig_dtype)
153
 
 
146
  if hidden_clamp is not None:
147
  result = result.clamp(-hidden_clamp, hidden_clamp)
148
  if input.shape[0] > num_valid:
149
+ padding = torch.zeros(input.shape[0] - num_valid,
150
+ input.shape[-1],
151
+ dtype=orig_dtype,
152
+ device=input.device)
153
  return torch.cat([result.to(orig_dtype), padding], dim=0)
154
  return result.to(orig_dtype)
155