feat: replace triton do_bench with torch.profiler for kernel timing
Browse filesSwitch from triton.testing.do_bench to torch.profiler-based CUDA kernel
time measurement. do_bench includes PyTorch CPU overhead in timing,
leading to inaccurate results. torch.profiler sums only actual CUDA
kernel durations for precise measurement.
- Add profile_bench() with per-kernel breakdown and bandwidth output
- Add GB/s and SpeedUp(ratio) columns to CSV reports
- Fix grouped_mul_poly import path (activation.grouped_poly_norm)
- Add docs/benchmark.md for benchmark system documentation
- Update CLAUDE.md build system description
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- CLAUDE.md +116 -0
- benchmarks/benchmark_profiler.yaml +96 -0
- benchmarks/cases/grouped_mul_poly.py +4 -3
- benchmarks/common/bench_framework.py +133 -10
- benchmarks/run_cases.py +45 -26
- docs/benchmark.md +146 -0
CLAUDE.md
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md
|
| 2 |
+
|
| 3 |
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
| 4 |
+
|
| 5 |
+
## Project Overview
|
| 6 |
+
|
| 7 |
+
Custom CUDA/ROCm normalization kernels for LLM training and inference, published as `motif-technologies/activation` on HuggingFace. Implements PolyNorm, RMSNorm, FusedAddRMSNorm, and FusedMulPolyNorm with full autograd support, fake tensor registration, and DTensor sharding strategies for sequence parallelism.
|
| 8 |
+
|
| 9 |
+
## Build System
|
| 10 |
+
|
| 11 |
+
### Local development build (primary)
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
pip install -e . # editable install (build + install)
|
| 15 |
+
python setup.py build_ext --inplace # build only
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
`setup.py` does two things:
|
| 19 |
+
1. Compiles `_activation` C extension from CUDA sources (`activation/*.cu`) + C++ binding (`torch-ext/torch_binding.cpp`)
|
| 20 |
+
2. Installs `activation` Python package from `torch-ext/activation/` (autograd functions, layers, etc.)
|
| 21 |
+
|
| 22 |
+
After `pip install -e .`, all imports work directly (`import activation`, `from activation.grouped_poly_norm import ...`). No PYTHONPATH manipulation needed.
|
| 23 |
+
|
| 24 |
+
NVCC flags: `-O3 --use_fast_math -std=c++17`, targets sm_80 (A100), sm_89 (L40/4090), sm_90 (H100), sm_100 (B200, CUDA 12.8+).
|
| 25 |
+
|
| 26 |
+
### CI / HuggingFace distribution build
|
| 27 |
+
|
| 28 |
+
Uses HuggingFace's `kernel-builder` via Nix for cross-compilation of pre-built `.abi3.so` binaries.
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
nix run .#build-and-copy
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
Pre-built binaries go to `build/` (tracked via Git LFS). The build config lives in `build.toml`.
|
| 35 |
+
|
| 36 |
+
## Running Tests
|
| 37 |
+
|
| 38 |
+
Tests require a GPU. Install first with `pip install -e .`.
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
# Run all tests
|
| 42 |
+
pytest tests/
|
| 43 |
+
|
| 44 |
+
# Run a single test file
|
| 45 |
+
pytest tests/test_rms_norm.py
|
| 46 |
+
|
| 47 |
+
# Run a specific test
|
| 48 |
+
pytest tests/test_rms_norm.py::test_rms_norm_forward -v
|
| 49 |
+
|
| 50 |
+
# Sequence parallel tests (require torch>=2.8 and 2+ GPUs)
|
| 51 |
+
torchrun --nproc-per-node=2 -m pytest tests/test_rms_norm_sequence_parallel.py
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Pytest config is in `tests/pytest.ini` (log_cli enabled at INFO level).
|
| 55 |
+
|
| 56 |
+
## Linting / Formatting
|
| 57 |
+
|
| 58 |
+
Pre-commit hooks handle all formatting. Install with `pre-commit install`.
|
| 59 |
+
|
| 60 |
+
- **Python**: yapf (formatter), isort (imports)
|
| 61 |
+
- **C++/CUDA**: clang-format (`--style=file`)
|
| 62 |
+
- **Markdown**: pymarkdown
|
| 63 |
+
- **Spelling**: typos
|
| 64 |
+
|
| 65 |
+
The `build/` and `result/` directories are excluded from all hooks.
|
| 66 |
+
|
| 67 |
+
## Architecture
|
| 68 |
+
|
| 69 |
+
### Layer structure (bottom-up)
|
| 70 |
+
|
| 71 |
+
1. **CUDA/HIP kernels** (`activation/*.cu`, `activation/*.h`): Hand-written kernels that compile for both NVIDIA (CUB) and AMD (hipcub). Each kernel uses vectorized template dispatch (`width > 0` for coalesced 128-bit loads when `dim % 8 == 0`, scalar fallback otherwise). Accumulation is always in float32.
|
| 72 |
+
|
| 73 |
+
2. **C++ torch binding** (`torch-ext/torch_binding.cpp`): Registers ops via `TORCH_LIBRARY_EXPAND` under a build-specific namespace (e.g., `_activation_53ed492_dirty`).
|
| 74 |
+
|
| 75 |
+
3. **Autograd Functions** (`torch-ext/activation/poly_norm.py`, `rms_norm.py`): `torch.autograd.Function` subclasses with `forward`, `setup_context`, and `backward`. Also registers `@torch.library.register_fake` for `torch.compile`/AOT support.
|
| 76 |
+
|
| 77 |
+
4. **DTensor sharding strategies** (`torch-ext/activation/rms_norm_meta.py`, `fused_add_rms_norm_meta.py`): `@register_op_strategy` definitions. Input can be sharded on any dim except the last (normalization dim); weight is always replicated.
|
| 78 |
+
|
| 79 |
+
5. **nn.Module wrappers** (`torch-ext/activation/layers.py`): `PolyNorm`, `FusedMulPolyNorm`, `RMSNorm`, `FusedAddRMSNorm` โ these are the public user-facing API.
|
| 80 |
+
|
| 81 |
+
6. **Parallel style** (`torch-ext/activation/parallel_style.py`): `ResidualSequenceParallel` extends PyTorch's `SequenceParallel` for the two-input (x + residual) pattern of `FusedAddRMSNorm`.
|
| 82 |
+
|
| 83 |
+
### Key files
|
| 84 |
+
|
| 85 |
+
| Path | Purpose |
|
| 86 |
+
|------|---------|
|
| 87 |
+
| `build.toml` | kernel-builder manifest (backends, source files, ROCm arch targets) |
|
| 88 |
+
| `activation/cuda_compat.h` | CUDA/ROCm compatibility shim (CUB vs hipcub, WARP_SIZE) |
|
| 89 |
+
| `activation/dispatch_utils.h` | `MOTIF_DISPATCH_FLOATING_TYPES` type dispatch macro |
|
| 90 |
+
| `torch-ext/activation/__init__.py` | Package entry point; exports functional API + layers + parallel_style |
|
| 91 |
+
| `torch-ext/activation/_ops.py` | Generated at build time; loads `.abi3.so` and exposes `torch.ops.*` |
|
| 92 |
+
|
| 93 |
+
### Adding a new kernel
|
| 94 |
+
|
| 95 |
+
1. Write the CUDA kernel in `activation/new_kernel.cu` (follow the vectorized template pattern from existing kernels)
|
| 96 |
+
2. Add C++ declarations in `torch-ext/torch_binding.h` and register in `torch-ext/torch_binding.cpp`
|
| 97 |
+
3. Create `torch.autograd.Function` in `torch-ext/activation/new_kernel.py` with fake tensor registration
|
| 98 |
+
4. Add `nn.Module` wrapper in `torch-ext/activation/layers.py`
|
| 99 |
+
5. If distributed support is needed, add DTensor strategy in `torch-ext/activation/new_kernel_meta.py`
|
| 100 |
+
6. Add the `.cu` file to both `[kernel.activation]` and `[kernel.activation_cuda]` in `build.toml`
|
| 101 |
+
7. Export from `torch-ext/activation/__init__.py`
|
| 102 |
+
8. Add tests in `tests/test_new_kernel.py` (numerical comparison vs PyTorch reference + `torch.library.opcheck`)
|
| 103 |
+
|
| 104 |
+
### Test conventions
|
| 105 |
+
|
| 106 |
+
- Tests compare custom ops against PyTorch reference implementations using tolerances from `tests/allclose_default.py`
|
| 107 |
+
- Every test runs `torch.library.opcheck` to validate op schema, autograd, fake tensor, and AOT dispatch
|
| 108 |
+
- Tests are parametrized over dtypes (float32, float16, bfloat16), sequence lengths, and hidden dimensions
|
| 109 |
+
- Sequence parallel tests use `torchrun` with 2 GPUs and require torch>=2.8
|
| 110 |
+
|
| 111 |
+
### ROCm/CUDA target matrix
|
| 112 |
+
|
| 113 |
+
- **ROCm architectures**: gfx90a (MI250), gfx942 (MI300X)
|
| 114 |
+
- **PyTorch versions**: 2.7 through 2.10
|
| 115 |
+
- **CUDA versions**: 11.8, 12.6, 12.8, 12.9, 13.0
|
| 116 |
+
- **ROCm versions**: 6.3, 6.4, 7.0, 7.1
|
benchmarks/benchmark_profiler.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
apiVersion: trainer.kubeflow.org/v1alpha1
|
| 2 |
+
kind: TrainJob
|
| 3 |
+
metadata:
|
| 4 |
+
name: jeesoo-grouped-polynorm-profiler-bench
|
| 5 |
+
namespace: kbm-g-np-motif
|
| 6 |
+
spec:
|
| 7 |
+
managedBy: trainer.kubeflow.org/trainjob-controller
|
| 8 |
+
podTemplateOverrides:
|
| 9 |
+
- spec:
|
| 10 |
+
containers:
|
| 11 |
+
- name: node
|
| 12 |
+
volumeMounts:
|
| 13 |
+
- mountPath: /dev/shm
|
| 14 |
+
name: shm
|
| 15 |
+
- mountPath: /mair
|
| 16 |
+
name: mair
|
| 17 |
+
volumes:
|
| 18 |
+
- emptyDir:
|
| 19 |
+
medium: Memory
|
| 20 |
+
sizeLimit: 64Gi
|
| 21 |
+
name: shm
|
| 22 |
+
- name: mair
|
| 23 |
+
persistentVolumeClaim:
|
| 24 |
+
claimName: mair
|
| 25 |
+
targetJobs:
|
| 26 |
+
- name: node
|
| 27 |
+
runtimeRef:
|
| 28 |
+
apiGroup: trainer.kubeflow.org
|
| 29 |
+
kind: ClusterTrainingRuntime
|
| 30 |
+
name: torch-distributed
|
| 31 |
+
suspend: false
|
| 32 |
+
trainer:
|
| 33 |
+
args:
|
| 34 |
+
- /bin/bash
|
| 35 |
+
- '-c'
|
| 36 |
+
- >
|
| 37 |
+
ACTIVATIONPATH=/mair/team-sys/jeesoo/activation
|
| 38 |
+
|
| 39 |
+
BUILDLOG=$ACTIVATIONPATH/benchmarks/results/build.log
|
| 40 |
+
|
| 41 |
+
pip install triton matplotlib pandas
|
| 42 |
+
|
| 43 |
+
echo "=== Building with setup.py ==="
|
| 44 |
+
|
| 45 |
+
cd $ACTIVATIONPATH
|
| 46 |
+
|
| 47 |
+
rm -f $ACTIVATIONPATH/_activation*.so
|
| 48 |
+
|
| 49 |
+
pip install --no-build-isolation -e . -v 2>&1 | tee $BUILDLOG | tail -200
|
| 50 |
+
|
| 51 |
+
python -c "import _activation; print('Build OK:', _activation)" || { echo "BUILD FAILED"; exit 0; }
|
| 52 |
+
|
| 53 |
+
echo "=== Build success. Running profiler benchmarks ==="
|
| 54 |
+
|
| 55 |
+
cd $ACTIVATIONPATH/benchmarks
|
| 56 |
+
|
| 57 |
+
DATESTAMP=$(date +'%y_%m_%d_%H_%M')
|
| 58 |
+
|
| 59 |
+
SAVE_PATH=$ACTIVATIONPATH/benchmarks/results/${DATESTAMP}
|
| 60 |
+
|
| 61 |
+
mkdir -p $SAVE_PATH/bench/grouped_mul_poly/bf16
|
| 62 |
+
|
| 63 |
+
nvidia-smi | tee $SAVE_PATH/nvidia_smi.txt
|
| 64 |
+
|
| 65 |
+
python -c "import torch; x=torch.randn(8192,1280,device='cuda',dtype=torch.bfloat16); [torch.mm(x.T,x) for _ in range(100)]; torch.cuda.synchronize(); print('warmup done')"
|
| 66 |
+
|
| 67 |
+
echo "=== Benchmark (torch.profiler) ==="
|
| 68 |
+
|
| 69 |
+
python run_cases.py --case grouped_mul_poly --dtype bf16 --save-path ${SAVE_PATH}/bench 2>&1 | tee ${SAVE_PATH}/bench_log.txt
|
| 70 |
+
|
| 71 |
+
echo "=== Done ==="
|
| 72 |
+
|
| 73 |
+
exit 0;
|
| 74 |
+
env:
|
| 75 |
+
- name: PYTHONUNBUFFERED
|
| 76 |
+
value: '1'
|
| 77 |
+
- name: PYTORCH_ALLOC_CONF
|
| 78 |
+
value: expandable_segments:True
|
| 79 |
+
- name: CUDA_LAUNCH_BLOCKING
|
| 80 |
+
value: '0'
|
| 81 |
+
- name: OMP_NUM_THREADS
|
| 82 |
+
value: '1'
|
| 83 |
+
- name: HF_HOME
|
| 84 |
+
value: /mair/llm-dataset/hf_cache
|
| 85 |
+
image: ghcr.io/motiftechnologies/llm-training:v0.1.3
|
| 86 |
+
numNodes: 1
|
| 87 |
+
numProcPerNode: 1
|
| 88 |
+
resourcesPerNode:
|
| 89 |
+
limits:
|
| 90 |
+
cpu: '16'
|
| 91 |
+
memory: 128Gi
|
| 92 |
+
nvidia.com/gpu: '1'
|
| 93 |
+
requests:
|
| 94 |
+
cpu: '16'
|
| 95 |
+
memory: 128Gi
|
| 96 |
+
nvidia.com/gpu: '1'
|
benchmarks/cases/grouped_mul_poly.py
CHANGED
|
@@ -5,8 +5,8 @@ from common.diff_engine import DiffCase
|
|
| 5 |
|
| 6 |
torch._functorch.config.donated_buffer = False
|
| 7 |
|
| 8 |
-
from grouped_poly_norm import (fused_mul_grouped_poly_norm,
|
| 9 |
-
|
| 10 |
|
| 11 |
# 384 / 8 (EP) = 48 experts per rank
|
| 12 |
# total_tokens = bs * sl, which equals per-rank tokens
|
|
@@ -73,7 +73,8 @@ class GroupedMulPoly(DiffCase):
|
|
| 73 |
probs = torch.ones(num_experts) / num_experts
|
| 74 |
assignments = torch.multinomial(probs, total_tokens, replacement=True)
|
| 75 |
counts = torch.bincount(assignments, minlength=num_experts).tolist()
|
| 76 |
-
offsets = torch.cumsum(torch.tensor(counts, dtype=torch.int32),
|
|
|
|
| 77 |
|
| 78 |
return {
|
| 79 |
"x":
|
|
|
|
| 5 |
|
| 6 |
torch._functorch.config.donated_buffer = False
|
| 7 |
|
| 8 |
+
from activation.grouped_poly_norm import (fused_mul_grouped_poly_norm,
|
| 9 |
+
fused_mul_grouped_poly_norm_ref)
|
| 10 |
|
| 11 |
# 384 / 8 (EP) = 48 experts per rank
|
| 12 |
# total_tokens = bs * sl, which equals per-rank tokens
|
|
|
|
| 73 |
probs = torch.ones(num_experts) / num_experts
|
| 74 |
assignments = torch.multinomial(probs, total_tokens, replacement=True)
|
| 75 |
counts = torch.bincount(assignments, minlength=num_experts).tolist()
|
| 76 |
+
offsets = torch.cumsum(torch.tensor(counts, dtype=torch.int32),
|
| 77 |
+
dim=0).to(torch.int32)
|
| 78 |
|
| 79 |
return {
|
| 80 |
"x":
|
benchmarks/common/bench_framework.py
CHANGED
|
@@ -4,11 +4,112 @@ import re
|
|
| 4 |
from typing import Any, Dict, Sequence
|
| 5 |
|
| 6 |
import torch
|
|
|
|
| 7 |
import triton
|
| 8 |
|
| 9 |
from .diff_engine import DiffCase
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
def make_fwd_key(batch_size, seq_len, dim):
|
| 13 |
return f"forward : ({batch_size}, {seq_len}, {dim})"
|
| 14 |
|
|
@@ -31,7 +132,7 @@ def make_fwd_benchmark_for_case(
|
|
| 31 |
case: DiffCase,
|
| 32 |
configs: Sequence[tuple[int, int, int]],
|
| 33 |
plot_name: str,
|
| 34 |
-
ylabel: str = "
|
| 35 |
line_vals=("naive", "cuda", "speedup"),
|
| 36 |
line_names: Dict[str, str] | None = None,
|
| 37 |
dtype=torch.bfloat16,
|
|
@@ -39,6 +140,7 @@ def make_fwd_benchmark_for_case(
|
|
| 39 |
time_unit_scale: float = 1000,
|
| 40 |
):
|
| 41 |
timings_ms = collections.defaultdict(dict)
|
|
|
|
| 42 |
line_vals = list(line_vals)
|
| 43 |
line_names = line_names or {v: v.title() for v in line_vals}
|
| 44 |
x_vals = [list(_) for _ in configs]
|
|
@@ -56,7 +158,11 @@ def make_fwd_benchmark_for_case(
|
|
| 56 |
key = make_fwd_key(dim, batch_size, seq_len)
|
| 57 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 58 |
if provider == "speedup":
|
| 59 |
-
return timings_ms["naive"][key] / timings_ms
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
if provider == "naive":
|
| 61 |
obj = case.make_naive(I)
|
| 62 |
elif provider == "compiled" and hasattr(case, "make_compiled"):
|
|
@@ -64,7 +170,10 @@ def make_fwd_benchmark_for_case(
|
|
| 64 |
else:
|
| 65 |
obj = case.make_cuda(I)
|
| 66 |
run = lambda: case.forward(obj, I)
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
| 68 |
timings_ms[provider][key] = ms
|
| 69 |
return time_unit_scale * ms
|
| 70 |
|
|
@@ -113,10 +222,12 @@ def make_fwd_benchmark_plot_for_case(
|
|
| 113 |
else:
|
| 114 |
obj = case.make_cuda(I)
|
| 115 |
run = lambda: case.forward(obj, I)
|
| 116 |
-
|
|
|
|
|
|
|
| 117 |
timings_ms[provider][config] = ms
|
| 118 |
if provider == "cuda":
|
| 119 |
-
ratio = timings_ms["naive"][config] / timings_ms
|
| 120 |
spdup_ratio.append(ratio)
|
| 121 |
return round(ratio, 2)
|
| 122 |
else:
|
|
@@ -130,7 +241,7 @@ def make_bwd_benchmark_for_case(
|
|
| 130 |
case: DiffCase,
|
| 131 |
configs: Sequence[tuple[int, int, int]],
|
| 132 |
plot_name: str,
|
| 133 |
-
ylabel: str = "
|
| 134 |
line_vals=("naive", "cuda", "speedup"),
|
| 135 |
line_names: Dict[str, str] | None = None,
|
| 136 |
dtype=torch.bfloat16,
|
|
@@ -138,6 +249,7 @@ def make_bwd_benchmark_for_case(
|
|
| 138 |
time_unit_scale: float = 1000,
|
| 139 |
):
|
| 140 |
timings_ms = collections.defaultdict(dict)
|
|
|
|
| 141 |
line_vals = list(line_vals)
|
| 142 |
line_names = line_names or {v: v.title() for v in line_vals}
|
| 143 |
x_vals = [list(_) for _ in configs]
|
|
@@ -155,7 +267,11 @@ def make_bwd_benchmark_for_case(
|
|
| 155 |
key = make_bwd_key(dim, batch_size, seq_len)
|
| 156 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 157 |
if provider == "speedup":
|
| 158 |
-
return timings_ms["naive"][key] / timings_ms
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
if provider == "naive":
|
| 160 |
obj = case.make_naive(I)
|
| 161 |
elif provider == "compiled" and hasattr(case, "make_compiled"):
|
|
@@ -174,7 +290,11 @@ def make_bwd_benchmark_for_case(
|
|
| 174 |
retain_graph=True,
|
| 175 |
create_graph=False,
|
| 176 |
allow_unused=False)
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
timings_ms[provider][key] = ms
|
| 179 |
return time_unit_scale * ms
|
| 180 |
|
|
@@ -234,10 +354,13 @@ def make_bwd_benchmark_plot_for_case(
|
|
| 234 |
retain_graph=True,
|
| 235 |
create_graph=False,
|
| 236 |
allow_unused=False)
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
| 238 |
timings_ms[provider][config] = ms
|
| 239 |
if provider == "cuda":
|
| 240 |
-
ratio = timings_ms["naive"][config] / timings_ms
|
| 241 |
spdup_ratio.append(ratio)
|
| 242 |
return round(ratio, 2)
|
| 243 |
else:
|
|
|
|
| 4 |
from typing import Any, Dict, Sequence
|
| 5 |
|
| 6 |
import torch
|
| 7 |
+
from torch.profiler import ProfilerActivity, profile
|
| 8 |
import triton
|
| 9 |
|
| 10 |
from .diff_engine import DiffCase
|
| 11 |
|
| 12 |
|
| 13 |
+
def _get_best_cuda_timing(timings_ms, key):
|
| 14 |
+
"""Look up the best CUDA-based timing for speedup calculation."""
|
| 15 |
+
for provider in ("cuda", "compiled_cuda"):
|
| 16 |
+
if provider in timings_ms and key in timings_ms[provider]:
|
| 17 |
+
return timings_ms[provider][key]
|
| 18 |
+
raise KeyError(f"No CUDA timing found for {key}")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _shorten_kernel_name(name: str) -> str:
|
| 22 |
+
"""Strip template args and function params from CUDA kernel names.
|
| 23 |
+
|
| 24 |
+
``void motif::grouped_poly_norm_bwd_kernel<...>(...)``
|
| 25 |
+
โ ``motif::grouped_poly_norm_bwd_kernel``
|
| 26 |
+
"""
|
| 27 |
+
# Remove leading 'void '
|
| 28 |
+
s = re.sub(r"^void\s+", "", name)
|
| 29 |
+
# Remove template args <...> (handles nested <>)
|
| 30 |
+
while "<" in s:
|
| 31 |
+
s = re.sub(r"<[^<>]*>", "", s)
|
| 32 |
+
# Remove function params (...)
|
| 33 |
+
s = re.sub(r"\(.*\)$", "", s)
|
| 34 |
+
return s.strip()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _compute_bytes(inputs, forward_fn, obj):
|
| 38 |
+
"""Compute total bytes: all input tensors read + all output tensors written."""
|
| 39 |
+
input_bytes = sum(v.nbytes for v in inputs.values()
|
| 40 |
+
if isinstance(v, torch.Tensor))
|
| 41 |
+
output = forward_fn()
|
| 42 |
+
if isinstance(output, torch.Tensor):
|
| 43 |
+
output_bytes = output.nbytes
|
| 44 |
+
elif isinstance(output, (tuple, list)):
|
| 45 |
+
output_bytes = sum(
|
| 46 |
+
o.nbytes for o in output if isinstance(o, torch.Tensor))
|
| 47 |
+
else:
|
| 48 |
+
output_bytes = 0
|
| 49 |
+
return input_bytes + output_bytes
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def profile_bench(fn, warmup=5, repeat=10, verbose=True, total_bytes=0):
|
| 53 |
+
"""Measure CUDA kernel time via torch.profiler.
|
| 54 |
+
|
| 55 |
+
Profiles the function, sums all CUDA kernel durations, and returns
|
| 56 |
+
the median across repeats. Also prints a per-kernel breakdown when
|
| 57 |
+
*verbose* is True so the caller can spot unexpected kernels.
|
| 58 |
+
|
| 59 |
+
Parameters
|
| 60 |
+
----------
|
| 61 |
+
total_bytes : int
|
| 62 |
+
Total bytes transferred (inputs read + outputs written).
|
| 63 |
+
If > 0, prints bandwidth in GB/s after the breakdown.
|
| 64 |
+
|
| 65 |
+
Returns
|
| 66 |
+
-------
|
| 67 |
+
median_ms : float
|
| 68 |
+
Median total CUDA kernel time in **milliseconds** (same unit as
|
| 69 |
+
``triton.testing.do_bench``).
|
| 70 |
+
"""
|
| 71 |
+
for _ in range(warmup):
|
| 72 |
+
fn()
|
| 73 |
+
torch.cuda.synchronize()
|
| 74 |
+
|
| 75 |
+
kernel_times_us: list[float] = []
|
| 76 |
+
last_breakdown: list[tuple[str, float]] = []
|
| 77 |
+
|
| 78 |
+
for _ in range(repeat):
|
| 79 |
+
with profile(activities=[ProfilerActivity.CUDA]) as prof:
|
| 80 |
+
fn()
|
| 81 |
+
|
| 82 |
+
breakdown: dict[str, float] = {}
|
| 83 |
+
for evt in prof.key_averages():
|
| 84 |
+
if evt.device_time_total > 0:
|
| 85 |
+
breakdown[evt.key] = (breakdown.get(evt.key, 0) +
|
| 86 |
+
evt.device_time_total)
|
| 87 |
+
|
| 88 |
+
total_us = sum(breakdown.values())
|
| 89 |
+
kernel_times_us.append(total_us)
|
| 90 |
+
last_breakdown = sorted(breakdown.items(),
|
| 91 |
+
key=lambda x: x[1],
|
| 92 |
+
reverse=True)
|
| 93 |
+
|
| 94 |
+
median_us = sorted(kernel_times_us)[len(kernel_times_us) // 2]
|
| 95 |
+
|
| 96 |
+
if verbose and last_breakdown:
|
| 97 |
+
total = sum(t for _, t in last_breakdown)
|
| 98 |
+
names = [_shorten_kernel_name(n) for n, _ in last_breakdown]
|
| 99 |
+
col_w = max(len(n) for n in names) + 2
|
| 100 |
+
col_w = max(col_w, len("Total kernel time") + 2)
|
| 101 |
+
for name, (_, t) in zip(names, last_breakdown):
|
| 102 |
+
pct = 100 * t / total if total > 0 else 0
|
| 103 |
+
print(f" {name:<{col_w}s} {t:>8.1f}us ({pct:4.1f}%)")
|
| 104 |
+
print(f" {'Total kernel time':<{col_w}s} {total:>8.1f}us")
|
| 105 |
+
if total_bytes > 0 and median_us > 0:
|
| 106 |
+
bw_gbs = total_bytes / (median_us * 1e-6) / 1e9
|
| 107 |
+
print(f" {'Bandwidth':<{col_w}s} {bw_gbs:>7.1f} GB/s"
|
| 108 |
+
f" ({total_bytes / 1e6:.1f} MB)")
|
| 109 |
+
|
| 110 |
+
return median_us / 1000 # us -> ms
|
| 111 |
+
|
| 112 |
+
|
| 113 |
def make_fwd_key(batch_size, seq_len, dim):
|
| 114 |
return f"forward : ({batch_size}, {seq_len}, {dim})"
|
| 115 |
|
|
|
|
| 132 |
case: DiffCase,
|
| 133 |
configs: Sequence[tuple[int, int, int]],
|
| 134 |
plot_name: str,
|
| 135 |
+
ylabel: str = "",
|
| 136 |
line_vals=("naive", "cuda", "speedup"),
|
| 137 |
line_names: Dict[str, str] | None = None,
|
| 138 |
dtype=torch.bfloat16,
|
|
|
|
| 140 |
time_unit_scale: float = 1000,
|
| 141 |
):
|
| 142 |
timings_ms = collections.defaultdict(dict)
|
| 143 |
+
bytes_map: dict[str, int] = {}
|
| 144 |
line_vals = list(line_vals)
|
| 145 |
line_names = line_names or {v: v.title() for v in line_vals}
|
| 146 |
x_vals = [list(_) for _ in configs]
|
|
|
|
| 158 |
key = make_fwd_key(dim, batch_size, seq_len)
|
| 159 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 160 |
if provider == "speedup":
|
| 161 |
+
return round(timings_ms["naive"][key] / _get_best_cuda_timing(timings_ms, key), 2)
|
| 162 |
+
if provider.endswith("_bw"):
|
| 163 |
+
base = provider[:-3]
|
| 164 |
+
ms = timings_ms[base][key]
|
| 165 |
+
return round(bytes_map[key] / (ms * 1e-3) / 1e9, 2)
|
| 166 |
if provider == "naive":
|
| 167 |
obj = case.make_naive(I)
|
| 168 |
elif provider == "compiled" and hasattr(case, "make_compiled"):
|
|
|
|
| 170 |
else:
|
| 171 |
obj = case.make_cuda(I)
|
| 172 |
run = lambda: case.forward(obj, I)
|
| 173 |
+
nbytes = _compute_bytes(I, run, obj)
|
| 174 |
+
bytes_map[key] = nbytes
|
| 175 |
+
print(f" [{provider}] {key}")
|
| 176 |
+
ms = profile_bench(run, total_bytes=nbytes)
|
| 177 |
timings_ms[provider][key] = ms
|
| 178 |
return time_unit_scale * ms
|
| 179 |
|
|
|
|
| 222 |
else:
|
| 223 |
obj = case.make_cuda(I)
|
| 224 |
run = lambda: case.forward(obj, I)
|
| 225 |
+
nbytes = _compute_bytes(I, run, obj)
|
| 226 |
+
print(f" [{provider}] {config}")
|
| 227 |
+
ms = profile_bench(run, total_bytes=nbytes)
|
| 228 |
timings_ms[provider][config] = ms
|
| 229 |
if provider == "cuda":
|
| 230 |
+
ratio = timings_ms["naive"][config] / _get_best_cuda_timing(timings_ms, config)
|
| 231 |
spdup_ratio.append(ratio)
|
| 232 |
return round(ratio, 2)
|
| 233 |
else:
|
|
|
|
| 241 |
case: DiffCase,
|
| 242 |
configs: Sequence[tuple[int, int, int]],
|
| 243 |
plot_name: str,
|
| 244 |
+
ylabel: str = "",
|
| 245 |
line_vals=("naive", "cuda", "speedup"),
|
| 246 |
line_names: Dict[str, str] | None = None,
|
| 247 |
dtype=torch.bfloat16,
|
|
|
|
| 249 |
time_unit_scale: float = 1000,
|
| 250 |
):
|
| 251 |
timings_ms = collections.defaultdict(dict)
|
| 252 |
+
bytes_map: dict[str, int] = {}
|
| 253 |
line_vals = list(line_vals)
|
| 254 |
line_names = line_names or {v: v.title() for v in line_vals}
|
| 255 |
x_vals = [list(_) for _ in configs]
|
|
|
|
| 267 |
key = make_bwd_key(dim, batch_size, seq_len)
|
| 268 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 269 |
if provider == "speedup":
|
| 270 |
+
return round(timings_ms["naive"][key] / _get_best_cuda_timing(timings_ms, key), 2)
|
| 271 |
+
if provider.endswith("_bw"):
|
| 272 |
+
base = provider[:-3]
|
| 273 |
+
ms = timings_ms[base][key]
|
| 274 |
+
return round(bytes_map[key] / (ms * 1e-3) / 1e9, 2)
|
| 275 |
if provider == "naive":
|
| 276 |
obj = case.make_naive(I)
|
| 277 |
elif provider == "compiled" and hasattr(case, "make_compiled"):
|
|
|
|
| 290 |
retain_graph=True,
|
| 291 |
create_graph=False,
|
| 292 |
allow_unused=False)
|
| 293 |
+
fwd_run = lambda: case.forward(obj, I)
|
| 294 |
+
nbytes = _compute_bytes(I, fwd_run, obj)
|
| 295 |
+
bytes_map[key] = nbytes
|
| 296 |
+
print(f" [{provider}] {key}")
|
| 297 |
+
ms = profile_bench(run, total_bytes=nbytes)
|
| 298 |
timings_ms[provider][key] = ms
|
| 299 |
return time_unit_scale * ms
|
| 300 |
|
|
|
|
| 354 |
retain_graph=True,
|
| 355 |
create_graph=False,
|
| 356 |
allow_unused=False)
|
| 357 |
+
fwd_run = lambda: case.forward(obj, I)
|
| 358 |
+
nbytes = _compute_bytes(I, fwd_run, obj)
|
| 359 |
+
print(f" [{provider}] {config}")
|
| 360 |
+
ms = profile_bench(run, total_bytes=nbytes)
|
| 361 |
timings_ms[provider][config] = ms
|
| 362 |
if provider == "cuda":
|
| 363 |
+
ratio = timings_ms["naive"][config] / _get_best_cuda_timing(timings_ms, config)
|
| 364 |
spdup_ratio.append(ratio)
|
| 365 |
return round(ratio, 2)
|
| 366 |
else:
|
benchmarks/run_cases.py
CHANGED
|
@@ -12,6 +12,17 @@ from common.bench_framework import (make_bwd_benchmark_for_case,
|
|
| 12 |
from common.diff_engine import DiffCase, calculate_diff
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def make_title_tag():
|
| 16 |
if torch.cuda.is_available():
|
| 17 |
dev_name = torch.cuda.get_device_name(0)
|
|
@@ -66,11 +77,6 @@ def main():
|
|
| 66 |
default="bf16",
|
| 67 |
help="Data type for benchmarking (default: bf16)",
|
| 68 |
)
|
| 69 |
-
ap.add_argument(
|
| 70 |
-
"--profile",
|
| 71 |
-
action="store_true",
|
| 72 |
-
help="Export chrome traces for backward benchmarks",
|
| 73 |
-
)
|
| 74 |
args = ap.parse_args()
|
| 75 |
|
| 76 |
dtype_map = {
|
|
@@ -170,27 +176,38 @@ def main():
|
|
| 170 |
itertools.product(dim, batch_size_range, seq_length_range))
|
| 171 |
|
| 172 |
if is_grouped:
|
| 173 |
-
fwd_line_vals = ("naive", "
|
|
|
|
| 174 |
fwd_line_names = {
|
| 175 |
-
"naive": "Naive",
|
| 176 |
-
"
|
| 177 |
-
"
|
| 178 |
-
"
|
|
|
|
|
|
|
|
|
|
| 179 |
}
|
| 180 |
-
bwd_line_vals = ("naive", "
|
| 181 |
-
"
|
|
|
|
| 182 |
bwd_line_names = {
|
| 183 |
-
"naive": "Naive",
|
| 184 |
-
"
|
| 185 |
-
"
|
| 186 |
-
"
|
|
|
|
|
|
|
|
|
|
| 187 |
}
|
| 188 |
else:
|
| 189 |
-
fwd_line_vals = ("naive", "cuda", "
|
|
|
|
| 190 |
fwd_line_names = {
|
| 191 |
-
"naive": "Naive",
|
| 192 |
-
"
|
| 193 |
-
"
|
|
|
|
|
|
|
| 194 |
}
|
| 195 |
bwd_line_vals = fwd_line_vals
|
| 196 |
bwd_line_names = fwd_line_names
|
|
@@ -202,11 +219,12 @@ def main():
|
|
| 202 |
dtype=dtype,
|
| 203 |
line_vals=fwd_line_vals,
|
| 204 |
line_names=fwd_line_names,
|
| 205 |
-
profile=args.profile,
|
| 206 |
-
profile_dir=os.path.join(save_dir, "traces"),
|
| 207 |
)
|
| 208 |
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
bench = make_bwd_benchmark_for_case(
|
| 212 |
case=case,
|
|
@@ -215,11 +233,12 @@ def main():
|
|
| 215 |
dtype=dtype,
|
| 216 |
line_vals=bwd_line_vals,
|
| 217 |
line_names=bwd_line_names,
|
| 218 |
-
profile=args.profile,
|
| 219 |
-
profile_dir=os.path.join(save_dir, "traces"),
|
| 220 |
)
|
| 221 |
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
| 223 |
for f in glob.glob(os.path.join(save_dir, "*.html")) + \
|
| 224 |
glob.glob(os.path.join(save_dir, "*.png")):
|
| 225 |
os.remove(f)
|
|
|
|
| 12 |
from common.diff_engine import DiffCase, calculate_diff
|
| 13 |
|
| 14 |
|
| 15 |
+
def _clean_and_print_csv(csv_path, title):
|
| 16 |
+
"""Remove trailing ' ()' from CSV column names and print the table."""
|
| 17 |
+
import pandas as pd
|
| 18 |
+
df = pd.read_csv(csv_path)
|
| 19 |
+
df.columns = [c.replace(" ()", "") for c in df.columns]
|
| 20 |
+
df.to_csv(csv_path, index=False)
|
| 21 |
+
print(f"{title}:")
|
| 22 |
+
print(df.to_string(index=False))
|
| 23 |
+
print()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
def make_title_tag():
|
| 27 |
if torch.cuda.is_available():
|
| 28 |
dev_name = torch.cuda.get_device_name(0)
|
|
|
|
| 77 |
default="bf16",
|
| 78 |
help="Data type for benchmarking (default: bf16)",
|
| 79 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
args = ap.parse_args()
|
| 81 |
|
| 82 |
dtype_map = {
|
|
|
|
| 176 |
itertools.product(dim, batch_size_range, seq_length_range))
|
| 177 |
|
| 178 |
if is_grouped:
|
| 179 |
+
fwd_line_vals = ("naive", "naive_bw", "compiled",
|
| 180 |
+
"compiled_bw", "cuda", "cuda_bw", "speedup")
|
| 181 |
fwd_line_names = {
|
| 182 |
+
"naive": "Naive (us)",
|
| 183 |
+
"naive_bw": "Naive (GB/s)",
|
| 184 |
+
"compiled": "Compiled (us)",
|
| 185 |
+
"compiled_bw": "Compiled (GB/s)",
|
| 186 |
+
"cuda": "CUDA (us)",
|
| 187 |
+
"cuda_bw": "CUDA (GB/s)",
|
| 188 |
+
"speedup": "SpeedUp (ratio)",
|
| 189 |
}
|
| 190 |
+
bwd_line_vals = ("naive", "naive_bw", "compiled",
|
| 191 |
+
"compiled_bw", "compiled_cuda",
|
| 192 |
+
"compiled_cuda_bw", "speedup")
|
| 193 |
bwd_line_names = {
|
| 194 |
+
"naive": "Naive (us)",
|
| 195 |
+
"naive_bw": "Naive (GB/s)",
|
| 196 |
+
"compiled": "Compiled (us)",
|
| 197 |
+
"compiled_bw": "Compiled (GB/s)",
|
| 198 |
+
"compiled_cuda": "CompiledCUDA (us)",
|
| 199 |
+
"compiled_cuda_bw": "CompiledCUDA (GB/s)",
|
| 200 |
+
"speedup": "SpeedUp (ratio)",
|
| 201 |
}
|
| 202 |
else:
|
| 203 |
+
fwd_line_vals = ("naive", "naive_bw", "cuda", "cuda_bw",
|
| 204 |
+
"speedup")
|
| 205 |
fwd_line_names = {
|
| 206 |
+
"naive": "Naive (us)",
|
| 207 |
+
"naive_bw": "Naive (GB/s)",
|
| 208 |
+
"cuda": "CUDA (us)",
|
| 209 |
+
"cuda_bw": "CUDA (GB/s)",
|
| 210 |
+
"speedup": "SpeedUp (ratio)",
|
| 211 |
}
|
| 212 |
bwd_line_vals = fwd_line_vals
|
| 213 |
bwd_line_names = fwd_line_names
|
|
|
|
| 219 |
dtype=dtype,
|
| 220 |
line_vals=fwd_line_vals,
|
| 221 |
line_names=fwd_line_names,
|
|
|
|
|
|
|
| 222 |
)
|
| 223 |
|
| 224 |
+
fwd_name = f"{args.case}-{dtype_name}-fwd-perf"
|
| 225 |
+
bench.run(print_data=False, save_path=save_dir)
|
| 226 |
+
_clean_and_print_csv(os.path.join(save_dir, fwd_name + ".csv"),
|
| 227 |
+
fwd_name)
|
| 228 |
|
| 229 |
bench = make_bwd_benchmark_for_case(
|
| 230 |
case=case,
|
|
|
|
| 233 |
dtype=dtype,
|
| 234 |
line_vals=bwd_line_vals,
|
| 235 |
line_names=bwd_line_names,
|
|
|
|
|
|
|
| 236 |
)
|
| 237 |
|
| 238 |
+
bwd_name = f"{args.case}-{dtype_name}-bwd-perf"
|
| 239 |
+
bench.run(print_data=False, save_path=save_dir)
|
| 240 |
+
_clean_and_print_csv(os.path.join(save_dir, bwd_name + ".csv"),
|
| 241 |
+
bwd_name)
|
| 242 |
for f in glob.glob(os.path.join(save_dir, "*.html")) + \
|
| 243 |
glob.glob(os.path.join(save_dir, "*.png")):
|
| 244 |
os.remove(f)
|
docs/benchmark.md
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Benchmark System
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
๋ฒค์น๋งํฌ ์์คํ
์ ์ปค์คํ
CUDA ์ปค๋์ forward/backward ์ฑ๋ฅ์ naive PyTorch ๊ตฌํ ๋๋น ์ธก์ ํ๋ค. Triton์ `do_bench`๋ฅผ ์ฌ์ฉํ๋ฉฐ, ์ ํ๋ ๊ฒ์ฆ(correctness check) ํ ์ฑ๋ฅ์ ์ธก์ ํ๋ค.
|
| 6 |
+
|
| 7 |
+
## Directory Structure
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
benchmarks/
|
| 11 |
+
โโโ run_cases.py # CLI ์ง์
์
|
| 12 |
+
โโโ common/
|
| 13 |
+
โ โโโ bench_framework.py # ๋ฒค์น๋งํฌ ์ ํธ๋ฆฌํฐ (Triton perf_report ๊ธฐ๋ฐ)
|
| 14 |
+
โ โโโ diff_engine.py # ์ ํ๋ ๊ฒ์ฆ ์์ง (DiffCase ABC)
|
| 15 |
+
โโโ cases/ # ๋ฒค์น๋งํฌ ์ผ์ด์ค ๊ตฌํ
|
| 16 |
+
โ โโโ rms.py # RMSNorm
|
| 17 |
+
โ โโโ add_rms.py # Fused Add + RMSNorm
|
| 18 |
+
โ โโโ poly.py # PolyNorm
|
| 19 |
+
โ โโโ mul_poly.py # Fused Mul + PolyNorm
|
| 20 |
+
โ โโโ grouped_mul_poly.py # Grouped MoE Fused Mul + PolyNorm
|
| 21 |
+
โโโ benchmark.yaml # Kubeflow ๋ฒค์น๋งํฌ job config
|
| 22 |
+
โโโ test.yaml # Kubeflow ํ
์คํธ job config
|
| 23 |
+
โโโ plots/ # ์์ฑ๋ ํ๋กฏ ๊ฒฐ๊ณผ
|
| 24 |
+
โโโ results/ # ํ์์คํฌํ๋ณ ๋ฒค์น๋งํฌ ๊ฒฐ๊ณผ
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
## Usage
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
python benchmarks/run_cases.py --case <CASE> [OPTIONS]
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
### Arguments
|
| 34 |
+
|
| 35 |
+
| Argument | Default | Choices | Description |
|
| 36 |
+
|----------|---------|---------|-------------|
|
| 37 |
+
| `--case` | (ํ์) | `rms`, `add_rms`, `poly`, `mul_poly`, `grouped_mul_poly` | ๋ฒค์น๋งํฌ ์ผ์ด์ค |
|
| 38 |
+
| `--dtype` | `bf16` | `fp16`, `bf16`, `fp32`, `all` | ๋ฐ์ดํฐ ํ์
|
|
| 39 |
+
| `--save-path` | `./configs/` | ๊ฒฝ๋ก | ๊ฒฐ๊ณผ ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ |
|
| 40 |
+
| `--plot` | false | - | ํ๋กฏ ์์ฑ ๋ชจ๋ |
|
| 41 |
+
| `--profile` | false | - | Chrome trace ํ๋กํ์ผ๋ง ๋ด๋ณด๋ด๊ธฐ |
|
| 42 |
+
|
| 43 |
+
### Examples
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
# bf16 ๊ธฐ๋ณธ ๋ฒค์น๋งํฌ
|
| 47 |
+
python benchmarks/run_cases.py --case grouped_mul_poly
|
| 48 |
+
|
| 49 |
+
# ๋ชจ๋ dtype + ํ๋กํ์ผ๋ง
|
| 50 |
+
python benchmarks/run_cases.py --case mul_poly --dtype all --profile --save-path ./results
|
| 51 |
+
|
| 52 |
+
# ํ๋กฏ๋ง ์์ฑ
|
| 53 |
+
python benchmarks/run_cases.py --case rms --plot --save-path ./plots
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Benchmark Cases
|
| 57 |
+
|
| 58 |
+
๊ฐ ์ผ์ด์ค๋ `DiffCase` ABC๋ฅผ ๊ตฌํํ๋ฉฐ, naive(PyTorch ์ฐธ์กฐ)์ CUDA ์ปค๋์ ๋น๊ตํ๋ค.
|
| 59 |
+
|
| 60 |
+
| Case | Naive | CUDA | Inputs |
|
| 61 |
+
|------|-------|------|--------|
|
| 62 |
+
| `rms` | `torch.nn.RMSNorm` | `activation.layers.RMSNorm` | x, weight, eps |
|
| 63 |
+
| `add_rms` | custom `FusedAddRMSNorm` | `activation.layers.FusedAddRMSNorm` | x, residual, weight, eps |
|
| 64 |
+
| `poly` | custom `PolyNorm` (x^3, x^2, x ์กฐํฉ) | `activation.layers.PolyNorm` | x, weight(3), bias(1), eps |
|
| 65 |
+
| `mul_poly` | custom `FusedMulPolyNorm` | `activation.layers.FusedMulPolyNorm` | x, mul, weight(3), bias, eps |
|
| 66 |
+
| `grouped_mul_poly` | `fused_mul_grouped_poly_norm_ref` | `fused_mul_grouped_poly_norm` | x, mul, weight(num_experts, 3), bias, offsets |
|
| 67 |
+
|
| 68 |
+
`grouped_mul_poly`๋ ์ถ๊ฐ๋ก `compiled`(torch.compile๋ naive)์ `compiled_cuda`(torch.compile๋ CUDA) provider๋ ์ธก์ ํ๋ค.
|
| 69 |
+
|
| 70 |
+
## Execution Flow
|
| 71 |
+
|
| 72 |
+
1. **์ ํ๋ ๊ฒ์ฆ** - 3๊ฐ config์ ๋ํด `calculate_diff()` ์คํ
|
| 73 |
+
- `(bs=2, sl=128, hidden=4096)`
|
| 74 |
+
- `(bs=8, sl=4096, hidden=1280)`
|
| 75 |
+
- `(bs=1, sl=32768, hidden=1280)`
|
| 76 |
+
- forward/backward ๋ชจ๋ `atol=1e-2, rtol=1e-2`๋ก ๋น๊ต
|
| 77 |
+
2. **๋ฒค์น๋งํฌ ์คํ** - dtype๋ณ๋ก forward/backward ์ฑ๋ฅ ์ธก์
|
| 78 |
+
3. **๊ฒฐ๊ณผ ์ ์ฅ** - CSV ํ์ผ (๋ฐ ์ ํ์ ์ผ๋ก ํ๋กฏ/trace)
|
| 79 |
+
|
| 80 |
+
## Configuration Ranges
|
| 81 |
+
|
| 82 |
+
**Standard cases** (rms, add_rms, poly, mul_poly):
|
| 83 |
+
- Batch sizes: 1, 2, 4, 8
|
| 84 |
+
- Sequence lengths: 1024, 2048, 4096, 8192
|
| 85 |
+
- Hidden dims: 2048, 4096
|
| 86 |
+
|
| 87 |
+
**Grouped case** (grouped_mul_poly):
|
| 88 |
+
- Total tokens: 1024 ~ 65536 (bs x sl)
|
| 89 |
+
- Hidden dim: 1280 (๊ณ ์ )
|
| 90 |
+
- Experts: 48 per rank
|
| 91 |
+
|
| 92 |
+
`--plot` ๋ชจ๋์์๋ `bs=1`๋ก ๊ณ ์ ํ๊ณ seq_len๋ง sweepํ๋ค.
|
| 93 |
+
|
| 94 |
+
## Output
|
| 95 |
+
|
| 96 |
+
### CSV
|
| 97 |
+
|
| 98 |
+
`{save_path}/{case}/{dtype}/` ๋๋ ํ ๋ฆฌ์ ์ ์ฅ:
|
| 99 |
+
|
| 100 |
+
- `{case}-{dtype}-fwd-perf.csv` - forward ๊ฒฐ๊ณผ
|
| 101 |
+
- `{case}-{dtype}-bwd-perf.csv` - backward ๊ฒฐ๊ณผ
|
| 102 |
+
|
| 103 |
+
์ปฌ๋ผ: `dim`, `batch_size`, `seq_len`, `Naive (us)`, `Compiled (us)`, `Cuda (us)`, `SpeedUp (us)`
|
| 104 |
+
|
| 105 |
+
### Chrome Trace (`--profile`)
|
| 106 |
+
|
| 107 |
+
`{save_path}/{case}/{dtype}/traces/` ๋๋ ํ ๋ฆฌ์ JSON ํ์์ผ๋ก ์ ์ฅ. `chrome://tracing`์์ ๋ก๋ํ์ฌ GPU ํ์๋ผ์ธ์ ๋ถ์ํ ์ ์๋ค.
|
| 108 |
+
|
| 109 |
+
ํ์ผ๋ช
ํจํด: `trace_{fwd|bwd}_{naive|compiled|cuda|compiled_cuda}_N{total_tokens}.json`
|
| 110 |
+
|
| 111 |
+
### Plot (`--plot`)
|
| 112 |
+
|
| 113 |
+
Speedup ๋น๊ต ํ๋กฏ ์์ฑ. Geometric mean์ผ๋ก ์ ์ฒด speedup์ ์ง๊ณํ๋ค.
|
| 114 |
+
|
| 115 |
+
## Framework Internals
|
| 116 |
+
|
| 117 |
+
### bench_framework.py
|
| 118 |
+
|
| 119 |
+
Triton์ `perf_report`/`Benchmark`๋ฅผ ์ฌ์ฉํ๋ 4๊ฐ ํฉํ ๋ฆฌ ํจ์:
|
| 120 |
+
|
| 121 |
+
- `make_fwd_benchmark_for_case()` - forward ๋ฒค์น๋งํฌ (CSV)
|
| 122 |
+
- `make_bwd_benchmark_for_case()` - backward ๋ฒค์น๋งํฌ (CSV)
|
| 123 |
+
- `make_fwd_benchmark_plot_for_case()` - forward ํ๋กฏ
|
| 124 |
+
- `make_bwd_benchmark_plot_for_case()` - backward ํ๋กฏ
|
| 125 |
+
|
| 126 |
+
ํ์ด๋ฐ์ `triton.testing.do_bench()`๋ก ์ธก์ ํ๋ฉฐ, ms ๋จ์๋ฅผ us๋ก ๋ณํํ๋ค (`time_unit_scale=1000`).
|
| 127 |
+
|
| 128 |
+
### diff_engine.py
|
| 129 |
+
|
| 130 |
+
`DiffCase` ABC ์ธํฐํ์ด์ค:
|
| 131 |
+
|
| 132 |
+
- `build_inputs(bs, sl, dim)` - ์
๋ ฅ ํ
์ ์๏ฟฝ๏ฟฝ
|
| 133 |
+
- `make_naive()` / `make_cuda()` - ๊ตฌํ์ฒด ์์ฑ
|
| 134 |
+
- `forward(module, inputs)` - forward ์คํ
|
| 135 |
+
- `grad_inputs(inputs)` - gradient ๋์ ํ
์ ๋ฐํ
|
| 136 |
+
|
| 137 |
+
`calculate_diff()`๊ฐ naive์ CUDA ์์ชฝ์ forward output + backward gradient๋ฅผ `torch.testing.assert_close()`๋ก ๋น๊ตํ๋ค.
|
| 138 |
+
|
| 139 |
+
## Kubeflow Integration
|
| 140 |
+
|
| 141 |
+
`benchmark.yaml`๋ก ํด๋ฌ์คํฐ์์ ๋ฒค์น๋งํฌ๋ฅผ ์คํํ ์ ์๋ค:
|
| 142 |
+
|
| 143 |
+
- triton, matplotlib, pandas ์ค์น
|
| 144 |
+
- C++ extension ๋น๋ (`setup.py`)
|
| 145 |
+
- GPU warmup (100 iterations matmul)
|
| 146 |
+
- ๊ฒฐ๊ณผ๋ฅผ `benchmarks/results/{YY_MM_DD_HH_MM}/`์ ์ ์ฅ
|