mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
Add more tests for packed cast.
This commit is contained in:
@@ -8,6 +8,8 @@
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/stream_config.hpp"
|
||||
|
||||
using ck::bhalf_t;
|
||||
using ck::type_convert;
|
||||
@@ -71,6 +73,130 @@ __global__ void cast(const float input, float* output)
|
||||
*output = type_convert<float>(bhalf_val);
|
||||
}
|
||||
|
||||
enum struct CastMode : int
|
||||
{
|
||||
Standard = 0,
|
||||
PackedV1 = 1,
|
||||
PackedV2 = 2
|
||||
};
|
||||
|
||||
template <CastMode PackedCast, int NumElements>
|
||||
__global__ void test_performance_kernel(float* input, ck::bhalf_t* output)
|
||||
{
|
||||
ck::bhalf_t buffer_bf16[NumElements];
|
||||
float buffer_float[NumElements];
|
||||
|
||||
// Initialize input data
|
||||
for(int i = 0; i < NumElements; i++)
|
||||
{
|
||||
buffer_float[i] = input[i];
|
||||
}
|
||||
|
||||
// Do enough work to offset kernel launch overhead and memory transfers.
|
||||
for(int i = 0; i < NumElements; i++)
|
||||
{
|
||||
for (int j = 0; j < NumElements; j++)
|
||||
{
|
||||
int index = (i + j) % NumElements;
|
||||
index = index < NumElements - 1 ? index : NumElements - 2;
|
||||
if constexpr (PackedCast == CastMode::PackedV1)
|
||||
{
|
||||
ck::bhalf2_t packed_value = ck::bf16x2_convert_rne<ck::bhalf2_t, float>(buffer_float[i], buffer_float[j]);
|
||||
buffer_bf16[index] = packed_value[0];
|
||||
buffer_bf16[index + 1] = packed_value[1];
|
||||
}
|
||||
else if constexpr (PackedCast == CastMode::PackedV2)
|
||||
{
|
||||
ck::bhalf2_t* buffer_range = reinterpret_cast<ck::bhalf2_t*>(&buffer_bf16[index]);
|
||||
*buffer_range = ck::bf16x2_convert_rne<ck::bhalf2_t, float>(buffer_float[i], buffer_float[j]);
|
||||
}
|
||||
else
|
||||
{
|
||||
buffer_bf16[index] = type_convert<ck::bhalf_t>(buffer_float[i]);
|
||||
buffer_bf16[index + 1] = type_convert<ck::bhalf_t>(buffer_float[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy results back to output
|
||||
for(int i = 0; i < NumElements; i++)
|
||||
{
|
||||
output[i] = buffer_bf16[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <int NumElements>
|
||||
void run_performance_test()
|
||||
{
|
||||
float* input_dev;
|
||||
ck::bhalf_t* output_dev;
|
||||
std::vector<ck::bhalf_t> output_host(NumElements);
|
||||
|
||||
hip_check_error(hipMalloc(&input_dev, sizeof(float) * NumElements));
|
||||
hip_check_error(hipMalloc(&output_dev, sizeof(ck::bhalf_t) * NumElements));
|
||||
|
||||
// Initialize input data on the device
|
||||
std::vector<float> input_host(NumElements);
|
||||
for (int i = 0; i < NumElements; i++)
|
||||
{
|
||||
input_host[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
hip_check_error(hipMemcpy(input_dev, input_host.data(), sizeof(float) * NumElements, hipMemcpyHostToDevice));
|
||||
|
||||
StreamConfig stream_config;
|
||||
stream_config.time_kernel_ = true;
|
||||
|
||||
auto baseline_kernel = test_performance_kernel<CastMode::Standard, NumElements>;
|
||||
auto packed_kernel_v1 = test_performance_kernel<CastMode::PackedV1, NumElements>;
|
||||
auto packed_kernel_v2 = test_performance_kernel<CastMode::PackedV2, NumElements>;
|
||||
|
||||
constexpr dim3 grid_size(1);
|
||||
constexpr dim3 block_size(1);
|
||||
constexpr size_t shared_mem_size = 0;
|
||||
|
||||
const float packed_time_v1 = launch_and_time_kernel(stream_config, packed_kernel_v1, grid_size, block_size, shared_mem_size, input_dev, output_dev);
|
||||
hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost));
|
||||
|
||||
const float basline_time = launch_and_time_kernel(stream_config, baseline_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev);
|
||||
hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost));
|
||||
|
||||
const float packed_time_v2 = launch_and_time_kernel(stream_config, packed_kernel_v2, grid_size, block_size, shared_mem_size, input_dev, output_dev);
|
||||
hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost));
|
||||
|
||||
// Cleanup
|
||||
hip_check_error(hipFree(input_dev));
|
||||
hip_check_error(hipFree(output_dev));
|
||||
|
||||
std::cout << "Packed cast V1 time: " << packed_time_v1 << " ms" << std::endl;
|
||||
std::cout << "Packed cast V2 time: " << packed_time_v2 << " ms" << std::endl;
|
||||
std::cout << "Baseline cast time: " << basline_time << " ms" << std::endl;
|
||||
|
||||
// Check if packed cast is faster than baseline
|
||||
ASSERT_LT(packed_time_v1, basline_time);
|
||||
ASSERT_LT(packed_time_v2, basline_time);
|
||||
}
|
||||
|
||||
TEST(BHALF_T, Performance_16_elements)
|
||||
{
|
||||
run_performance_test<16>();
|
||||
}
|
||||
|
||||
TEST(BHALF_T, Performance_32_elements)
|
||||
{
|
||||
run_performance_test<32>();
|
||||
}
|
||||
|
||||
TEST(BHALF_T, Performance_64_elements)
|
||||
{
|
||||
run_performance_test<64>();
|
||||
}
|
||||
|
||||
TEST(BHALF_T, Performance_128_elements)
|
||||
{
|
||||
run_performance_test<128>();
|
||||
}
|
||||
|
||||
TEST(BHALF_T, PackedCast)
|
||||
{
|
||||
// Test packed cast from bhalf2 to float2
|
||||
|
||||
@@ -235,8 +235,7 @@ class TestGroupedConvndBwdWeight2d_bf16_gfx950 : public TestGroupedConvndBwdWeig
|
||||
};
|
||||
|
||||
using KernelTypes2d_bf16_gfx950 = ::testing::Types<
|
||||
// This layout does not yet work.
|
||||
//std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
|
||||
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
|
||||
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NGCHW, GKYXC, NGKHW, ck::Number<2>>,
|
||||
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NGCHW, GKCYX, NGKHW, ck::Number<2>>>;
|
||||
|
||||
@@ -249,12 +248,13 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d_bf16_gfx950, Test2D)
|
||||
|
||||
// n_dim group_count n_batch n_out_channels n_in_channels filter_size input_size strides dilations left_pads right_pads
|
||||
this->conv_params.push_back({2, 32, 64, 4, 4, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 4, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 4, 64, 128, 256, {2, 2}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 2, 64, 3, 3, {2, 2}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 2, 64, 5, 5, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 3, 2, 8, 16, {4, 4}, {64, 64}, {2, 2}, {1, 1}, {1, 1}, {1, 1}});
|
||||
|
||||
this->Run();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user