Kernels
wyldecat Claude Opus 4.6 (1M context) commited on
Commit
7d51e61
ยท
1 Parent(s): 972d63b

feat: replace triton do_bench with torch.profiler for kernel timing

Browse files

Switch from triton.testing.do_bench to torch.profiler-based CUDA kernel
time measurement. do_bench includes PyTorch CPU overhead in timing,
leading to inaccurate results. torch.profiler sums only actual CUDA
kernel durations for precise measurement.

- Add profile_bench() with per-kernel breakdown and bandwidth output
- Add GB/s and SpeedUp(ratio) columns to CSV reports
- Fix grouped_mul_poly import path (activation.grouped_poly_norm)
- Add docs/benchmark.md for benchmark system documentation
- Update CLAUDE.md build system description

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

CLAUDE.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ Custom CUDA/ROCm normalization kernels for LLM training and inference, published as `motif-technologies/activation` on HuggingFace. Implements PolyNorm, RMSNorm, FusedAddRMSNorm, and FusedMulPolyNorm with full autograd support, fake tensor registration, and DTensor sharding strategies for sequence parallelism.
8
+
9
+ ## Build System
10
+
11
+ ### Local development build (primary)
12
+
13
+ ```bash
14
+ pip install -e . # editable install (build + install)
15
+ python setup.py build_ext --inplace # build only
16
+ ```
17
+
18
+ `setup.py` does two things:
19
+ 1. Compiles `_activation` C extension from CUDA sources (`activation/*.cu`) + C++ binding (`torch-ext/torch_binding.cpp`)
20
+ 2. Installs `activation` Python package from `torch-ext/activation/` (autograd functions, layers, etc.)
21
+
22
+ After `pip install -e .`, all imports work directly (`import activation`, `from activation.grouped_poly_norm import ...`). No PYTHONPATH manipulation needed.
23
+
24
+ NVCC flags: `-O3 --use_fast_math -std=c++17`, targets sm_80 (A100), sm_89 (L40/4090), sm_90 (H100), sm_100 (B200, CUDA 12.8+).
25
+
26
+ ### CI / HuggingFace distribution build
27
+
28
+ Uses HuggingFace's `kernel-builder` via Nix for cross-compilation of pre-built `.abi3.so` binaries.
29
+
30
+ ```bash
31
+ nix run .#build-and-copy
32
+ ```
33
+
34
+ Pre-built binaries go to `build/` (tracked via Git LFS). The build config lives in `build.toml`.
35
+
36
+ ## Running Tests
37
+
38
+ Tests require a GPU. Install first with `pip install -e .`.
39
+
40
+ ```bash
41
+ # Run all tests
42
+ pytest tests/
43
+
44
+ # Run a single test file
45
+ pytest tests/test_rms_norm.py
46
+
47
+ # Run a specific test
48
+ pytest tests/test_rms_norm.py::test_rms_norm_forward -v
49
+
50
+ # Sequence parallel tests (require torch>=2.8 and 2+ GPUs)
51
+ torchrun --nproc-per-node=2 -m pytest tests/test_rms_norm_sequence_parallel.py
52
+ ```
53
+
54
+ Pytest config is in `tests/pytest.ini` (log_cli enabled at INFO level).
55
+
56
+ ## Linting / Formatting
57
+
58
+ Pre-commit hooks handle all formatting. Install with `pre-commit install`.
59
+
60
+ - **Python**: yapf (formatter), isort (imports)
61
+ - **C++/CUDA**: clang-format (`--style=file`)
62
+ - **Markdown**: pymarkdown
63
+ - **Spelling**: typos
64
+
65
+ The `build/` and `result/` directories are excluded from all hooks.
66
+
67
+ ## Architecture
68
+
69
+ ### Layer structure (bottom-up)
70
+
71
+ 1. **CUDA/HIP kernels** (`activation/*.cu`, `activation/*.h`): Hand-written kernels that compile for both NVIDIA (CUB) and AMD (hipcub). Each kernel uses vectorized template dispatch (`width > 0` for coalesced 128-bit loads when `dim % 8 == 0`, scalar fallback otherwise). Accumulation is always in float32.
72
+
73
+ 2. **C++ torch binding** (`torch-ext/torch_binding.cpp`): Registers ops via `TORCH_LIBRARY_EXPAND` under a build-specific namespace (e.g., `_activation_53ed492_dirty`).
74
+
75
+ 3. **Autograd Functions** (`torch-ext/activation/poly_norm.py`, `rms_norm.py`): `torch.autograd.Function` subclasses with `forward`, `setup_context`, and `backward`. Also registers `@torch.library.register_fake` for `torch.compile`/AOT support.
76
+
77
+ 4. **DTensor sharding strategies** (`torch-ext/activation/rms_norm_meta.py`, `fused_add_rms_norm_meta.py`): `@register_op_strategy` definitions. Input can be sharded on any dim except the last (normalization dim); weight is always replicated.
78
+
79
+ 5. **nn.Module wrappers** (`torch-ext/activation/layers.py`): `PolyNorm`, `FusedMulPolyNorm`, `RMSNorm`, `FusedAddRMSNorm` โ€” these are the public user-facing API.
80
+
81
+ 6. **Parallel style** (`torch-ext/activation/parallel_style.py`): `ResidualSequenceParallel` extends PyTorch's `SequenceParallel` for the two-input (x + residual) pattern of `FusedAddRMSNorm`.
82
+
83
+ ### Key files
84
+
85
+ | Path | Purpose |
86
+ |------|---------|
87
+ | `build.toml` | kernel-builder manifest (backends, source files, ROCm arch targets) |
88
+ | `activation/cuda_compat.h` | CUDA/ROCm compatibility shim (CUB vs hipcub, WARP_SIZE) |
89
+ | `activation/dispatch_utils.h` | `MOTIF_DISPATCH_FLOATING_TYPES` type dispatch macro |
90
+ | `torch-ext/activation/__init__.py` | Package entry point; exports functional API + layers + parallel_style |
91
+ | `torch-ext/activation/_ops.py` | Generated at build time; loads `.abi3.so` and exposes `torch.ops.*` |
92
+
93
+ ### Adding a new kernel
94
+
95
+ 1. Write the CUDA kernel in `activation/new_kernel.cu` (follow the vectorized template pattern from existing kernels)
96
+ 2. Add C++ declarations in `torch-ext/torch_binding.h` and register in `torch-ext/torch_binding.cpp`
97
+ 3. Create `torch.autograd.Function` in `torch-ext/activation/new_kernel.py` with fake tensor registration
98
+ 4. Add `nn.Module` wrapper in `torch-ext/activation/layers.py`
99
+ 5. If distributed support is needed, add DTensor strategy in `torch-ext/activation/new_kernel_meta.py`
100
+ 6. Add the `.cu` file to both `[kernel.activation]` and `[kernel.activation_cuda]` in `build.toml`
101
+ 7. Export from `torch-ext/activation/__init__.py`
102
+ 8. Add tests in `tests/test_new_kernel.py` (numerical comparison vs PyTorch reference + `torch.library.opcheck`)
103
+
104
+ ### Test conventions
105
+
106
+ - Tests compare custom ops against PyTorch reference implementations using tolerances from `tests/allclose_default.py`
107
+ - Every test runs `torch.library.opcheck` to validate op schema, autograd, fake tensor, and AOT dispatch
108
+ - Tests are parametrized over dtypes (float32, float16, bfloat16), sequence lengths, and hidden dimensions
109
+ - Sequence parallel tests use `torchrun` with 2 GPUs and require torch>=2.8
110
+
111
+ ### ROCm/CUDA target matrix
112
+
113
+ - **ROCm architectures**: gfx90a (MI250), gfx942 (MI300X)
114
+ - **PyTorch versions**: 2.7 through 2.10
115
+ - **CUDA versions**: 11.8, 12.6, 12.8, 12.9, 13.0
116
+ - **ROCm versions**: 6.3, 6.4, 7.0, 7.1
benchmarks/benchmark_profiler.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apiVersion: trainer.kubeflow.org/v1alpha1
2
+ kind: TrainJob
3
+ metadata:
4
+ name: jeesoo-grouped-polynorm-profiler-bench
5
+ namespace: kbm-g-np-motif
6
+ spec:
7
+ managedBy: trainer.kubeflow.org/trainjob-controller
8
+ podTemplateOverrides:
9
+ - spec:
10
+ containers:
11
+ - name: node
12
+ volumeMounts:
13
+ - mountPath: /dev/shm
14
+ name: shm
15
+ - mountPath: /mair
16
+ name: mair
17
+ volumes:
18
+ - emptyDir:
19
+ medium: Memory
20
+ sizeLimit: 64Gi
21
+ name: shm
22
+ - name: mair
23
+ persistentVolumeClaim:
24
+ claimName: mair
25
+ targetJobs:
26
+ - name: node
27
+ runtimeRef:
28
+ apiGroup: trainer.kubeflow.org
29
+ kind: ClusterTrainingRuntime
30
+ name: torch-distributed
31
+ suspend: false
32
+ trainer:
33
+ args:
34
+ - /bin/bash
35
+ - '-c'
36
+ - >
37
+ ACTIVATIONPATH=/mair/team-sys/jeesoo/activation
38
+
39
+ BUILDLOG=$ACTIVATIONPATH/benchmarks/results/build.log
40
+
41
+ pip install triton matplotlib pandas
42
+
43
+ echo "=== Building with setup.py ==="
44
+
45
+ cd $ACTIVATIONPATH
46
+
47
+ rm -f $ACTIVATIONPATH/_activation*.so
48
+
49
+ pip install --no-build-isolation -e . -v 2>&1 | tee $BUILDLOG | tail -200
50
+
51
+ python -c "import _activation; print('Build OK:', _activation)" || { echo "BUILD FAILED"; exit 0; }
52
+
53
+ echo "=== Build success. Running profiler benchmarks ==="
54
+
55
+ cd $ACTIVATIONPATH/benchmarks
56
+
57
+ DATESTAMP=$(date +'%y_%m_%d_%H_%M')
58
+
59
+ SAVE_PATH=$ACTIVATIONPATH/benchmarks/results/${DATESTAMP}
60
+
61
+ mkdir -p $SAVE_PATH/bench/grouped_mul_poly/bf16
62
+
63
+ nvidia-smi | tee $SAVE_PATH/nvidia_smi.txt
64
+
65
+ python -c "import torch; x=torch.randn(8192,1280,device='cuda',dtype=torch.bfloat16); [torch.mm(x.T,x) for _ in range(100)]; torch.cuda.synchronize(); print('warmup done')"
66
+
67
+ echo "=== Benchmark (torch.profiler) ==="
68
+
69
+ python run_cases.py --case grouped_mul_poly --dtype bf16 --save-path ${SAVE_PATH}/bench 2>&1 | tee ${SAVE_PATH}/bench_log.txt
70
+
71
+ echo "=== Done ==="
72
+
73
+ exit 0;
74
+ env:
75
+ - name: PYTHONUNBUFFERED
76
+ value: '1'
77
+ - name: PYTORCH_ALLOC_CONF
78
+ value: expandable_segments:True
79
+ - name: CUDA_LAUNCH_BLOCKING
80
+ value: '0'
81
+ - name: OMP_NUM_THREADS
82
+ value: '1'
83
+ - name: HF_HOME
84
+ value: /mair/llm-dataset/hf_cache
85
+ image: ghcr.io/motiftechnologies/llm-training:v0.1.3
86
+ numNodes: 1
87
+ numProcPerNode: 1
88
+ resourcesPerNode:
89
+ limits:
90
+ cpu: '16'
91
+ memory: 128Gi
92
+ nvidia.com/gpu: '1'
93
+ requests:
94
+ cpu: '16'
95
+ memory: 128Gi
96
+ nvidia.com/gpu: '1'
benchmarks/cases/grouped_mul_poly.py CHANGED
@@ -5,8 +5,8 @@ from common.diff_engine import DiffCase
5
 
