Add perf test. Fix packed bf16 cast implementation.

This commit is contained in:
Ville Pietilä
2025-08-18 11:37:14 +00:00
parent 6b2b5e7c7c
commit b48ae7e447
2 changed files with 37 additions and 71 deletions

View File

@@ -52,20 +52,6 @@ inline __device__ bhalf_t static_cast_float_to_bf16(float x)
}
#endif
#if defined(__gfx950__)
inline __device__ bhalf2_t static_cast_float_x2_to_bhalf2_rne(float x, float y)
{
union {
uint32_t u32;
bhalf2_t bf16x2;
} value;
asm volatile("v_cvt_pk_bf16_f32 %0, %1, %2"
: "=v"(value.u32)
: "v"(x), "v"(y));
return value.bf16x2;
}
#endif
// Declare a template function for conversion of bf16 vector of two values using RNE
template <typename T, typename U>
__host__ __device__ constexpr T bf16x2_convert_rne(U x, U y);
@@ -74,13 +60,8 @@ __host__ __device__ constexpr T bf16x2_convert_rne(U x, U y);
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn_base(float x)
{
#if defined(__gfx950__) && defined(__HIP_DEVICE_COMPILE__)
return static_cast_float_to_bf16(x);
#else
// Nan check
if(x != x)
{
@@ -97,6 +78,16 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(fl
constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1);
return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16);
}
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
#if defined(__gfx950__) && defined(__HIP_DEVICE_COMPILE__)
return static_cast_float_to_bf16(x);
#else
return bf16_convert_rtn_base(x);
#endif
}
@@ -111,15 +102,8 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(fl
template<>
inline __host__ __device__ constexpr bhalf2_t bf16x2_convert_rne<bhalf2_t, float>(float x, float y)
{
#if defined(__gfx950__)
return static_cast_float_x2_to_bhalf2_rne(x, y);
#else
// TODO: Perform real RNE conversion for bfloat16
return {
bf16_convert_rtn<bhalf_t>(x),
bf16_convert_rtn<bhalf_t>(y)
};
#endif
// for gfx950, the compiler will use device instruction v_cvt_pk_bf16_f32 to execute packed cast.
return {bf16_convert_rtn<bhalf_t>(x), bf16_convert_rtn<bhalf_t>(y)};
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed

View File

@@ -9,6 +9,7 @@
#include "ck/utility/type_convert.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/stream_config.hpp"
using ck::bhalf_t;
@@ -76,8 +77,7 @@ __global__ void cast(const float input, float* output)
enum struct CastMode : int
{
Standard = 0,
PackedV1 = 1,
PackedV2 = 2
Packed = 1
};
template <CastMode PackedCast, int NumElements>
@@ -99,21 +99,15 @@ __global__ void test_performance_kernel(float* input, ck::bhalf_t* output)
{
int index = (i + j) % NumElements;
index = index < NumElements - 1 ? index : NumElements - 2;
if constexpr (PackedCast == CastMode::PackedV1)
{
ck::bhalf2_t packed_value = ck::bf16x2_convert_rne<ck::bhalf2_t, float>(buffer_float[i], buffer_float[j]);
buffer_bf16[index] = packed_value[0];
buffer_bf16[index + 1] = packed_value[1];
}
else if constexpr (PackedCast == CastMode::PackedV2)
if constexpr (PackedCast == CastMode::Packed)
{
ck::bhalf2_t* buffer_range = reinterpret_cast<ck::bhalf2_t*>(&buffer_bf16[index]);
*buffer_range = ck::bf16x2_convert_rne<ck::bhalf2_t, float>(buffer_float[i], buffer_float[j]);
}
else
{
buffer_bf16[index] = type_convert<ck::bhalf_t>(buffer_float[i]);
buffer_bf16[index + 1] = type_convert<ck::bhalf_t>(buffer_float[j]);
buffer_bf16[index] = ck::bf16_convert_rtn_base(buffer_float[i]);
buffer_bf16[index + 1] = ck::bf16_convert_rtn_base(buffer_float[j]);
}
}
}
@@ -139,7 +133,7 @@ void run_performance_test()
std::vector<float> input_host(NumElements);
for (int i = 0; i < NumElements; i++)
{
input_host[i] = static_cast<float>(i);
input_host[i] = 3.14f * static_cast<float>(i) - 1.7f;
}
hip_check_error(hipMemcpy(input_dev, input_host.data(), sizeof(float) * NumElements, hipMemcpyHostToDevice));
@@ -148,56 +142,44 @@ void run_performance_test()
stream_config.time_kernel_ = true;
auto baseline_kernel = test_performance_kernel<CastMode::Standard, NumElements>;
auto packed_kernel_v1 = test_performance_kernel<CastMode::PackedV1, NumElements>;
auto packed_kernel_v2 = test_performance_kernel<CastMode::PackedV2, NumElements>;
auto packed_kernel = test_performance_kernel<CastMode::Packed, NumElements>;
constexpr dim3 grid_size(1);
constexpr dim3 block_size(1);
constexpr size_t shared_mem_size = 0;
const float packed_time_v1 = launch_and_time_kernel(stream_config, packed_kernel_v1, grid_size, block_size, shared_mem_size, input_dev, output_dev);
const float baseline_time = launch_and_time_kernel(stream_config, baseline_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev);
hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost));
const float basline_time = launch_and_time_kernel(stream_config, baseline_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev);
hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost));
const float packed_time_v2 = launch_and_time_kernel(stream_config, packed_kernel_v2, grid_size, block_size, shared_mem_size, input_dev, output_dev);
const float packed_time = launch_and_time_kernel(stream_config, packed_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev);
hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost));
// Cleanup
hip_check_error(hipFree(input_dev));
hip_check_error(hipFree(output_dev));
std::cout << "Packed cast V1 time: " << packed_time_v1 << " ms" << std::endl;
std::cout << "Packed cast V2 time: " << packed_time_v2 << " ms" << std::endl;
std::cout << "Baseline cast time: " << basline_time << " ms" << std::endl;
std::cout << "Packed cast time ( " << NumElements << " elements): " << packed_time << " ms" << std::endl;
std::cout << "Baseline cast time ( " << NumElements << " elements): " << baseline_time << " ms" << std::endl;
// Check if packed cast is faster than baseline
ASSERT_LT(packed_time_v1, basline_time);
ASSERT_LT(packed_time_v2, basline_time);
ASSERT_LT(packed_time, baseline_time);
}
TEST(BHALF_T, Performance_16_elements)
TEST(BHALF_T, Performance)
{
run_performance_test<16>();
if (ck::get_device_name() == "gfx950")
{
run_performance_test<32>();
run_performance_test<64>();
run_performance_test<128>();
}
else
{
GTEST_SKIP() << "Packed cast performance test requires gfx950.";
}
}
TEST(BHALF_T, Performance_32_elements)
{
run_performance_test<32>();
}
TEST(BHALF_T, Performance_64_elements)
{
run_performance_test<64>();
}
TEST(BHALF_T, Performance_128_elements)
{
run_performance_test<128>();
}
TEST(BHALF_T, PackedCast)
TEST(BHALF_T, PackedCastCorrectness)
{
// Test packed cast from bhalf2 to float2
// Use values that are representable in bhalf2 as well as values that are not