diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index c859cfba3d..e9fd1ea88f 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -39,6 +39,19 @@ namespace details { } // namespace details } // namespace +#if defined(__gfx950__) +inline __device__ bhalf_t static_cast_float_to_bf16(float x) +{ + union + { + uint16_t uint16; + __bf16 bf16; + } out; + out.bf16 = static_cast<__bf16>(x); + return out.uint16; +} +#endif + // Declare a template function for bf16 conversion using RTN template __host__ __device__ constexpr Y bf16_convert_rtn(X x); @@ -47,6 +60,9 @@ __host__ __device__ constexpr Y bf16_convert_rtn(X x); template <> inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) { +#if defined(__gfx950__) + return static_cast_float_to_bf16(x); +#else // Nan check if(x != x) { @@ -63,6 +79,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(fl constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1); return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16); +#endif } // convert fp16 to bfp16 via fp32 with RTN if higher precision is needed diff --git a/test/data_type/test_bhalf.cpp b/test/data_type/test_bhalf.cpp index cadd8c70cf..ad31e194b8 100644 --- a/test/data_type/test_bhalf.cpp +++ b/test/data_type/test_bhalf.cpp @@ -2,8 +2,12 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" + +#include + #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" +#include "ck/host_utility/hip_check_error.hpp" using ck::bhalf_t; using ck::type_convert; @@ -46,3 +50,45 @@ TEST(BHALF_T, MantisaExpOverflow) ASSERT_TRUE(std::isnan(float_val)); ASSERT_TRUE(std::isnan(type_convert(type_convert(float_val)))); } + +__global__ void cast(const float input, float* output) +{ + const bhalf_t bhalf_val = type_convert(input); + *output = type_convert(bhalf_val); +} + +TEST(BHALF_T, CastOnDevice) +{ + constexpr int num_vals = 11; + const float abs_tol = std::pow(2, -7); + float float_vals[num_vals] = {0.5, 0.875, 1.5, 1, 2, 4, 8, 16, 32, 64, 128}; + + float* float_val_after_cast_dev; + float float_val_after_cast_host; + hip_check_error(hipMalloc(&float_val_after_cast_dev, sizeof(float))); + + // Positive + for(int idx = 0; idx < num_vals; idx++) + { + cast<<<1, 1>>>(float_vals[idx], float_val_after_cast_dev); + + hip_check_error(hipMemcpy(&float_val_after_cast_host, + float_val_after_cast_dev, + sizeof(float), + hipMemcpyDeviceToHost)); + + ASSERT_NEAR(float_val_after_cast_host, float_vals[idx], abs_tol); + } + // Negative + for(int idx = 0; idx < num_vals; idx++) + { + cast<<<1, 1>>>(-float_vals[idx], float_val_after_cast_dev); + + hip_check_error(hipMemcpy(&float_val_after_cast_host, + float_val_after_cast_dev, + sizeof(float), + hipMemcpyDeviceToHost)); + + ASSERT_NEAR(float_val_after_cast_host, -float_vals[idx], abs_tol); + } +}