|
|
| #ifndef _data_types_cuh |
| #define _data_types_cuh |
| #include "marlin.cuh" |
| #include <cuda_fp16.h> |
| #include <cuda_bf16.h> |
|
|
| #ifndef MARLIN_NAMESPACE_NAME |
| #define MARLIN_NAMESPACE_NAME marlin |
| #endif |
|
|
| namespace MARLIN_NAMESPACE_NAME { |
|
|
| template <typename scalar_t> |
| class ScalarType {}; |
|
|
| template <> |
| class ScalarType<half> { |
| public: |
| using scalar_t = half; |
| using scalar_t2 = half2; |
|
|
| |
| |
| |
| using FragA = Vec<half2, 4>; |
| using FragB = Vec<half2, 2>; |
| using FragC = Vec<float, 4>; |
| using FragS = Vec<half2, 1>; |
| using FragZP = Vec<half2, 4>; |
|
|
| static __device__ float inline num2float(const half x) { |
| return __half2float(x); |
| } |
|
|
| static __device__ half2 inline num2num2(const half x) { |
| return __half2half2(x); |
| } |
|
|
| static __device__ half2 inline nums2num2(const half x1, const half x2) { |
| return __halves2half2(x1, x2); |
| } |
|
|
| static __host__ __device__ half inline float2num(const float x) { |
| return __float2half(x); |
| } |
| }; |
|
|
| template <> |
| class ScalarType<nv_bfloat16> { |
| public: |
| using scalar_t = nv_bfloat16; |
| using scalar_t2 = nv_bfloat162; |
|
|
| using FragA = Vec<nv_bfloat162, 4>; |
| using FragB = Vec<nv_bfloat162, 2>; |
| using FragC = Vec<float, 4>; |
| using FragS = Vec<nv_bfloat162, 1>; |
| using FragZP = Vec<nv_bfloat162, 4>; |
|
|
| #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 |
| static __device__ float inline num2float(const nv_bfloat16 x) { |
| return __bfloat162float(x); |
| } |
|
|
| static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { |
| return __bfloat162bfloat162(x); |
| } |
|
|
| static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, |
| const nv_bfloat16 x2) { |
| return __halves2bfloat162(x1, x2); |
| } |
|
|
| static __host__ __device__ nv_bfloat16 inline float2num(const float x) { |
| return __float2bfloat16(x); |
| } |
| #endif |
| }; |
|
|
| } |
|
|
| #endif |
|
|