From c749ef9da3ea6d95915997215602af662648b854 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 10 Mar 2026 04:03:04 +0100 Subject: [PATCH] [CK][CK Tile] Add grouped conv backward weight tile test and fix tr load in BASE_V1 pipeline (#5115) ## Motivation Test grouped conv backward weight from ck tile and fix incorrect values. ## Technical Details - Add test for CI - Add daily tests - Fix transpose load in BASE_V1 pipeline ## Test Plan test_grouped_convnd_backward_weight_tile ## Test Result in progress ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-783 --- Jenkinsfile | 3 +- .../generate_instances.py | 13 + .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 93 +++---- test/grouped_convnd_bwd_weight/CMakeLists.txt | 11 + .../test_grouped_convnd_bwd_weight_tile.cpp | 237 ++++++++++++++++++ 5 files changed, 312 insertions(+), 45 deletions(-) create mode 100644 test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_tile.cpp diff --git a/Jenkinsfile b/Jenkinsfile index c435c078af..929db4e573 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1524,7 +1524,8 @@ pipeline { setup_args = "NO_CK_BUILD" execute_args = """ python3 ../experimental/grouped_convolution_tile_instances/generate_instances.py --mode=profiler && \ cmake .. --preset dev-gfx90a -D CK_EXPERIMENTAL_BUILDER=ON && \ - make -j64 test_grouped_convnd_fwd_tile && \ + make -j64 test_grouped_convnd_fwd_tile test_grouped_convnd_bwd_weight_tile && \ + ./bin/test_grouped_convnd_bwd_weight_tile && \ ./bin/test_grouped_convnd_fwd_tile""" } steps{ diff --git a/experimental/grouped_convolution_tile_instances/generate_instances.py b/experimental/grouped_convolution_tile_instances/generate_instances.py index ff48718c89..4bbe12f8c5 100755 --- a/experimental/grouped_convolution_tile_instances/generate_instances.py +++ b/experimental/grouped_convolution_tile_instances/generate_instances.py @@ -288,6 +288,8 @@ def parse_bwd_weight_instances(instances, problem_name): end = instance.rindex('>') params_str = instance[start:end] args = parse_instance_string(params_str) + + direct_load = False is_v3_instance = instance.find("Xdl_CShuffleV3") != -1 is_two_stage_instance = instance.find("TwoStage") != -1 @@ -357,6 +359,7 @@ def parse_bwd_weight_instances(instances, problem_name): if len(args) != 45: raise RuntimeError(f"Wrong number of parameters in the V3 XDL CShuffle instance string: {instance}") + direct_load = int(args[43]) == 1 num_groups_to_merge = int(args[44]) # Block GEMM pipeline parameters @@ -396,6 +399,16 @@ def parse_bwd_weight_instances(instances, problem_name): if pipeline_version == "V5": pipeline_version = "V6" + if direct_load: + if pipeline_version == "V1": + pipeline_version = "ASYNC_V1" + elif pipeline_version == "V4": + pipeline_version = "ASYNC_V4" + else: + raise RuntimeError( + f"Not supported pipeline for direct load: pipeline_version={pipeline_version} in instance: {instance}" + ) + m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave)) n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave)) warp_size = 64 diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 942a496d33..918eb3de26 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -112,6 +112,9 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr index_t kLdsAlignmentInBytes = 16; @@ -272,29 +275,27 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, elementwise_As_res); - store_tile(a_copy_lds_window, a_shuffle_tmp); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - store_tile(a_copy_lds_window, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } - - // LDS write 0 - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); - store_tile(b_copy_lds_window, b_shuffle_tmp); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - store_tile(b_copy_lds_window, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } } @@ -310,7 +311,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res); - store_tile(a_copy_lds_window, a_shuffle_tmp_loop); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - store_tile(a_copy_lds_window, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } - - // LDS write i + 1 - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { - auto b_shuffle_tmp_loop = make_static_distributed_tensor( + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res); - store_tile(b_copy_lds_window, b_shuffle_tmp_loop); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - store_tile(b_copy_lds_window, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } iCounter--; @@ -354,7 +354,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, elementwise_As_res); - store_tile(a_copy_lds_window, a_shuffle_tmp); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - store_tile(a_copy_lds_window, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } - - // LDS write 0 - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); - store_tile(b_copy_lds_window, b_shuffle_tmp); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - store_tile(b_copy_lds_window, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } } @@ -518,36 +517,38 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res); - store_tile(a_copy_lds_window, a_shuffle_tmp_loop); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - store_tile(a_copy_lds_window, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } - - // LDS write i + 1 - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { - auto b_shuffle_tmp_loop = make_static_distributed_tensor( + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res); - store_tile(b_copy_lds_window, b_shuffle_tmp_loop); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - store_tile(b_copy_lds_window, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } iCounter--; @@ -558,7 +559,11 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1 +#include +#include +#include +#include + +#include "ck_tile/builder/testing/conv/ck_tile.hpp" +#include "ck_tile/host/device_prop.hpp" +#include "profiler/grouped_convolution_backward_weight_tile_algs.hpp" + +static ck::index_t args_mask = 0xffff; +static ck::index_t instance_index = -1; + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace ckp = ck_tile::builder::profiling; + +template +struct SignatureDetails +{ + static constexpr ck_tile::index_t num_spatial_dim = num_spatial_dim_; + static constexpr ckb::DataType data_type = data_type_; + static constexpr ckb::DataType acc_data_type = acc_data_type_; + static constexpr ckb::TensorLayout in_layout = in_layout_; + static constexpr ckb::TensorLayout wei_layout = wei_layout_; + static constexpr ckb::TensorLayout out_layout = out_layout_; +}; + +template +class TestGroupedConvndBwdWeightTile : public ::testing::Test +{ + protected: + static constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = SignatureDetailsType::num_spatial_dim, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = SignatureDetailsType::data_type, + .accumulation_data_type = SignatureDetailsType::acc_data_type, + .input = {.config = {.layout = SignatureDetailsType::in_layout}}, + .weight = {.config = {.layout = SignatureDetailsType::wei_layout}}, + .output = {.config = {.layout = SignatureDetailsType::out_layout}}}; + + std::vector> conv_args; + std::vector split_ks{"-1", "1", "2"}; + + template + void Run() + { + ASSERT_FALSE(conv_args.empty()); + bool pass = true; + for(size_t i = 0; i < conv_args.size(); i++) + { + for(auto& split_k : split_ks) + { + if((args_mask & (1 << i)) == 0) + { + continue; + } + auto& args = conv_args[i]; + + auto inputs = alloc_inputs(args); + auto outputs = alloc_outputs(args); + ckt::init_tensor_buffer_uniform_int( + inputs.get().input, args.make_input_descriptor(), -5, 5); + ckt::init_tensor_buffer_uniform_int( + inputs.get().output, args.make_output_descriptor(), -5, 5); + + std::cout << args.make_input_descriptor() << std::endl; + std::cout << args.make_weight_descriptor() << std::endl; + std::cout << args.make_output_descriptor() << std::endl; + [[maybe_unused]] auto&& [case_passed, avg_time, op_name, best_split_k] = + + ckp::run_grouped_conv_backward_weight_tile_algs( + args, + split_k, + inputs.get(), + outputs.get(), + ck_tile::stream_config{nullptr, false /*time_kernel*/}); + + pass = pass && case_passed; + } + } + EXPECT_TRUE(pass); + } + + void conv_args_append(std::size_t, + std::size_t G, + std::size_t N, + std::size_t K, + std::size_t C, + const std::vector& filter_spatial_lengths, + const std::vector& input_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + ckt::Args args = { + .lengths = + { + .batch_size = N, + .groups = G, + .input_channels = C, + .output_channels = K, + .image = ckt::filter_extent_from_vector( + input_spatial_lengths), + .filter = ckt::filter_extent_from_vector( + filter_spatial_lengths), + }, + .filter_strides = ckt::filter_extent_from_vector( + conv_filter_strides), + .filter_dilation = + ckt::filter_extent_from_vector( + conv_filter_dilations), + .input_left_pad = ckt::filter_extent_from_vector( + input_left_pads), + .input_right_pad = + ckt::filter_extent_from_vector( + input_right_pads), + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + conv_args.push_back(args); + } +}; + +using KernelTypes2d = ::testing::Types, + SignatureDetails<2, + ckb::DataType::FP16, + ckb::DataType::FP32, + ckb::TensorLayout::NHWGC, + ckb::TensorLayout::GKYXC, + ckb::TensorLayout::NHWGK>, + SignatureDetails<2, + ckb::DataType::BF16, + ckb::DataType::FP32, + ckb::TensorLayout::NHWGC, + ckb::TensorLayout::GKYXC, + ckb::TensorLayout::NHWGK>>; + +using KernelTypes3d = ::testing::Types, + SignatureDetails<3, + ckb::DataType::FP16, + ckb::DataType::FP32, + ckb::TensorLayout::NDHWGC, + ckb::TensorLayout::GKZYXC, + ckb::TensorLayout::NDHWGK>, + SignatureDetails<3, + ckb::DataType::BF16, + ckb::DataType::FP32, + ckb::TensorLayout::NDHWGC, + ckb::TensorLayout::GKZYXC, + ckb::TensorLayout::NDHWGK>>; + +template +class TestGroupedConvndBwdWeightTile2d : public TestGroupedConvndBwdWeightTile +{ +}; + +template +class TestGroupedConvndBwdWeightTile3d : public TestGroupedConvndBwdWeightTile +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndBwdWeightTile2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndBwdWeightTile3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndBwdWeightTile2d, Test2D) +{ + this->conv_args.clear(); + this->conv_args_append(2, 2, 64, 4, 4, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 2, 64, 5, 5, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndBwdWeightTile3d, Test3D) +{ + this->conv_args.clear(); + this->conv_args_append( + 3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->conv_args_append( + 3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 1, 1, 1, 32, {3, 3, 3}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->conv_args_append( + 3, 1, 1, 64, 3, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->conv_args_append( + 3, 1, 1, 1, 1, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->conv_args_append( + 3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->template Run<3>(); +} + +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) {} + else if(argc == 3) + { + args_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: args_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +}