Add more tests for packed cast.

This commit is contained in:
Ville Pietilä
2025-08-18 06:39:04 +00:00
parent c0b8f66674
commit d7c681f2f2
2 changed files with 130 additions and 4 deletions

View File

@@ -8,6 +8,8 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/stream_config.hpp"
using ck::bhalf_t;
using ck::type_convert;
@@ -71,6 +73,130 @@ __global__ void cast(const float input, float* output)
*output = type_convert<float>(bhalf_val);
}
enum struct CastMode : int
{
Standard = 0,
PackedV1 = 1,
PackedV2 = 2
};
template <CastMode PackedCast, int NumElements>
__global__ void test_performance_kernel(float* input, ck::bhalf_t* output)
{
ck::bhalf_t buffer_bf16[NumElements];
float buffer_float[NumElements];
// Initialize input data
for(int i = 0; i < NumElements; i++)
{
buffer_float[i] = input[i];
}
// Do enough work to offset kernel launch overhead and memory transfers.
for(int i = 0; i < NumElements; i++)
{
for (int j = 0; j < NumElements; j++)
{
int index = (i + j) % NumElements;
index = index < NumElements - 1 ? index : NumElements - 2;
if constexpr (PackedCast == CastMode::PackedV1)
{
ck::bhalf2_t packed_value = ck::bf16x2_convert_rne<ck::bhalf2_t, float>(buffer_float[i], buffer_float[j]);
buffer_bf16[index] = packed_value[0];
buffer_bf16[index + 1] = packed_value[1];
}
else if constexpr (PackedCast == CastMode::PackedV2)
{
ck::bhalf2_t* buffer_range = reinterpret_cast<ck::bhalf2_t*>(&buffer_bf16[index]);
*buffer_range = ck::bf16x2_convert_rne<ck::bhalf2_t, float>(buffer_float[i], buffer_float[j]);
}
else
{
buffer_bf16[index] = type_convert<ck::bhalf_t>(buffer_float[i]);
buffer_bf16[index + 1] = type_convert<ck::bhalf_t>(buffer_float[j]);
}
}
}
// Copy results back to output
for(int i = 0; i < NumElements; i++)
{
output[i] = buffer_bf16[i];
}
}
template <int NumElements>
void run_performance_test()
{
float* input_dev;
ck::bhalf_t* output_dev;
std::vector<ck::bhalf_t> output_host(NumElements);
hip_check_error(hipMalloc(&input_dev, sizeof(float) * NumElements));
hip_check_error(hipMalloc(&output_dev, sizeof(ck::bhalf_t) * NumElements));
// Initialize input data on the device
std::vector<float> input_host(NumElements);
for (int i = 0; i < NumElements; i++)
{
input_host[i] = static_cast<float>(i);
}
hip_check_error(hipMemcpy(input_dev, input_host.data(), sizeof(float) * NumElements, hipMemcpyHostToDevice));
StreamConfig stream_config;
stream_config.time_kernel_ = true;
auto baseline_kernel = test_performance_kernel<CastMode::Standard, NumElements>;
auto packed_kernel_v1 = test_performance_kernel<CastMode::PackedV1, NumElements>;
auto packed_kernel_v2 = test_performance_kernel<CastMode::PackedV2, NumElements>;
constexpr dim3 grid_size(1);
constexpr dim3 block_size(1);
constexpr size_t shared_mem_size = 0;
const float packed_time_v1 = launch_and_time_kernel(stream_config, packed_kernel_v1, grid_size, block_size, shared_mem_size, input_dev, output_dev);
hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost));
const float basline_time = launch_and_time_kernel(stream_config, baseline_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev);
hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost));
const float packed_time_v2 = launch_and_time_kernel(stream_config, packed_kernel_v2, grid_size, block_size, shared_mem_size, input_dev, output_dev);
hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost));
// Cleanup
hip_check_error(hipFree(input_dev));
hip_check_error(hipFree(output_dev));
std::cout << "Packed cast V1 time: " << packed_time_v1 << " ms" << std::endl;
std::cout << "Packed cast V2 time: " << packed_time_v2 << " ms" << std::endl;
std::cout << "Baseline cast time: " << basline_time << " ms" << std::endl;
// Check if packed cast is faster than baseline
ASSERT_LT(packed_time_v1, basline_time);
ASSERT_LT(packed_time_v2, basline_time);
}
TEST(BHALF_T, Performance_16_elements)
{
run_performance_test<16>();
}
TEST(BHALF_T, Performance_32_elements)
{
run_performance_test<32>();
}
TEST(BHALF_T, Performance_64_elements)
{
run_performance_test<64>();
}
TEST(BHALF_T, Performance_128_elements)
{
run_performance_test<128>();
}
TEST(BHALF_T, PackedCast)
{
// Test packed cast from bhalf2 to float2

View File

@@ -235,8 +235,7 @@ class TestGroupedConvndBwdWeight2d_bf16_gfx950 : public TestGroupedConvndBwdWeig
};
using KernelTypes2d_bf16_gfx950 = ::testing::Types<
// This layout does not yet work.
//std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NGCHW, GKYXC, NGKHW, ck::Number<2>>,
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NGCHW, GKCYX, NGKHW, ck::Number<2>>>;
@@ -249,12 +248,13 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d_bf16_gfx950, Test2D)
// n_dim group_count n_batch n_out_channels n_in_channels filter_size input_size strides dilations left_pads right_pads
this->conv_params.push_back({2, 32, 64, 4, 4, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 4, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 4, 64, 128, 256, {2, 2}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 2, 64, 3, 3, {2, 2}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 2, 64, 5, 5, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 3, 2, 8, 16, {4, 4}, {64, 64}, {2, 2}, {1, 1}, {1, 1}, {1, 1}});
this->Run();
}