6
  torch._functorch.config.donated_buffer = False
7
 
8
- from grouped_poly_norm import (fused_mul_grouped_poly_norm,
9
- fused_mul_grouped_poly_norm_ref)
10
 
11
  # 384 / 8 (EP) = 48 experts per rank
12
  # total_tokens = bs * sl, which equals per-rank tokens
@@ -73,7 +73,8 @@ class GroupedMulPoly(DiffCase):
73
  probs = torch.ones(num_experts) / num_experts
74
  assignments = torch.multinomial(probs, total_tokens, replacement=True)
75
  counts = torch.bincount(assignments, minlength=num_experts).tolist()
76
- offsets = torch.cumsum(torch.tensor(counts, dtype=torch.int32), dim=0)
 
77
 
78
  return {
79
  "x":
 
5
 
6
  torch._functorch.config.donated_buffer = False
7
 
8
+ from activation.grouped_poly_norm import (fused_mul_grouped_poly_norm,
9
+ fused_mul_grouped_poly_norm_ref)
10
 
11
  # 384 / 8 (EP) = 48 experts per rank
12
  # total_tokens = bs * sl, which equals per-rank tokens
 
73
  probs = torch.ones(num_experts) / num_experts
74
  assignments = torch.multinomial(probs, total_tokens, replacement=True)
75
  counts = torch.bincount(assignments, minlength=num_experts).tolist()
76
+ offsets = torch.cumsum(torch.tensor(counts, dtype=torch.int32),
77
+ dim=0).to(torch.int32)
78
 
79
  return {
80
  "x":
benchmarks/common/bench_framework.py CHANGED
@@ -4,11 +4,112 @@ import re
4
  from typing import Any, Dict, Sequence
5
 
6
  import torch
 
7
  import triton
8
 
9
  from .diff_engine import DiffCase
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def make_fwd_key(batch_size, seq_len, dim):
13
  return f"forward : ({batch_size}, {seq_len}, {dim})"
14
 
@@ -31,7 +132,7 @@ def make_fwd_benchmark_for_case(
31
  case: DiffCase,
32
  configs: Sequence[tuple[int, int, int]],
33
  plot_name: str,
34
- ylabel: str = "us",
35
  line_vals=("naive", "cuda", "speedup"),
36
  line_names: Dict[str, str] | None = None,
37
  dtype=torch.bfloat16,
@@ -39,6 +140,7 @@ def make_fwd_benchmark_for_case(
39
  time_unit_scale: float = 1000,
40
  ):
41
  timings_ms = collections.defaultdict(dict)
 
42
  line_vals = list(line_vals)
43
  line_names = line_names or {v: v.title() for v in line_vals}
44
  x_vals = [list(_) for _ in configs]
@@ -56,7 +158,11 @@ def make_fwd_benchmark_for_case(
56
  key = make_fwd_key(dim, batch_size, seq_len)
57
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
58
  if provider == "speedup":
59
- return timings_ms["naive"][key] / timings_ms["cuda"][key]
 
 
 
 
60
  if provider == "naive":
61
  obj = case.make_naive(I)
62
  elif provider == "compiled" and hasattr(case, "make_compiled"):
@@ -64,7 +170,10 @@ def make_fwd_benchmark_for_case(
64
  else:
65
  obj = case.make_cuda(I)
66
  run = lambda: case.forward(obj, I)
67
- ms = triton.testing.do_bench(run)
 
 
 
68
  timings_ms[provider][key] = ms
69
  return time_unit_scale * ms
70
 
@@ -113,10 +222,12 @@ def make_fwd_benchmark_plot_for_case(
113
  else:
114
  obj = case.make_cuda(I)
115
  run = lambda: case.forward(obj, I)
116
- ms = triton.testing.do_bench(run)
 
 
117
  timings_ms[provider][config] = ms
118
  if provider == "cuda":
119
- ratio = timings_ms["naive"][config] / timings_ms["cuda"][config]
120
  spdup_ratio.append(ratio)
121
  return round(ratio, 2)
122
  else:
@@ -130,7 +241,7 @@ def make_bwd_benchmark_for_case(
130
  case: DiffCase,
131
  configs: Sequence[tuple[int, int, int]],
132
  plot_name: str,
133
- ylabel: str = "us",
134
  line_vals=("naive", "cuda", "speedup"),
135
  line_names: Dict[str, str] | None = None,
136
  dtype=torch.bfloat16,
@@ -138,6 +249,7 @@ def make_bwd_benchmark_for_case(
138
  time_unit_scale: float = 1000,
139
  ):
140
  timings_ms = collections.defaultdict(dict)
 
141
  line_vals = list(line_vals)
142
  line_names = line_names or {v: v.title() for v in line_vals}
143
  x_vals = [list(_) for _ in configs]
@@ -155,7 +267,11 @@ def make_bwd_benchmark_for_case(
155
  key = make_bwd_key(dim, batch_size, seq_len)
156
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
157
  if provider == "speedup":
158
- return timings_ms["naive"][key] / timings_ms["cuda"][key]
 
 
 
 
159
  if provider == "naive":
160
  obj = case.make_naive(I)
161
  elif provider == "compiled" and hasattr(case, "make_compiled"):
@@ -174,7 +290,11 @@ def make_bwd_benchmark_for_case(
174
  retain_graph=True,
175
  create_graph=False,
176
  allow_unused=False)
177
- ms = triton.testing.do_bench(run)
 
 
 
 
178
  timings_ms[provider][key] = ms
179
  return time_unit_scale * ms
180
 
@@ -234,10 +354,13 @@ def make_bwd_benchmark_plot_for_case(
234
  retain_graph=True,
235
  create_graph=False,
236
  allow_unused=False)
237
- ms = triton.testing.do_bench(run)
 
 
 
238
  timings_ms[provider][config] = ms
239
  if provider == "cuda":
240
- ratio = timings_ms["naive"][config] / timings_ms["cuda"][config]
241
  spdup_ratio.append(ratio)
242
  return round(ratio, 2)
243
  else:
 
4
  from typing import Any, Dict, Sequence
5
 
6
  import torch
7
+ from torch.profiler import ProfilerActivity, profile
8
  import triton
9
 
10
  from .diff_engine import DiffCase
11
 
12
 
13
+ def _get_best_cuda_timing(timings_ms, key):
14
+ """Look up the best CUDA-based timing for speedup calculation."""
15
+ for provider in ("cuda", "compiled_cuda"):
16
+ if provider in timings_ms and key in timings_ms[provider]:
17
+ return timings_ms[provider][key]
18
+ raise KeyError(f"No CUDA timing found for {key}")
19
+
20
+
21
+ def _shorten_kernel_name(name: str) -> str:
22
+ """Strip template args and function params from CUDA kernel names.
23
+
24
+ ``void motif::grouped_poly_norm_bwd_kernel<...>(...)``
25
+ โ†’ ``motif::grouped_poly_norm_bwd_kernel``
26
+ """
27
+ # Remove leading 'void '
28
+ s = re.sub(r"^void\s+", "", name)
29
+ # Remove template args <...> (handles nested <>)
30
+ while "<" in s:
31
+ s = re.sub(r"<[^<>]*>", "", s)
32
+ # Remove function params (...)
33
+ s = re.sub(r"\(.*\)$", "", s)
34
+ return s.strip()
35
+
36
+
37
+ def _compute_bytes(inputs, forward_fn, obj):
38
+ """Compute total bytes: all input tensors read + all output tensors written."""
39
+ input_bytes = sum(v.nbytes for v in inputs.values()
40
+ if isinstance(v, torch.Tensor))
41
+ output = forward_fn()
42
+ if isinstance(output, torch.Tensor):
43
+ output_bytes = output.nbytes
44
+ elif isinstance(output, (tuple, list)):
45
+ output_bytes = sum(
46
+ o.nbytes for o in output if isinstance(o, torch.Tensor))
47
+ else:
48
+ output_bytes = 0
49
+ return input_bytes + output_bytes
50
+
51
+
52
+ def profile_bench(fn, warmup=5, repeat=10, verbose=True, total_bytes=0):
53
+ """Measure CUDA kernel time via torch.profiler.
54
+
55
+ Profiles the function, sums all CUDA kernel durations, and returns
56
+ the median across repeats. Also prints a per-kernel breakdown when
57
+ *verbose* is True so the caller can spot unexpected kernels.
58
+
59
+ Parameters
60
+ ----------
61
+ total_bytes : int
62
+ Total bytes transferred (inputs read + outputs written).
63
+ If > 0, prints bandwidth in GB/s after the breakdown.
64
+
65
+ Returns
66
+ -------
67
+ median_ms : float
68
+ Median total CUDA kernel time in **milliseconds** (same unit as
69
+ ``triton.testing.do_bench``).
70
+ """
71
+ for _ in range(warmup):
72
+ fn()
73
+ torch.cuda.synchronize()
74
+
75
+ kernel_times_us: list[float] = []
76
+ last_breakdown: list[tuple[str, float]] = []
77
+
78
+ for _ in range(repeat):
79
+ with profile(activities=[ProfilerActivity.CUDA]) as prof:
80
+ fn()
81
+
82
+ breakdown: dict[str, float] = {}
83
+ for evt in prof.key_averages():
84
+ if evt.device_time_total > 0:
85
+ breakdown[evt.key] = (breakdown.get(evt.key, 0) +
86
+ evt.device_time_total)
87
+
88
+ total_us = sum(breakdown.values())
89
+ kernel_times_us.append(total_us)
90
+ last_breakdown = sorted(breakdown.items(),
91
+ key=lambda x: x[1],
92
+ reverse=True)
93
+
94
+ median_us = sorted(kernel_times_us)[len(kernel_times_us) // 2]
95
+
96
+ if verbose and last_breakdown:
97
+ total = sum(t for _, t in last_breakdown)
98
+ names = [_shorten_kernel_name(n) for n, _ in last_breakdown]
99
+ col_w = max(len(n) for n in names) + 2
100
+ col_w = max(col_w, len("Total kernel time") + 2)
101
+ for name, (_, t) in zip(names, last_breakdown):
102
+ pct = 100 * t / total if total > 0 else 0
103
+ print(f" {name:<{col_w}s} {t:>8.1f}us ({pct:4.1f}%)")
104
+ print(f" {'Total kernel time':<{col_w}s} {total:>8.1f}us")
105
+ if total_bytes > 0 and median_us > 0:
106
+ bw_gbs = total_bytes / (median_us * 1e-6) / 1e9
107
+ print(f" {'Bandwidth':<{col_w}s} {bw_gbs:>7.1f} GB/s"
108
+ f" ({total_bytes / 1e6:.1f} MB)")
109
+
110
+ return median_us / 1000 # us -> ms
111
+
112
+
113
  def make_fwd_key(batch_size, seq_len, dim):
114
  return f"forward : ({batch_size}, {seq_len}, {dim})"
115
 
 
132
  case: DiffCase,
133
  configs: Sequence[tuple[int, int, int]],
134
  plot_name: str,
135
+ ylabel: str = "",
136
  line_vals=("naive", "cuda", "speedup"),
137
  line_names: Dict[str, str] | None = None,
138
  dtype=torch.bfloat16,
 
140
  time_unit_scale: float = 1000,
141
  ):
142
  timings_ms = collections.defaultdict(dict)
143
+ bytes_map: dict[str, int] = {}
144
  line_vals = list(line_vals)
145
  line_names = line_names or {v: v.title() for v in line_vals}
146
  x_vals = [list(_) for _ in configs]
 
158
  key = make_fwd_key(dim, batch_size, seq_len)
159
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
160
  if provider == "speedup":
161
+ return round(timings_ms["naive"][key] / _get_best_cuda_timing(timings_ms, key), 2)
162
+ if provider.endswith("_bw"):
163
+ base = provider[:-3]
164
+ ms = timings_ms[base][key]
165
+ return round(bytes_map[key] / (ms * 1e-3) / 1e9, 2)
166
  if provider == "naive":
167
  obj = case.make_naive(I)
168
  elif provider == "compiled" and hasattr(case, "make_compiled"):
 
170
  else:
171
  obj = case.make_cuda(I)
172
  run = lambda: case.forward(obj, I)
173
+ nbytes = _compute_bytes(I, run, obj)
174
+ bytes_map[key] = nbytes
175
+ print(f" [{provider}] {key}")
176
+ ms = profile_bench(run, total_bytes=nbytes)
177
  timings_ms[provider][key] = ms
178
  return time_unit_scale * ms
179
 
 
222
  else:
223
  obj = case.make_cuda(I)
224
  run = lambda: case.forward(obj, I)
225
+ nbytes = _compute_bytes(I, run, obj)
226
+ print(f" [{provider}] {config}")
227
+ ms = profile_bench(run, total_bytes=nbytes)
228
  timings_ms[provider][config] = ms
229
  if provider == "cuda":
230
+ ratio = timings_ms["naive"][config] / _get_best_cuda_timing(timings_ms, config)
231
  spdup_ratio.append(ratio)
232
  return round(ratio, 2)
233
  else:
 
241
  case: DiffCase,
242
  configs: Sequence[tuple[int, int, int]],
243
  plot_name: str,
244
+ ylabel: str = "",
245
  line_vals=("naive", "cuda", "speedup"),
246
  line_names: Dict[str, str] | None = None,
247
  dtype=torch.bfloat16,
 
249
  time_unit_scale: float = 1000,
250
  ):
251
  timings_ms = collections.defaultdict(dict)
252
+ bytes_map: dict[str, int] = {}
253
  line_vals = list(line_vals)
254
  line_names = line_names or {v: v.title() for v in line_vals}
255
  x_vals = [list(_) for _ in configs]
 
267
  key = make_bwd_key(dim, batch_size, seq_len)
268
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
269
  if provider == "speedup":
270
+ return round(timings_ms["naive"][key] / _get_best_cuda_timing(timings_ms, key), 2)
271
+ if provider.endswith("_bw"):
272
+ base = provider[:-3]
273
+ ms = timings_ms[base][key]
274
+ return round(bytes_map[key] / (ms * 1e-3) / 1e9, 2)
275
  if provider == "naive":
276
  obj = case.make_naive(I)
277
  elif provider == "compiled" and hasattr(case, "make_compiled"):
 
290
  retain_graph=True,
291
  create_graph=False,
292
  allow_unused=False)
293
+ fwd_run = lambda: case.forward(obj, I)
294
+ nbytes = _compute_bytes(I, fwd_run, obj)
295
+ bytes_map[key] = nbytes
296
+ print(f" [{provider}] {key}")
297
+ ms = profile_bench(run, total_bytes=nbytes)
298
  timings_ms[provider][key] = ms
299
  return time_unit_scale * ms
300
 
 
354
  retain_graph=True,
355
  create_graph=False,
356
  allow_unused=False)
357
+ fwd_run = lambda: case.forward(obj, I)
358
+ nbytes = _compute_bytes(I, fwd_run, obj)
359
+ print(f" [{provider}] {config}")
360
+ ms = profile_bench(run, total_bytes=nbytes)
361
  timings_ms[provider][config] = ms
362
  if provider == "cuda":
363
+ ratio = timings_ms["naive"][config] / _get_best_cuda_timing(timings_ms, config)
364
  spdup_ratio.append(ratio)
365
  return round(ratio, 2)
366
  else:
benchmarks/run_cases.py CHANGED
@@ -12,6 +12,17 @@ from common.bench_framework import (make_bwd_benchmark_for_case,
12
  from common.diff_engine import DiffCase, calculate_diff
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
15
  def make_title_tag():
16
  if torch.cuda.is_available():
17
  dev_name = torch.cuda.get_device_name(0)
@@ -66,11 +77,6 @@ def main():
66
  default="bf16",
67
  help="Data type for benchmarking (default: bf16)",
68
  )
69
- ap.add_argument(
70
- "--profile",
71
- action="store_true",
72
- help="Export chrome traces for backward benchmarks",
73
- )
74
  args = ap.parse_args()
75
 
76
  dtype_map = {
@@ -170,27 +176,38 @@ def main():
170
  itertools.product(dim, batch_size_range, seq_length_range))
171
 
172
  if is_grouped:
173
- fwd_line_vals = ("naive", "compiled", "cuda", "speedup")
 
174
  fwd_line_names = {
175
- "naive": "Naive",
176
- "compiled": "Compiled",
177
- "cuda": "Triton",
178
- "speedup": "SpeedUp",
 
 
 
179
  }
180
- bwd_line_vals = ("naive", "compiled", "compiled_cuda",
181
- "speedup")
 
182
  bwd_line_names = {
183
- "naive": "Naive",
184
- "compiled": "Compiled",
185
- "compiled_cuda": "CompiledCUDA",
186
- "speedup": "SpeedUp",
 
 
 
187
  }
188
  else:
189
- fwd_line_vals = ("naive", "cuda", "speedup")
 
190
  fwd_line_names = {
191
- "naive": "Naive",
192
- "cuda": "Cuda",
193
- "speedup": "SpeedUp",
 
 
194
  }
195
  bwd_line_vals = fwd_line_vals
196
  bwd_line_names = fwd_line_names
@@ -202,11 +219,12 @@ def main():
202
  dtype=dtype,
203
  line_vals=fwd_line_vals,
204
  line_names=fwd_line_names,
205
- profile=args.profile,
206
- profile_dir=os.path.join(save_dir, "traces"),
207
  )
208
 
209
- bench.run(print_data=True, save_path=save_dir)
 
 
 
210
 
211
  bench = make_bwd_benchmark_for_case(
212
  case=case,
@@ -215,11 +233,12 @@ def main():
215
  dtype=dtype,
216
  line_vals=bwd_line_vals,
217
  line_names=bwd_line_names,
218
- profile=args.profile,
219
- profile_dir=os.path.join(save_dir, "traces"),
220
  )
221
 
222
- bench.run(print_data=True, save_path=save_dir)
 
 
 
223
  for f in glob.glob(os.path.join(save_dir, "*.html")) + \
224
  glob.glob(os.path.join(save_dir, "*.png")):
225
  os.remove(f)
 
12
  from common.diff_engine import DiffCase, calculate_diff
13
 
14
 
15
+ def _clean_and_print_csv(csv_path, title):
16
+ """Remove trailing ' ()' from CSV column names and print the table."""
17
+ import pandas as pd
18
+ df = pd.read_csv(csv_path)
19
+ df.columns = [c.replace(" ()", "") for c in df.columns]
20
+ df.to_csv(csv_path, index=False)
21
+ print(f"{title}:")
22
+ print(df.to_string(index=False))
23
+ print()
24
+
25
+
26
  def make_title_tag():
27
  if torch.cuda.is_available():
28
  dev_name = torch.cuda.get_device_name(0)
 
77
  default="bf16",
78
  help="Data type for benchmarking (default: bf16)",
79
  )
 
 
 
 
 
80
  args = ap.parse_args()
81
 
82
  dtype_map = {
 
176
  itertools.product(dim, batch_size_range, seq_length_range))
177
 
178
  if is_grouped:
179
+ fwd_line_vals = ("naive", "naive_bw", "compiled",
180
+ "compiled_bw", "cuda", "cuda_bw", "speedup")
181
  fwd_line_names = {
182
+ "naive": "Naive (us)",
183
+ "naive_bw": "Naive (GB/s)",
184
+ "compiled": "Compiled (us)",
185
+ "compiled_bw": "Compiled (GB/s)",
186
+ "cuda": "CUDA (us)",
187
+ "cuda_bw": "CUDA (GB/s)",
188
+ "speedup": "SpeedUp (ratio)",
189
  }
190
+ bwd_line_vals = ("naive", "naive_bw", "compiled",
191
+ "compiled_bw", "compiled_cuda",
192
+ "compiled_cuda_bw", "speedup")
193
  bwd_line_names = {
194
+ "naive": "Naive (us)",
195
+ "naive_bw": "Naive (GB/s)",
196
+ "compiled": "Compiled (us)",
197
+ "compiled_bw": "Compiled (GB/s)",
198
+ "compiled_cuda": "CompiledCUDA (us)",
199
+ "compiled_cuda_bw": "CompiledCUDA (GB/s)",
200
+ "speedup": "SpeedUp (ratio)",
201
  }
202
  else:
203
+ fwd_line_vals = ("naive", "naive_bw", "cuda", "cuda_bw",
204
+ "speedup")
205
  fwd_line_names = {
206
+ "naive": "Naive (us)",
207
+ "naive_bw": "Naive (GB/s)",
208
+ "cuda": "CUDA (us)",
209
+ "cuda_bw": "CUDA (GB/s)",
210
+ "speedup": "SpeedUp (ratio)",
211
  }
212
  bwd_line_vals = fwd_line_vals
213
  bwd_line_names = fwd_line_names
 
219
  dtype=dtype,
220
  line_vals=fwd_line_vals,
221
  line_names=fwd_line_names,
 
 
222
  )
223
 
224
+ fwd_name = f"{args.case}-{dtype_name}-fwd-perf"
225
+ bench.run(print_data=False, save_path=save_dir)
226
+ _clean_and_print_csv(os.path.join(save_dir, fwd_name + ".csv"),
227
+ fwd_name)
228
 
229
  bench = make_bwd_benchmark_for_case(
230
  case=case,
 
233
  dtype=dtype,
234
  line_vals=bwd_line_vals,
235
  line_names=bwd_line_names,
 
 
236
  )
237
 
238
+ bwd_name = f"{args.case}-{dtype_name}-bwd-perf"
239
+ bench.run(print_data=False, save_path=save_dir)
240
+ _clean_and_print_csv(os.path.join(save_dir, bwd_name + ".csv"),
241
+ bwd_name)
242
  for f in glob.glob(os.path.join(save_dir, "*.html")) + \
243
  glob.glob(os.path.join(save_dir, "*.png")):
244
  os.remove(f)
docs/benchmark.md ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Benchmark System
2
+
3
+ ## Overview
4
+
5
+ ๋ฒค์น˜๋งˆํฌ ์‹œ์Šคํ…œ์€ ์ปค์Šคํ…€ CUDA ์ปค๋„์˜ forward/backward ์„ฑ๋Šฅ์„ naive PyTorch ๊ตฌํ˜„ ๋Œ€๋น„ ์ธก์ •ํ•œ๋‹ค. Triton์˜ `do_bench`๋ฅผ ์‚ฌ์šฉํ•˜๋ฉฐ, ์ •ํ™•๋„ ๊ฒ€์ฆ(correctness check) ํ›„ ์„ฑ๋Šฅ์„ ์ธก์ •ํ•œ๋‹ค.
6
+
7
+ ## Directory Structure
8
+
9
+ ```
10
+ benchmarks/
11
+ โ”œโ”€โ”€ run_cases.py # CLI ์ง„์ž…์ 
12
+ โ”œโ”€โ”€ common/
13
+ โ”‚ โ”œโ”€โ”€ bench_framework.py # ๋ฒค์น˜๋งˆํฌ ์œ ํ‹ธ๋ฆฌํ‹ฐ (Triton perf_report ๊ธฐ๋ฐ˜)
14
+ โ”‚ โ””โ”€โ”€ diff_engine.py # ์ •ํ™•๋„ ๊ฒ€์ฆ ์—”์ง„ (DiffCase ABC)
15
+ โ”œโ”€โ”€ cases/ # ๋ฒค์น˜๋งˆํฌ ์ผ€์ด์Šค ๊ตฌํ˜„
16
+ โ”‚ โ”œโ”€โ”€ rms.py # RMSNorm
17
+ โ”‚ โ”œโ”€โ”€ add_rms.py # Fused Add + RMSNorm
18
+ โ”‚ โ”œโ”€โ”€ poly.py # PolyNorm
19
+ โ”‚ โ”œโ”€โ”€ mul_poly.py # Fused Mul + PolyNorm
20
+ โ”‚ โ””โ”€โ”€ grouped_mul_poly.py # Grouped MoE Fused Mul + PolyNorm
21
+ โ”œโ”€โ”€ benchmark.yaml # Kubeflow ๋ฒค์น˜๋งˆํฌ job config
22
+ โ”œโ”€โ”€ test.yaml # Kubeflow ํ…Œ์ŠคํŠธ job config
23
+ โ”œโ”€โ”€ plots/ # ์ƒ์„ฑ๋œ ํ”Œ๋กฏ ๊ฒฐ๊ณผ
24
+ โ””โ”€โ”€ results/ # ํƒ€์ž„์Šคํƒฌํ”„๋ณ„ ๋ฒค์น˜๋งˆํฌ ๊ฒฐ๊ณผ
25
+ ```
26
+
27
+ ## Usage
28
+
29
+ ```bash
30
+ python benchmarks/run_cases.py --case <CASE> [OPTIONS]
31
+ ```
32
+
33
+ ### Arguments
34
+
35
+ | Argument | Default | Choices | Description |
36
+ |----------|---------|---------|-------------|
37
+ | `--case` | (ํ•„์ˆ˜) | `rms`, `add_rms`, `poly`, `mul_poly`, `grouped_mul_poly` | ๋ฒค์น˜๋งˆํฌ ์ผ€์ด์Šค |
38
+ | `--dtype` | `bf16` | `fp16`, `bf16`, `fp32`, `all` | ๋ฐ์ดํ„ฐ ํƒ€์ž… |
39
+ | `--save-path` | `./configs/` | ๊ฒฝ๋กœ | ๊ฒฐ๊ณผ ์ถœ๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ |
40
+ | `--plot` | false | - | ํ”Œ๋กฏ ์ƒ์„ฑ ๋ชจ๋“œ |
41
+ | `--profile` | false | - | Chrome trace ํ”„๋กœํŒŒ์ผ๋ง ๋‚ด๋ณด๋‚ด๊ธฐ |
42
+
43
+ ### Examples
44
+
45
+ ```bash
46
+ # bf16 ๊ธฐ๋ณธ ๋ฒค์น˜๋งˆํฌ
47
+ python benchmarks/run_cases.py --case grouped_mul_poly
48
+
49
+ # ๋ชจ๋“  dtype + ํ”„๋กœํŒŒ์ผ๋ง
50
+ python benchmarks/run_cases.py --case mul_poly --dtype all --profile --save-path ./results
51
+
52
+ # ํ”Œ๋กฏ๋งŒ ์ƒ์„ฑ
53
+ python benchmarks/run_cases.py --case rms --plot --save-path ./plots
54
+ ```
55
+
56
+ ## Benchmark Cases
57
+
58
+ ๊ฐ ์ผ€์ด์Šค๋Š” `DiffCase` ABC๋ฅผ ๊ตฌํ˜„ํ•˜๋ฉฐ, naive(PyTorch ์ฐธ์กฐ)์™€ CUDA ์ปค๋„์„ ๋น„๊ตํ•œ๋‹ค.
59
+
60
+ | Case | Naive | CUDA | Inputs |
61
+ |------|-------|------|--------|
62
+ | `rms` | `torch.nn.RMSNorm` | `activation.layers.RMSNorm` | x, weight, eps |
63
+ | `add_rms` | custom `FusedAddRMSNorm` | `activation.layers.FusedAddRMSNorm` | x, residual, weight, eps |
64
+ | `poly` | custom `PolyNorm` (x^3, x^2, x ์กฐํ•ฉ) | `activation.layers.PolyNorm` | x, weight(3), bias(1), eps |
65
+ | `mul_poly` | custom `FusedMulPolyNorm` | `activation.layers.FusedMulPolyNorm` | x, mul, weight(3), bias, eps |
66
+ | `grouped_mul_poly` | `fused_mul_grouped_poly_norm_ref` | `fused_mul_grouped_poly_norm` | x, mul, weight(num_experts, 3), bias, offsets |
67
+
68
+ `grouped_mul_poly`๋Š” ์ถ”๊ฐ€๋กœ `compiled`(torch.compile๋œ naive)์™€ `compiled_cuda`(torch.compile๋œ CUDA) provider๋„ ์ธก์ •ํ•œ๋‹ค.
69
+
70
+ ## Execution Flow
71
+
72
+ 1. **์ •ํ™•๋„ ๊ฒ€์ฆ** - 3๊ฐœ config์— ๋Œ€ํ•ด `calculate_diff()` ์‹คํ–‰
73
+ - `(bs=2, sl=128, hidden=4096)`
74
+ - `(bs=8, sl=4096, hidden=1280)`
75
+ - `(bs=1, sl=32768, hidden=1280)`
76
+ - forward/backward ๋ชจ๋‘ `atol=1e-2, rtol=1e-2`๋กœ ๋น„๊ต
77
+ 2. **๋ฒค์น˜๋งˆํฌ ์‹คํ–‰** - dtype๋ณ„๋กœ forward/backward ์„ฑ๋Šฅ ์ธก์ •
78
+ 3. **๊ฒฐ๊ณผ ์ €์žฅ** - CSV ํŒŒ์ผ (๋ฐ ์„ ํƒ์ ์œผ๋กœ ํ”Œ๋กฏ/trace)
79
+
80
+ ## Configuration Ranges
81
+
82
+ **Standard cases** (rms, add_rms, poly, mul_poly):
83
+ - Batch sizes: 1, 2, 4, 8
84
+ - Sequence lengths: 1024, 2048, 4096, 8192
85
+ - Hidden dims: 2048, 4096
86
+
87
+ **Grouped case** (grouped_mul_poly):
88
+ - Total tokens: 1024 ~ 65536 (bs x sl)
89
+ - Hidden dim: 1280 (๊ณ ์ •)
90
+ - Experts: 48 per rank
91
+
92
+ `--plot` ๋ชจ๋“œ์—์„œ๋Š” `bs=1`๋กœ ๊ณ ์ •ํ•˜๊ณ  seq_len๋งŒ sweepํ•œ๋‹ค.
93
+
94
+ ## Output
95
+
96
+ ### CSV
97
+
98
+ `{save_path}/{case}/{dtype}/` ๋””๋ ‰ํ† ๋ฆฌ์— ์ €์žฅ:
99
+
100
+ - `{case}-{dtype}-fwd-perf.csv` - forward ๊ฒฐ๊ณผ
101
+ - `{case}-{dtype}-bwd-perf.csv` - backward ๊ฒฐ๊ณผ
102
+
103
+ ์ปฌ๋Ÿผ: `dim`, `batch_size`, `seq_len`, `Naive (us)`, `Compiled (us)`, `Cuda (us)`, `SpeedUp (us)`
104
+
105
+ ### Chrome Trace (`--profile`)
106
+
107
+ `{save_path}/{case}/{dtype}/traces/` ๋””๋ ‰ํ† ๋ฆฌ์— JSON ํ˜•์‹์œผ๋กœ ์ €์žฅ. `chrome://tracing`์—์„œ ๋กœ๋“œํ•˜์—ฌ GPU ํƒ€์ž„๋ผ์ธ์„ ๋ถ„์„ํ•  ์ˆ˜ ์žˆ๋‹ค.
108
+
109
+ ํŒŒ์ผ๋ช… ํŒจํ„ด: `trace_{fwd|bwd}_{naive|compiled|cuda|compiled_cuda}_N{total_tokens}.json`
110
+
111
+ ### Plot (`--plot`)
112
+
113
+ Speedup ๋น„๊ต ํ”Œ๋กฏ ์ƒ์„ฑ. Geometric mean์œผ๋กœ ์ „์ฒด speedup์„ ์ง‘๊ณ„ํ•œ๋‹ค.
114
+
115
+ ## Framework Internals
116
+
117
+ ### bench_framework.py
118
+
119
+ Triton์˜ `perf_report`/`Benchmark`๋ฅผ ์‚ฌ์šฉํ•˜๋Š” 4๊ฐœ ํŒฉํ† ๋ฆฌ ํ•จ์ˆ˜:
120
+
121
+ - `make_fwd_benchmark_for_case()` - forward ๋ฒค์น˜๋งˆํฌ (CSV)
122
+ - `make_bwd_benchmark_for_case()` - backward ๋ฒค์น˜๋งˆํฌ (CSV)
123
+ - `make_fwd_benchmark_plot_for_case()` - forward ํ”Œ๋กฏ
124
+ - `make_bwd_benchmark_plot_for_case()` - backward ํ”Œ๋กฏ
125
+
126
+ ํƒ€์ด๋ฐ์€ `triton.testing.do_bench()`๋กœ ์ธก์ •ํ•˜๋ฉฐ, ms ๋‹จ์œ„๋ฅผ us๋กœ ๋ณ€ํ™˜ํ•œ๋‹ค (`time_unit_scale=1000`).
127
+
128
+ ### diff_engine.py
129
+
130
+ `DiffCase` ABC ์ธํ„ฐํŽ˜์ด์Šค:
131
+
132
+ - `build_inputs(bs, sl, dim)` - ์ž…๋ ฅ ํ…์„œ ์ƒ๏ฟฝ๏ฟฝ
133
+ - `make_naive()` / `make_cuda()` - ๊ตฌํ˜„์ฒด ์ƒ์„ฑ
134
+ - `forward(module, inputs)` - forward ์‹คํ–‰
135
+ - `grad_inputs(inputs)` - gradient ๋Œ€์ƒ ํ…์„œ ๋ฐ˜ํ™˜
136
+
137
+ `calculate_diff()`๊ฐ€ naive์™€ CUDA ์–‘์ชฝ์˜ forward output + backward gradient๋ฅผ `torch.testing.assert_close()`๋กœ ๋น„๊ตํ•œ๋‹ค.
138
+
139
+ ## Kubeflow Integration
140
+
141
+ `benchmark.yaml`๋กœ ํด๋Ÿฌ์Šคํ„ฐ์—์„œ ๋ฒค์น˜๋งˆํฌ๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค:
142
+
143
+ - triton, matplotlib, pandas ์„ค์น˜
144
+ - C++ extension ๋นŒ๋“œ (`setup.py`)
145
+ - GPU warmup (100 iterations matmul)
146
+ - ๊ฒฐ๊ณผ๋ฅผ `benchmarks/results/{YY_MM_DD_HH_MM}/`์— ์ €์žฅ