From 07d14c9618c43aba9301f104cd6b81bcf7c5d138 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Fri, 10 Oct 2025 20:21:35 +0000 Subject: [PATCH] Merge commit '9d060d3e3c7c943a6609a95e11ff48c35b30edef' into develop --- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 112 ++++++++++-------- ...ine_ag_bg_cr_comp_async_default_policy.hpp | 82 +++++++++---- .../gemm/test_gemm_pipeline_kernel_types.hpp | 5 +- .../gemm/test_gemm_pipeline_ut_cases.inc | 5 + 4 files changed, 128 insertions(+), 76 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 2c8d008127..fa7f9fc788 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -152,6 +152,9 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -249,11 +252,6 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync; constexpr bool is_b_row_major = std::is_same_v; - // TODO currently only support A matrix row major, B matrix col major; if A matrix is - // col major or B is row major, need to combine with transpose load api - static_assert(!(is_a_col_major || is_b_row_major), - "only support A matrix is row major, B matrix is col major!"); - static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) @@ -293,18 +291,29 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + + constexpr auto b_lds_shape = []() { + if constexpr(is_b_load_tr_v) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + // LDS tile windows for storing, one per LDS buffer - auto a_copy_lds_window0 = make_tile_window( - a_lds_block0, make_tuple(number{}, number{}), {0, 0}); + auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0}); - auto a_copy_lds_window1 = make_tile_window( - a_lds_block1, make_tuple(number{}, number{}), {0, 0}); + auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0}); - auto b_copy_lds_window0 = make_tile_window( - b_lds_block0, make_tuple(number{}, number{}), {0, 0}); + auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0}); - auto b_copy_lds_window1 = make_tile_window( - b_lds_block1, make_tuple(number{}, number{}), {0, 0}); + auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0}); // initialize DRAM window steps, used to advance the DRAM windows using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; @@ -336,44 +345,49 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}], b_dram_tile_window_step); - constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution( - BlockGemm::MakeABlockDistributionEncode())){}; - constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution( - BlockGemm::MakeBBlockDistributionEncode())){}; + // tile distribution for the register tiles + constexpr auto ALdsTileDistr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto BLdsTileDistr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); // register tiles; double buffering -> a register tile corresponds to a LDS tile window - ALdsTile a_block_tile0; - ALdsTile a_block_tile1; + ALdsTile a_block_tile0, a_block_tile1; + BLdsTile b_block_tile0, b_block_tile1; - BLdsTile b_block_tile0; - BLdsTile b_block_tile1; + constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() { + if constexpr(is_a_load_tr_v) + return make_static_tile_distribution( + typename InputTileDistributionTraits< + typename decltype(ALdsTileDistr)::DstrEncode, + typename Problem::ADataType>::TransposedDstrEncode{}); + else + return ALdsTileDistr; + }(); + constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() { + if constexpr(is_b_load_tr_v) + return make_static_tile_distribution( + typename InputTileDistributionTraits< + typename decltype(BLdsTileDistr)::DstrEncode, + typename Problem::BDataType>::TransposedDstrEncode{}); + else + return BLdsTileDistr; + }(); // LDS tile windows for reading; // they share the data pointer with the LDS windows for storing // but also associate with a distribution to produce a register tile when reading auto a_lds_ld_window0 = - make_tile_window(a_lds_block0, - make_tuple(number{}, number{}), - {0, 0}, - ALdsTileDistr); + make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr); auto a_lds_ld_window1 = - make_tile_window(a_lds_block1, - make_tuple(number{}, number{}), - {0, 0}, - ALdsTileDistr); + make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr); auto b_lds_ld_window0 = - make_tile_window(b_lds_block0, - make_tuple(number{}, number{}), - {0, 0}, - BLdsTileDistr); + make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr); auto b_lds_ld_window1 = - make_tile_window(b_lds_block1, - make_tuple(number{}, number{}), - {0, 0}, - BLdsTileDistr); + make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr); static_assert(!(is_tile_window_linear_v) && !(is_tile_window_linear_v) && @@ -384,8 +398,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync(); + if constexpr(is_a_load_tr) + { + // TODO: better LDS descriptor for performance + // This branch is reusing the logic from + // UniversalGemmBasePolicy::MakeALdsBlockDescriptor + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( // + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return a_lds_block_desc_0; + } + else + { + constexpr index_t KPack = GetSmemPackA(); - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); - return transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + return transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } } template @@ -45,21 +60,36 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy { constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackB(); + if constexpr(is_b_load_tr) + { + // TODO: better LDS descriptor for performance + // This branch is reusing the logic from + // UniversalGemmBasePolicy::MakeBLdsBlockDescriptor + constexpr auto b_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return b_lds_block_desc_0; + } + else + { + constexpr index_t KPack = GetSmemPackB(); - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); - return transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + return transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } } template diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 243a823653..bba106174c 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -131,7 +131,10 @@ using KernelTypesCompV4 = ::testing::Types< >; using KernelTypesCompAsync = ::testing::Types< - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync> + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync> >; using KernelTypesCompV4Wmma = ::testing::Types< diff --git a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc index 66ef05b0ba..ae91631a00 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc @@ -33,6 +33,11 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM) } } +TYPED_TEST(TEST_SUITE_NAME, SingleTile) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + TYPED_TEST(TEST_SUITE_NAME, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573};