From 3a9bd7c1ff5a5fa2d09cce670a2f24490575acc7 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 14 Oct 2025 08:43:14 -0700 Subject: [PATCH] Revert "[CK_TILE] Non-K Major from old CK to CK-Tile (#2442)" (#3017) This reverts commit 3653a0d01edd715f5c9759eaf547a7644c762b8e. [ROCm/composable_kernel commit: e4298e55c75ddb3931aec2053d41a0099e3a4549] --- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 520 +++++++----------- 1 file changed, 202 insertions(+), 318 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 89e0346961..4030783ecc 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -73,14 +73,10 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { - using ALayout = remove_cvref_t; - using ADataType = remove_cvref_t; using ADataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackA(); - constexpr auto DataTypeSize = sizeof(ADataType); if constexpr(is_a_load_tr) { @@ -94,168 +90,47 @@ struct UniversalGemmBasePolicy } else { - // Only use this ColumnMajor layout for Wave64 mode (gfx9) - constexpr auto Wave64 = get_warp_size() == 64; - if constexpr(Wave64 && - std::is_same_v) - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg - // offset for compiler. - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t VecLoadSize = GetVectorSizeA(); - using TileEncodingPattern = - tile_distribution_encoding_pattern_2d; - // AK1 - constexpr auto AK1 = number{}; - constexpr auto AK0 = number{}; - // How the M dimension is split across threads - constexpr auto M0 = TileEncodingPattern::X0; // # of threads in M dim - constexpr auto M1 = number{}; + constexpr index_t KPack = GetSmemPackA(); - // Get the warp tile size - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr auto MPerXdl = number{}; + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - // How many elements we can write by single thread to LDS, - // the transposed / shuffled tile dstr has size: - constexpr auto KThreadWrite = TileEncodingPattern::Y2; - constexpr auto K0PerThreadWrite = integer_divide_ceil(AK0, KThreadWrite); - constexpr auto KThreadRead = get_warp_size() / MPerXdl; - constexpr auto K0PerThreadRead = integer_divide_ceil(AK0, KThreadRead); + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); - constexpr auto LdsBanksWidth = 128; - constexpr auto kfold = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth) - ? 1 - : LdsBanksWidth / (AK1 * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - ((kfold * K0PerThreadWrite / K0PerThreadRead) > 1 && - (kfold * K0PerThreadWrite / K0PerThreadRead) < KThreadRead) - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); - // 1<=mpair<=n0 - constexpr auto mpair = - (AK1 * MPerXdl * sizeof(ADataType) > LdsBanksWidth) - ? 1 - : ((LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType))) > M0 - ? M0 - : LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType))); + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - AK1), - AK1); + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_pass_through_transform( - number{}), - make_pass_through_transform(number{}), - make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(AK1)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2, 3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2, 3>{}, - sequence<4>{}, - sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform( - number{}), - make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, number{})), - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(AK1)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<1>{}, - sequence<2>{}, - sequence<0, 3>{}, - sequence<4, 5>{}, - sequence<6>{}, - sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}, - number{}, - number{}, - AK1)), - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - else // A is in RowMajor - { - constexpr auto MLdsLayer = (32 * 4 / KPerBlock / DataTypeSize) < 1 - ? 1 - : (32 * 4 / KPerBlock / DataTypeSize); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple( - make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform_v3_division_mod(make_tuple( - number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; - } + return a_lds_block_desc; } } @@ -268,12 +143,12 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using BLayout = remove_cvref_t; using BDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; +#if 1 if constexpr(is_b_load_tr) { // TODO: better lds descriptor for performance @@ -285,169 +160,178 @@ struct UniversalGemmBasePolicy return b_lds_block_desc_0; } else + // else if constexpr(std::is_same_v) { - // Only use this RowMajor layout for Wave64 mode (gfx9) - constexpr auto Wave64 = get_warp_size() == 64; - if constexpr(Wave64 && std::is_same_v) - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t VecLoadSize = GetVectorSizeB(); - using TileEncodingPattern = - tile_distribution_encoding_pattern_2d; - // BK1 - constexpr auto BK1 = number{}; - constexpr auto BK0 = number{}; + constexpr index_t KPack = GetSmemPackB(); + constexpr auto BK0 = number{}; + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - // How threads access data on N dim - constexpr auto N0 = TileEncodingPattern::X0; // # of threads in N dim - constexpr auto N1 = number{}; + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple( + BK0 * number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); - // Get NPerXdl, the warp tile size - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr auto NPerXdl = number{}; + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + BK0 * number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); - // How many elements we can write by single thread to LDS, - // the transposed / shuffled tile dstr has size: - constexpr auto KThreadWrite = TileEncodingPattern::Y2; - constexpr auto K0PerThreadWrite = integer_divide_ceil(BK0, KThreadWrite); - constexpr auto KThreadRead = get_warp_size() / NPerXdl; - constexpr auto K0PerThreadRead = integer_divide_ceil(BK0, KThreadRead); + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - // check if we exceed all 32banks width - (32x4B) - constexpr auto LdsBanksWidth = 128; - constexpr auto kfold = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth) - ? 1 - : LdsBanksWidth / (BK1 * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - ((kfold * K0PerThreadWrite / K0PerThreadRead) > 1 && - (kfold * K0PerThreadWrite / K0PerThreadRead) < KThreadRead) - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = - (BK1 * NPerXdl * sizeof(BDataType) > LdsBanksWidth) - ? 1 - : ((LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType))) > N0 - ? N0 - : LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - BK1), - BK1); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_pass_through_transform( - number{}), - make_pass_through_transform(number{}), - make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(BK1)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2, 3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2, 3>{}, - sequence<4>{}, - sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform( - number{}), - make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, number{})), - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(BK1)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple( - sequence<1>{}, // 0: K0PerThreadWrite - sequence<2>{}, // 1: KThreadReadPerm - sequence<0, 3>{}, // 2: KThreadWrite / kfold / KThreadReadPerm, 3: N1 - sequence<4, 5>{}, // 4: kfold, 5: N0 / npair - sequence<6>{}, // 6: npair - sequence<7>{})); // 7: BK1 - - constexpr auto b_lds_block_desc_nk = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}, - number{}, - number{}, - BK1)), - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - return b_lds_block_desc_nk; - } - else // B is Column Major - { - constexpr index_t KPack = GetSmemPackB(); - constexpr auto BK0 = number{}; - constexpr auto DataTypeSize = sizeof(BDataType); - constexpr auto NLdsLayer = (32 * 4 / KPerBlock / DataTypeSize) < 1 - ? 1 - : (32 * 4 / KPerBlock / DataTypeSize); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(BK0 * number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - BK0 * number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple( - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc; - } + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; } +#else + else // B is Row Major + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t VecLoadSize = GetVectorSizeB(); + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + + constexpr auto BK0 = number{}; + constexpr auto BK1 = number{}; + // constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N0 = TileEncodingPattern::X0; + constexpr auto N1 = NPerBlock / N0; + + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr auto NPerXdl = number{}; + + // constexpr auto KThreadWrite = + // BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto KThreadWrite = TileEncodingPattern::Y2; + constexpr auto K0PerThreadWrite = BK0 / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0 / KThreadRead; + + constexpr auto kfold = + (BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1 * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + BK1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + // constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + // b_lds_block_desc_unmerged, + // make_tuple(make_merge_transform_v3_division_mod( + // make_tuple(number{}, + // number{}, + // number{}, + // number{})), + // make_merge_transform_v3_division_mod( + // make_tuple(number{}, number{}, number{})), + // make_pass_through_transform(BK1)), + // make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}), + // make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + BK1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // return b_lds_block_desc_bk0_n_bk1; + return b_lds_block_desc_kn; + + // constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor( + // make_tuple(BK0, number{}, number{}), + // make_tuple(number{}, number{}, number<1>{}), + // number{}, + // number<1>{}); + + // constexpr auto b_lds_block_desc = transform_tensor_descriptor( + // b_lds_block_desc_bk0_n_bk1, + // make_tuple(make_pass_through_transform(number{}), + // make_merge_transform_v3_division_mod(make_tuple(BK0, + // number{}))), + // make_tuple(sequence<1>{}, sequence<0, 2>{}), + // make_tuple(sequence<0>{}, sequence<1>{})); + + // return b_lds_block_desc; + } +#endif } /**