From d2bbca3eca2bd14014e3daae39ae70846ec8218b Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Mon, 13 Oct 2025 13:27:02 +0100 Subject: [PATCH] [CK_TILE] Non-K Major from old CK to CK-Tile (#2442) * Enable the adapted LDS B layout for Row-Major * fix formatting * Implement specialized col-major A LDS block descriptor * Fix formatting * Use VecLoadSize for AK1/BK1 * Fix some thread access pattern values * Use GetVectorSizeA for A * Fix formatting * Add extra condition to avoid division by zero * disable layout for wave32 * remove extra else * fix formatting * Fix formatting * Rename one remaining TileDistributionEncodingPattern2D * Use integer ceil division * revert remod.py changes * also revert utility.hpp * use getA/BTileAccessPattern everywhere * use integer_divide_ceil for AK0 too --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: Adam Osewski --- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 520 +++++++++++------- 1 file changed, 318 insertions(+), 202 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 4030783ecc..89e0346961 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -73,10 +73,14 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; using ADataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); + constexpr auto DataTypeSize = sizeof(ADataType); if constexpr(is_a_load_tr) { @@ -90,47 +94,168 @@ struct UniversalGemmBasePolicy } else { - constexpr index_t KPack = GetSmemPackA(); + // Only use this ColumnMajor layout for Wave64 mode (gfx9) + constexpr auto Wave64 = get_warp_size() == 64; + if constexpr(Wave64 && + std::is_same_v) + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg + // offset for compiler. + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t VecLoadSize = GetVectorSizeA(); + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + // AK1 + constexpr auto AK1 = number{}; + constexpr auto AK0 = number{}; + // How the M dimension is split across threads + constexpr auto M0 = TileEncodingPattern::X0; // # of threads in M dim + constexpr auto M1 = number{}; - constexpr auto DataTypeSize = sizeof(ADataType); - constexpr auto MLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + // Get the warp tile size + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr auto MPerXdl = number{}; - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); + // How many elements we can write by single thread to LDS, + // the transposed / shuffled tile dstr has size: + constexpr auto KThreadWrite = TileEncodingPattern::Y2; + constexpr auto K0PerThreadWrite = integer_divide_ceil(AK0, KThreadWrite); + constexpr auto KThreadRead = get_warp_size() / MPerXdl; + constexpr auto K0PerThreadRead = integer_divide_ceil(AK0, KThreadRead); - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto LdsBanksWidth = 128; + constexpr auto kfold = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth) + ? 1 + : LdsBanksWidth / (AK1 * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + ((kfold * K0PerThreadWrite / K0PerThreadRead) > 1 && + (kfold * K0PerThreadWrite / K0PerThreadRead) < KThreadRead) + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; - constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + // 1<=mpair<=n0 + constexpr auto mpair = + (AK1 * MPerXdl * sizeof(ADataType) > LdsBanksWidth) + ? 1 + : ((LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType))); - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + AK1), + AK1); - return a_lds_block_desc; + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_pass_through_transform( + number{}), + make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(AK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2, 3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2, 3>{}, + sequence<4>{}, + sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform( + number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(AK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + AK1)), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // A is in RowMajor + { + constexpr auto MLdsLayer = (32 * 4 / KPerBlock / DataTypeSize) < 1 + ? 1 + : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple( + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod(make_tuple( + number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } } } @@ -143,12 +268,12 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { + using BLayout = remove_cvref_t; using BDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; -#if 1 if constexpr(is_b_load_tr) { // TODO: better lds descriptor for performance @@ -160,178 +285,169 @@ struct UniversalGemmBasePolicy return b_lds_block_desc_0; } else - // else if constexpr(std::is_same_v) { - constexpr index_t KPack = GetSmemPackB(); - constexpr auto BK0 = number{}; - constexpr auto DataTypeSize = sizeof(BDataType); - constexpr auto NLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + // Only use this RowMajor layout for Wave64 mode (gfx9) + constexpr auto Wave64 = get_warp_size() == 64; + if constexpr(Wave64 && std::is_same_v) + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t VecLoadSize = GetVectorSizeB(); + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + // BK1 + constexpr auto BK1 = number{}; + constexpr auto BK0 = number{}; - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple( - BK0 * number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); + // How threads access data on N dim + constexpr auto N0 = TileEncodingPattern::X0; // # of threads in N dim + constexpr auto N1 = number{}; - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - BK0 * number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + // Get NPerXdl, the warp tile size + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr auto NPerXdl = number{}; - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + // How many elements we can write by single thread to LDS, + // the transposed / shuffled tile dstr has size: + constexpr auto KThreadWrite = TileEncodingPattern::Y2; + constexpr auto K0PerThreadWrite = integer_divide_ceil(BK0, KThreadWrite); + constexpr auto KThreadRead = get_warp_size() / NPerXdl; + constexpr auto K0PerThreadRead = integer_divide_ceil(BK0, KThreadRead); - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc; + // check if we exceed all 32banks width - (32x4B) + constexpr auto LdsBanksWidth = 128; + constexpr auto kfold = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth) + ? 1 + : LdsBanksWidth / (BK1 * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + ((kfold * K0PerThreadWrite / K0PerThreadRead) > 1 && + (kfold * K0PerThreadWrite / K0PerThreadRead) < KThreadRead) + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = + (BK1 * NPerXdl * sizeof(BDataType) > LdsBanksWidth) + ? 1 + : ((LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + BK1), + BK1); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_pass_through_transform( + number{}), + make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2, 3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2, 3>{}, + sequence<4>{}, + sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform( + number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple( + sequence<1>{}, // 0: K0PerThreadWrite + sequence<2>{}, // 1: KThreadReadPerm + sequence<0, 3>{}, // 2: KThreadWrite / kfold / KThreadReadPerm, 3: N1 + sequence<4, 5>{}, // 4: kfold, 5: N0 / npair + sequence<6>{}, // 6: npair + sequence<7>{})); // 7: BK1 + + constexpr auto b_lds_block_desc_nk = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + BK1)), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return b_lds_block_desc_nk; + } + else // B is Column Major + { + constexpr index_t KPack = GetSmemPackB(); + constexpr auto BK0 = number{}; + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = (32 * 4 / KPerBlock / DataTypeSize) < 1 + ? 1 + : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(BK0 * number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + BK0 * number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } } -#else - else // B is Row Major - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t VecLoadSize = GetVectorSizeB(); - using TileEncodingPattern = - tile_distribution_encoding_pattern_2d; - - constexpr auto BK0 = number{}; - constexpr auto BK1 = number{}; - // constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N0 = TileEncodingPattern::X0; - constexpr auto N1 = NPerBlock / N0; - - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr auto NPerXdl = number{}; - - // constexpr auto KThreadWrite = - // BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto KThreadWrite = TileEncodingPattern::Y2; - constexpr auto K0PerThreadWrite = BK0 / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerXdl; - constexpr auto K0PerThreadRead = BK0 / KThreadRead; - - constexpr auto kfold = - (BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1 * NPerXdl * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - BK1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_xor_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(BK1)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, number{})), - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(BK1)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<1>{}, - sequence<2>{}, - sequence<0, 3>{}, - sequence<4, 5>{}, - sequence<6>{}, - sequence<7>{})); - - // constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - // b_lds_block_desc_unmerged, - // make_tuple(make_merge_transform_v3_division_mod( - // make_tuple(number{}, - // number{}, - // number{}, - // number{})), - // make_merge_transform_v3_division_mod( - // make_tuple(number{}, number{}, number{})), - // make_pass_through_transform(BK1)), - // make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}), - // make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}, - number{}, - number{}, - BK1)), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - // return b_lds_block_desc_bk0_n_bk1; - return b_lds_block_desc_kn; - - // constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor( - // make_tuple(BK0, number{}, number{}), - // make_tuple(number{}, number{}, number<1>{}), - // number{}, - // number<1>{}); - - // constexpr auto b_lds_block_desc = transform_tensor_descriptor( - // b_lds_block_desc_bk0_n_bk1, - // make_tuple(make_pass_through_transform(number{}), - // make_merge_transform_v3_division_mod(make_tuple(BK0, - // number{}))), - // make_tuple(sequence<1>{}, sequence<0, 2>{}), - // make_tuple(sequence<0>{}, sequence<1>{})); - - // return b_lds_block_desc; - } -#endif } /**