From b48ae7e447c9707060ad3e358a3a36ceb01e8034 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Mon, 18 Aug 2025 11:37:14 +0000 Subject: [PATCH] Add perf test. Fix packed bf16 cast implementation. --- include/ck/utility/type_convert.hpp | 42 ++++++------------ test/data_type/test_bhalf.cpp | 66 +++++++++++------------------ 2 files changed, 37 insertions(+), 71 deletions(-) diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index d45adeebc7..0ee53a9011 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -52,20 +52,6 @@ inline __device__ bhalf_t static_cast_float_to_bf16(float x) } #endif -#if defined(__gfx950__) -inline __device__ bhalf2_t static_cast_float_x2_to_bhalf2_rne(float x, float y) -{ - union { - uint32_t u32; - bhalf2_t bf16x2; - } value; - asm volatile("v_cvt_pk_bf16_f32 %0, %1, %2" - : "=v"(value.u32) - : "v"(x), "v"(y)); - return value.bf16x2; -} -#endif - // Declare a template function for conversion of bf16 vector of two values using RNE template __host__ __device__ constexpr T bf16x2_convert_rne(U x, U y); @@ -74,13 +60,8 @@ __host__ __device__ constexpr T bf16x2_convert_rne(U x, U y); template __host__ __device__ constexpr Y bf16_convert_rtn(X x); -// Convert fp32 to bf16 with RTN if higher precision is needed -template <> -inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) +inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn_base(float x) { -#if defined(__gfx950__) && defined(__HIP_DEVICE_COMPILE__) - return static_cast_float_to_bf16(x); -#else // Nan check if(x != x) { @@ -97,6 +78,16 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(fl constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1); return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16); +} + +// Convert fp32 to bf16 with RTN if higher precision is needed +template <> +inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) +{ +#if defined(__gfx950__) && defined(__HIP_DEVICE_COMPILE__) + return static_cast_float_to_bf16(x); +#else + return bf16_convert_rtn_base(x); #endif } @@ -111,15 +102,8 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(fl template<> inline __host__ __device__ constexpr bhalf2_t bf16x2_convert_rne(float x, float y) { -#if defined(__gfx950__) - return static_cast_float_x2_to_bhalf2_rne(x, y); -#else - // TODO: Perform real RNE conversion for bfloat16 - return { - bf16_convert_rtn(x), - bf16_convert_rtn(y) - }; -#endif + // for gfx950, the compiler will use device instruction v_cvt_pk_bf16_f32 to execute packed cast. + return {bf16_convert_rtn(x), bf16_convert_rtn(y)}; } // convert fp16 to bfp16 via fp32 with RTN if higher precision is needed diff --git a/test/data_type/test_bhalf.cpp b/test/data_type/test_bhalf.cpp index c9d67aa63a..b2fc813d33 100644 --- a/test/data_type/test_bhalf.cpp +++ b/test/data_type/test_bhalf.cpp @@ -9,6 +9,7 @@ #include "ck/utility/type_convert.hpp" #include "ck/host_utility/hip_check_error.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/device_prop.hpp" #include "ck/stream_config.hpp" using ck::bhalf_t; @@ -76,8 +77,7 @@ __global__ void cast(const float input, float* output) enum struct CastMode : int { Standard = 0, - PackedV1 = 1, - PackedV2 = 2 + Packed = 1 }; template @@ -99,21 +99,15 @@ __global__ void test_performance_kernel(float* input, ck::bhalf_t* output) { int index = (i + j) % NumElements; index = index < NumElements - 1 ? index : NumElements - 2; - if constexpr (PackedCast == CastMode::PackedV1) - { - ck::bhalf2_t packed_value = ck::bf16x2_convert_rne(buffer_float[i], buffer_float[j]); - buffer_bf16[index] = packed_value[0]; - buffer_bf16[index + 1] = packed_value[1]; - } - else if constexpr (PackedCast == CastMode::PackedV2) + if constexpr (PackedCast == CastMode::Packed) { ck::bhalf2_t* buffer_range = reinterpret_cast(&buffer_bf16[index]); *buffer_range = ck::bf16x2_convert_rne(buffer_float[i], buffer_float[j]); } else { - buffer_bf16[index] = type_convert(buffer_float[i]); - buffer_bf16[index + 1] = type_convert(buffer_float[j]); + buffer_bf16[index] = ck::bf16_convert_rtn_base(buffer_float[i]); + buffer_bf16[index + 1] = ck::bf16_convert_rtn_base(buffer_float[j]); } } } @@ -139,7 +133,7 @@ void run_performance_test() std::vector input_host(NumElements); for (int i = 0; i < NumElements; i++) { - input_host[i] = static_cast(i); + input_host[i] = 3.14f * static_cast(i) - 1.7f; } hip_check_error(hipMemcpy(input_dev, input_host.data(), sizeof(float) * NumElements, hipMemcpyHostToDevice)); @@ -148,56 +142,44 @@ void run_performance_test() stream_config.time_kernel_ = true; auto baseline_kernel = test_performance_kernel; - auto packed_kernel_v1 = test_performance_kernel; - auto packed_kernel_v2 = test_performance_kernel; + auto packed_kernel = test_performance_kernel; constexpr dim3 grid_size(1); constexpr dim3 block_size(1); constexpr size_t shared_mem_size = 0; - const float packed_time_v1 = launch_and_time_kernel(stream_config, packed_kernel_v1, grid_size, block_size, shared_mem_size, input_dev, output_dev); + const float baseline_time = launch_and_time_kernel(stream_config, baseline_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev); hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost)); - const float basline_time = launch_and_time_kernel(stream_config, baseline_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev); - hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost)); - - const float packed_time_v2 = launch_and_time_kernel(stream_config, packed_kernel_v2, grid_size, block_size, shared_mem_size, input_dev, output_dev); + const float packed_time = launch_and_time_kernel(stream_config, packed_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev); hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost)); // Cleanup hip_check_error(hipFree(input_dev)); hip_check_error(hipFree(output_dev)); - std::cout << "Packed cast V1 time: " << packed_time_v1 << " ms" << std::endl; - std::cout << "Packed cast V2 time: " << packed_time_v2 << " ms" << std::endl; - std::cout << "Baseline cast time: " << basline_time << " ms" << std::endl; + std::cout << "Packed cast time ( " << NumElements << " elements): " << packed_time << " ms" << std::endl; + std::cout << "Baseline cast time ( " << NumElements << " elements): " << baseline_time << " ms" << std::endl; // Check if packed cast is faster than baseline - ASSERT_LT(packed_time_v1, basline_time); - ASSERT_LT(packed_time_v2, basline_time); + ASSERT_LT(packed_time, baseline_time); } -TEST(BHALF_T, Performance_16_elements) +TEST(BHALF_T, Performance) { - run_performance_test<16>(); + if (ck::get_device_name() == "gfx950") + { + run_performance_test<32>(); + run_performance_test<64>(); + run_performance_test<128>(); + } + else + { + GTEST_SKIP() << "Packed cast performance test requires gfx950."; + } } -TEST(BHALF_T, Performance_32_elements) -{ - run_performance_test<32>(); -} - -TEST(BHALF_T, Performance_64_elements) -{ - run_performance_test<64>(); -} - -TEST(BHALF_T, Performance_128_elements) -{ - run_performance_test<128>(); -} - -TEST(BHALF_T, PackedCast) +TEST(BHALF_T, PackedCastCorrectness) { // Test packed cast from bhalf2 to float2 // Use values that are representable in bhalf2 as well as values that are not