diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..473e5188876eafc7234e038dc311d6a4a1879251 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu129-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f080903665d8a163e89bb402140dfd73f5f86b16 --- /dev/null +++ b/README.md @@ -0,0 +1,6 @@ +--- +tags: + - kernel +--- + +RWKV kernel for transformers \ No newline at end of file diff --git a/build.toml b/build.toml new file mode 100644 index 0000000000000000000000000000000000000000..212d2cdef56c5edd6e5a38954dba357f181e69f8 --- /dev/null +++ b/build.toml @@ -0,0 +1,31 @@ +[general] +name = "rwkv" +universal = false + +[torch] +src = [ + "torch-ext/torch_binding.cpp", +] + +[kernel.rwkv] +depends = ["torch"] +backend = "cuda" +cuda-capabilities = [ + "8.0", + "8.9", + "9.0", + "10.0", + "12.0", +] +include = ["."] +src = [ + "rwkv/wkv_cuda.cu", + "rwkv/wkv_cuda_bf16.cu", +] +cuda-flags = [ + "-res-usage", + "--use_fast_math", + "-O3", + "--extra-device-vectorization", + "-DTmax=1024", +] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ac87a592935cf8e08a689c9e7ba18f0969489b5 Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9489d4f981fa954d00cd0bf1495187b40c69ad7a Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d5289450131eaf881db14496443314079fd4427d --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_eb0e3e5_dirty +ops = torch.ops._rwkv_eb0e3e5_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_eb0e3e5_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..ffede8b92cd6bbfb6e37223691c462c3c692fcee --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:25632a613591ab66c83b18aeda5bd01f4ba117e34345efdb6191fe501be170cf +size 2065424 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f6b80a875a0ede7f4d1b513ca07a476f585ee05 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c7a791bb451d6c6913b4b8534687db9914ab043 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d5289450131eaf881db14496443314079fd4427d --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_eb0e3e5_dirty +ops = torch.ops._rwkv_eb0e3e5_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_eb0e3e5_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..36d1c57cb98c85c0c681f825476c90d1dac5e330 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c96331e3a863df3e5fc2eee8fe0916f036a51e5c73a03bba2169556a945db2d7 +size 2106440 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce97b6a71dac35e7342e04fefb1193201b1a9a03 Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34af42a2d6d8ceac2b85d819e968afda6c568840 Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d5289450131eaf881db14496443314079fd4427d --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_eb0e3e5_dirty +ops = torch.ops._rwkv_eb0e3e5_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_eb0e3e5_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..c1c287ea4bba97a2726552c02f24a10c9e9a4ef2 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b1ec16445438d1b9ec59f7efab20880b58d774f59c06abab5ae964d164d5fa0 +size 2308880 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..329b5b131b38106776ea86e7fe3c0d1e4f483224 Binary files /dev/null and b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..243aa6f67725e4bc9d5637a02d4636dd3b886cce Binary files /dev/null and b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/rwkv/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d5289450131eaf881db14496443314079fd4427d --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_eb0e3e5_dirty +ops = torch.ops._rwkv_eb0e3e5_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_eb0e3e5_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..63a7748f96ea9cdd0193953815cb2e40f95dac9a --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00476bbd1c8b08d8ad3fef06cb12e0a86900830ad4f72dd4ea400131f432d510 +size 2106464 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d027f8095a699e1546c9ce6c888b30904f797888 Binary files /dev/null and b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca2f20f9ba81107958bc94098c7b387418030ff2 Binary files /dev/null and b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/rwkv/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d5289450131eaf881db14496443314079fd4427d --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_eb0e3e5_dirty +ops = torch.ops._rwkv_eb0e3e5_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_eb0e3e5_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..382e5c41784fbb006a546a64a19499a4279510b9 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:301bf08ad13b2e382ccd49bd7de6ed8931fd1c9a70eb699728999fa454a52723 +size 2308880 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa524be76c0ee5729ed70e931260bdf0c6e1cec2 Binary files /dev/null and b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23cc91e318c59bed440950bd2a4ac549c0d8447b Binary files /dev/null and b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/rwkv/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d5289450131eaf881db14496443314079fd4427d --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_eb0e3e5_dirty +ops = torch.ops._rwkv_eb0e3e5_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_eb0e3e5_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..08a978aab5659be92aa8d39920231a875e1c67d9 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6f42c1d67d68e35de1f1140e85011870e61bf78a6c77499e2fc4c8926d3ebe3 +size 2330376 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7c57d723f5a88ce1844ad8ff14e06c027d96c66 Binary files /dev/null and b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64116501129d9fec7e22bd8165c3bcd87c45fa6f Binary files /dev/null and b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu126-x86_64-linux/rwkv/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d5289450131eaf881db14496443314079fd4427d --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_eb0e3e5_dirty +ops = torch.ops._rwkv_eb0e3e5_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_eb0e3e5_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..20114e909755f23f56e0b0d5fba56cec37f056b0 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40d078873830f1833a85082fd6d907fd75c7eb3c94605ee77618981fe8cdee4e +size 2106440 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23623a59f46a1a2cdb1261eea61425d29be74e65 Binary files /dev/null and b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53a2660f9be84cff703ad95a48d189fed8cf0e1d Binary files /dev/null and b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu128-x86_64-linux/rwkv/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d5289450131eaf881db14496443314079fd4427d --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_eb0e3e5_dirty +ops = torch.ops._rwkv_eb0e3e5_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_eb0e3e5_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..d419664327d14474dd832155b3e3c5ca7e2202d8 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad6856c6dd2c9ed38dc4077f21c09cec11d0de47b870ae348d6576637f9816a2 +size 2308848 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6316b490d04ccd4099ed79523f5b2937338eac9f Binary files /dev/null and b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..099c4487038fa2b25eeb5aaab0b1c8d683fccfde Binary files /dev/null and b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu130-x86_64-linux/rwkv/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d5289450131eaf881db14496443314079fd4427d --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_eb0e3e5_dirty +ops = torch.ops._rwkv_eb0e3e5_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_eb0e3e5_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..72b891a02a8b0447f740e03e1dc9bdf93e0a516a --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbe5a4e59abcaf208a0b7e0c952821d9585431aec58ca2364881d5239ac5819e +size 2334744 diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000000000000000000000000000000000000..85b5d60a855bf4c19555cc9b8de8ca88d6fd3ae9 --- /dev/null +++ b/flake.lock @@ -0,0 +1,168 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_2", + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1759851564, + "narHash": "sha256-Xybkhm0FM/VzlZ5WndTYq/X/9MAeddd4EQ2Vz8GdkOA=", + "owner": "huggingface", + "repo": "hf-nix", + "rev": "351655d9f124805ed7c1193aa61550ce245f4570", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "hf-nix", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "hf-nix": "hf-nix", + "nixpkgs": [ + "kernel-builder", + "hf-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760035358, + "narHash": "sha256-N5vmCrgwcIluPclf/hmnofLK77EJJYh5PR8SRvw++es=", + "owner": "huggingface", + "repo": "kernel-builder", + "rev": "a48cbd19ae7e425dfc1865188ef06dac43ab9244", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "kernel-builder", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1755963616, + "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000000000000000000000000000000000..f4954b5ca991ee9ab5893205c29997dd1c1112c2 --- /dev/null +++ b/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for rwkv kernels"; + + inputs = { + kernel-builder.url = "github:huggingface/kernel-builder"; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genFlakeOutputs { + path = ./.; + rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; + }; +} diff --git a/rwkv/wkv_cuda.cu b/rwkv/wkv_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..571d5a8a8307e95aac689eb3c9333d1ad350c7de --- /dev/null +++ b/rwkv/wkv_cuda.cu @@ -0,0 +1,187 @@ +#include +#include + +#define MIN_VALUE (-1e38) + +template +__global__ void kernel_forward( + const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, + const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + F aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } +} + +template +__global__ void kernel_forward_with_state( + const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, + const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, F *__restrict__ const _s +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset_s = _b * C * 3 + _c * 3; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + F *__restrict__ const s = _s + _offset_s; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + F aa = s[0], bb = s[1], pp = s[2]; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + s[0] = aa; + s[1] = bb; + s[2] = pp; +} + +template +__global__ void kernel_backward( + const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, + const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _y, + const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, + F *__restrict__ const _gv +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const y = _y + _offset; + const F *__restrict__ const gy = _gy + _offset; + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F q[Tmax], r[Tmax]; + + F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + const F qq = gy[ii] / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = gu; + + aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + const F qq = q[i]; + const F rr = r[i]; + + F e1 = qq * exp(rr); + F e2 = exp(kk + pp); + gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); + gv[ii] = e1 + e2 * aa; + + const F ww = w + pp; + const F www = rr - u - kk; + const F p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; + } +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward_with_state<<>>(B, T, C, w, u, k, v, y, s); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); +} diff --git a/rwkv/wkv_cuda_bf16.cu b/rwkv/wkv_cuda_bf16.cu new file mode 100644 index 0000000000000000000000000000000000000000..042cb4aba1db98be5916aea1de86a7fed0b6510d --- /dev/null +++ b/rwkv/wkv_cuda_bf16.cu @@ -0,0 +1,186 @@ +#include +#include +#include "ATen/ATen.h" +#define MIN_VALUE (-1e38) +typedef at::BFloat16 bf16; + +__global__ void kernel_forward_bf16( + const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u, + const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + bf16 *__restrict__ const y = _y + _offset; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + float aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2)); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } +} + +__global__ void kernel_forward_with_state_bf16( + const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u, + const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y, + float *__restrict__ const _s +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset_s = _b * C * 3 + _c * 3; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + bf16 *__restrict__ const y = _y + _offset; + float *__restrict__ const s = _s + _offset_s; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + float aa = s[0], bb = s[1], pp = s[2]; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + y[ii] = bf16(e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + s[0] = aa; + s[1] = bb; + s[2] = pp; +} + +__global__ void kernel_backward_bf16( + const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u, + const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, const bf16 *__restrict__ const _y, + const bf16 *__restrict__ const _gy, bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, + bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + const bf16 *__restrict__ const y = _y + _offset; + const bf16 *__restrict__ const gy = _gy + _offset; + bf16 *__restrict__ const gk = _gk + _offset; + bf16 *__restrict__ const gv = _gv + _offset; + + float q[Tmax], r[Tmax]; + + float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + const float yy = float(y[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + const float qq = float(gy[ii]) / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = bf16(gu); + + aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + const float yy = float(y[ii]); + const float qq = q[i]; + const float rr = r[i]; + + float e1 = qq * exp(rr); + float e2 = exp(kk + pp); + gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb)); + gv[ii] = bf16(e1 + e2 * aa); + + const float ww = w + pp; + const float www = rr - u - kk; + const float p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; + } +} + +void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward_bf16<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward_with_state_bf16<<>>(B, T, C, w, u, k, v, y, s); +} + +void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward_bf16<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); +} diff --git a/torch-ext/rwkv/__init__.py b/torch-ext/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/torch-ext/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c42395d0aa9bb14566a58d73f959c9803fc73cde --- /dev/null +++ b/torch-ext/torch_binding.cpp @@ -0,0 +1,74 @@ +#include +#include "ATen/ATen.h" +#include + +#include "registration.h" + +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); +void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y); +void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s); +void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s); +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); +void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv); + +void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_forward_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_forward_with_state(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), s.data_ptr()); +} +void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_forward_with_state_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), s.data_ptr()); +} +void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} +void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_backward_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), + gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("forward", forward); + ops.impl("forward", torch::kCUDA, &forward); + + ops.def("forward_bf16", forward_bf16); + ops.impl("forward_bf16", torch::kCUDA, &forward_bf16); + + ops.def("forward_with_state", forward_with_state); + ops.impl("forward_with_state", torch::kCUDA, &forward_with_state); + + ops.def("forward_with_state_bf16", forward_with_state_bf16); + ops.impl("forward_with_state_bf16", torch::kCUDA, &forward_with_state_bf16); + + ops.def("backward", backward); + ops.impl("backward", torch::kCUDA, &backward); + + ops.def("backward_bf16", backward_bf16); + ops.impl("backward_bf16", torch::kCUDA, &backward_bf16); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file