From d7c681f2f2a0dcae85b2b200aa704afdafa1a860 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Mon, 18 Aug 2025 06:39:04 +0000 Subject: [PATCH] Add more tests for packed cast. --- test/data_type/test_bhalf.cpp | 126 ++++++++++++++++++ .../test_grouped_convnd_bwd_weight.cpp | 8 +- 2 files changed, 130 insertions(+), 4 deletions(-) diff --git a/test/data_type/test_bhalf.cpp b/test/data_type/test_bhalf.cpp index d5707f2f1e..c9d67aa63a 100644 --- a/test/data_type/test_bhalf.cpp +++ b/test/data_type/test_bhalf.cpp @@ -8,6 +8,8 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" #include "ck/host_utility/hip_check_error.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/stream_config.hpp" using ck::bhalf_t; using ck::type_convert; @@ -71,6 +73,130 @@ __global__ void cast(const float input, float* output) *output = type_convert(bhalf_val); } +enum struct CastMode : int +{ + Standard = 0, + PackedV1 = 1, + PackedV2 = 2 +}; + +template +__global__ void test_performance_kernel(float* input, ck::bhalf_t* output) +{ + ck::bhalf_t buffer_bf16[NumElements]; + float buffer_float[NumElements]; + + // Initialize input data + for(int i = 0; i < NumElements; i++) + { + buffer_float[i] = input[i]; + } + + // Do enough work to offset kernel launch overhead and memory transfers. + for(int i = 0; i < NumElements; i++) + { + for (int j = 0; j < NumElements; j++) + { + int index = (i + j) % NumElements; + index = index < NumElements - 1 ? index : NumElements - 2; + if constexpr (PackedCast == CastMode::PackedV1) + { + ck::bhalf2_t packed_value = ck::bf16x2_convert_rne(buffer_float[i], buffer_float[j]); + buffer_bf16[index] = packed_value[0]; + buffer_bf16[index + 1] = packed_value[1]; + } + else if constexpr (PackedCast == CastMode::PackedV2) + { + ck::bhalf2_t* buffer_range = reinterpret_cast(&buffer_bf16[index]); + *buffer_range = ck::bf16x2_convert_rne(buffer_float[i], buffer_float[j]); + } + else + { + buffer_bf16[index] = type_convert(buffer_float[i]); + buffer_bf16[index + 1] = type_convert(buffer_float[j]); + } + } + } + + // Copy results back to output + for(int i = 0; i < NumElements; i++) + { + output[i] = buffer_bf16[i]; + } +} + +template +void run_performance_test() +{ + float* input_dev; + ck::bhalf_t* output_dev; + std::vector output_host(NumElements); + + hip_check_error(hipMalloc(&input_dev, sizeof(float) * NumElements)); + hip_check_error(hipMalloc(&output_dev, sizeof(ck::bhalf_t) * NumElements)); + + // Initialize input data on the device + std::vector input_host(NumElements); + for (int i = 0; i < NumElements; i++) + { + input_host[i] = static_cast(i); + } + + hip_check_error(hipMemcpy(input_dev, input_host.data(), sizeof(float) * NumElements, hipMemcpyHostToDevice)); + + StreamConfig stream_config; + stream_config.time_kernel_ = true; + + auto baseline_kernel = test_performance_kernel; + auto packed_kernel_v1 = test_performance_kernel; + auto packed_kernel_v2 = test_performance_kernel; + + constexpr dim3 grid_size(1); + constexpr dim3 block_size(1); + constexpr size_t shared_mem_size = 0; + + const float packed_time_v1 = launch_and_time_kernel(stream_config, packed_kernel_v1, grid_size, block_size, shared_mem_size, input_dev, output_dev); + hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost)); + + const float basline_time = launch_and_time_kernel(stream_config, baseline_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev); + hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost)); + + const float packed_time_v2 = launch_and_time_kernel(stream_config, packed_kernel_v2, grid_size, block_size, shared_mem_size, input_dev, output_dev); + hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost)); + + // Cleanup + hip_check_error(hipFree(input_dev)); + hip_check_error(hipFree(output_dev)); + + std::cout << "Packed cast V1 time: " << packed_time_v1 << " ms" << std::endl; + std::cout << "Packed cast V2 time: " << packed_time_v2 << " ms" << std::endl; + std::cout << "Baseline cast time: " << basline_time << " ms" << std::endl; + + // Check if packed cast is faster than baseline + ASSERT_LT(packed_time_v1, basline_time); + ASSERT_LT(packed_time_v2, basline_time); +} + +TEST(BHALF_T, Performance_16_elements) +{ + run_performance_test<16>(); +} + +TEST(BHALF_T, Performance_32_elements) +{ + run_performance_test<32>(); +} + +TEST(BHALF_T, Performance_64_elements) +{ + run_performance_test<64>(); +} + +TEST(BHALF_T, Performance_128_elements) +{ + run_performance_test<128>(); +} + TEST(BHALF_T, PackedCast) { // Test packed cast from bhalf2 to float2 diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 3cfcb652c7..56bb69edd4 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -235,8 +235,7 @@ class TestGroupedConvndBwdWeight2d_bf16_gfx950 : public TestGroupedConvndBwdWeig }; using KernelTypes2d_bf16_gfx950 = ::testing::Types< - // This layout does not yet work. - //std::tuple>, + std::tuple>, std::tuple>, std::tuple>>; @@ -249,12 +248,13 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d_bf16_gfx950, Test2D) // n_dim group_count n_batch n_out_channels n_in_channels filter_size input_size strides dilations left_pads right_pads this->conv_params.push_back({2, 32, 64, 4, 4, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); - this->conv_params.push_back({2, 4, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); - this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back({2, 4, 64, 128, 256, {2, 2}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back({2, 2, 64, 3, 3, {2, 2}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back({2, 2, 64, 5, 5, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back({2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 3, 2, 8, 16, {4, 4}, {64, 64}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}); this->Run(); }