| #pragma once |
|
|
| #ifdef __HIPCC__ |
| #include <hip/hip_runtime.h> |
| #else |
| #include <type_traits> |
| #include <stdint.h> |
| #include <math.h> |
| #include <iostream> |
| #endif |
|
|
| #include "hip_float8_impl.h" |
|
|
| struct alignas(1) hip_fp8 { |
| struct from_bits_t {}; |
| HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { |
| return from_bits_t(); |
| } |
| uint8_t data; |
|
|
| hip_fp8() = default; |
| HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; |
| HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; |
| explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) |
| : data(v) {} |
|
|
| #ifdef __HIP__MI300__ |
| |
| explicit HIP_FP8_DEVICE hip_fp8(float v) |
| : data(hip_fp8_impl::to_fp8_from_fp32(v)) {} |
|
|
| explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) |
| : hip_fp8(static_cast<float>(v)) {} |
|
|
| |
| explicit HIP_FP8_HOST |
| #else |
| |
| explicit HIP_FP8_HOST_DEVICE |
| #endif |
| hip_fp8(float v) { |
| data = hip_fp8_impl::to_float8<4, 3, float, true , |
| true >(v); |
| } |
|
|
| explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) |
| : hip_fp8(static_cast<float>(v)) {} |
|
|
| #ifdef __HIP__MI300__ |
| |
| explicit inline HIP_FP8_DEVICE operator float() const { |
| float fval; |
| uint32_t i32val = static_cast<uint32_t>(data); |
|
|
| |
| asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" |
| : "=v"(fval) |
| : "v"(i32val)); |
|
|
| return fval; |
| } |
|
|
| explicit inline HIP_FP8_HOST operator float() const |
| #else |
| explicit inline HIP_FP8_HOST_DEVICE operator float() const |
| #endif |
| { |
| return hip_fp8_impl::from_float8<4, 3, float, true >( |
| data); |
| } |
| }; |
|
|
| namespace std { |
| inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); } |
| inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); } |
| HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; } |
| } |
|
|
| |
| inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) { |
| return os << float(f8); |
| } |
|
|
| |
| |
| |
| inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) { |
| return (fa + float(b)); |
| } |
|
|
| inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) { |
| return (float(a) + fb); |
| } |
|
|
| inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) { |
| return hip_fp8(float(a) + float(b)); |
| } |
|
|
| inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) { |
| return a = hip_fp8(float(a) + float(b)); |
| } |
|
|
| |
| inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) { |
| return float(a) * float(b); |
| } |
|
|
| inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) { |
| return (a * float(b)); |
| } |
|
|
| inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) { |
| return (float(a) * b); |
| } |
|
|
| inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) { |
| return ((float)a * float(b)); |
| } |
|
|
| inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) { |
| return ((float)a * float(b)); |
| } |
|
|
| |
| inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) { |
| return (a.data == b.data); |
| } |
| inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) { |
| return (a.data != b.data); |
| } |
|
|
| inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) { |
| return static_cast<float>(a) >= static_cast<float>(b); |
| } |
| inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) { |
| return static_cast<float>(a) > static_cast<float>(b); |
| } |
|
|