From d10e451e7eede877a36c16b999555758faf39bee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 13 Nov 2024 11:46:18 +0100 Subject: [PATCH] [CK TILE] Update gemm universal pipeline (#1644) * [CK TILE] Update gemm universal pipeline * Fixes * fix * Rebase [ROCm/composable_kernel commit: d20735691ccb9429ed66f42f831385c709707d62] --- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 377 +++++------------- 1 file changed, 105 insertions(+), 272 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 207f1f9e4b..94b0faf039 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -18,289 +18,136 @@ struct UniversalGemmPipelineAgBgCrPolicy static constexpr bool TransposeC = true; + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorLoadSize() + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; + + if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0) + { + return (16 / sizeof(DataType)); + } + else if constexpr(elements_per_thread % (8 / sizeof(DataType)) == 0) + { + return (8 / sizeof(DataType)); + } + else if constexpr(elements_per_thread % (4 / sizeof(DataType)) == 0 && + sizeof(DataType) >= 4) + { + return (4 / sizeof(DataType)); + } + else if constexpr(elements_per_thread % (2 / sizeof(DataType)) == 0 && + sizeof(DataType) >= 2) + { + return (2 / sizeof(DataType)); + } + else + { + return 1; + } + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { - using WarpGemm = WarpGemmMfmaDispatcher; using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K1 = WarpGemm::kK; - constexpr index_t K0 = KPerBlock / K1; + constexpr index_t KPack = GetVectorLoadSize(); - if constexpr(std::is_same::value) - { - constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(ADataType); - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple(K0 * number{}, number{}, K1), - make_tuple(K1, number{}, I1)); + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(K1)), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); - constexpr auto a_lds_block_desc_ak0_kMLdsLayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(K0, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(K1)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); - constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( - a_lds_block_desc_ak0_kMLdsLayer_m_ak1, - make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - return a_lds_block_desc_m_k; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I0); - constexpr auto M1 = MPerBlock / M0; + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - constexpr auto KThreadWrite = Problem::kBlockSize / M0; - constexpr auto K0PerThreadWrite = K0 / KThreadWrite; - constexpr auto KThreadRead = 64 / WarpGemm::kM; - constexpr auto K0PerThreadRead = K0 / KThreadRead; - - constexpr auto kfold = - (K1 * M0 * sizeof(ADataType) > 128) ? 1 : 128 / (K1 * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=kN0 - constexpr auto mpair = (K1 * WarpGemm::kM * sizeof(ADataType) > 128) - ? 1 - : ((128 / (K1 * WarpGemm::kM * sizeof(ADataType))) > M0 - ? M0 - : 128 / (K1 * WarpGemm::kM * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - K1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_xor_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(K1)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, number{})), - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(K1)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<1>{}, - sequence<2>{}, - sequence<0, 3>{}, - sequence<4, 5>{}, - sequence<6>{}, - sequence<7>{})); - - constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}, - number{}, - number{}, - K1)), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - return a_lds_block_desc_m_k; - } + return a_lds_block_desc; } template CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using WarpGemm = WarpGemmMfmaDispatcher; using BDataType = remove_cvref_t; - using BLayout = remove_cvref_t; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetVectorLoadSize(); - constexpr index_t K1 = WarpGemm::kK; - constexpr index_t K0 = KPerBlock / K1; + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - if constexpr(std::is_same::value) - { - // NLdsLayer * K0 as logical Bank - constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(BDataType); - ; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple(K0 * number{}, number{}, K1), - make_tuple(K1, number{}, I1)); + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(K1)), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); - constexpr auto b_lds_block_desc_bk0_kNLdsLayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(K0, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(K1)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( - b_lds_block_desc_bk0_kNLdsLayer_n_bk1, - make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - return b_lds_block_desc_n_k; - } - else // RowMajor B - { - constexpr auto N0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = Problem::kBlockSize / N0; - constexpr auto K0PerThreadWrite = K0 / KThreadWrite; - constexpr auto KThreadRead = 64 / WarpGemm::kN; - constexpr auto K0PerThreadRead = K0 / KThreadRead; - - constexpr auto kfold = - (K1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (K1 * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=kN0 - constexpr auto npair = (K1 * WarpGemm::kN * sizeof(BDataType) > 128) - ? 1 - : ((128 / (K1 * WarpGemm::kN * sizeof(BDataType))) > N0 - ? N0 - : 128 / (K1 * WarpGemm::kN * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - K1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_xor_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(K1)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, number{})), - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(K1)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<1>{}, - sequence<2>{}, - sequence<0, 3>{}, - sequence<4, 5>{}, - sequence<6>{}, - sequence<7>{})); - - constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}, - number{}, - number{}, - K1)), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - return b_lds_block_desc_n_k; - } + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; } template @@ -330,20 +177,6 @@ struct UniversalGemmPipelineAgBgCrPolicy return smem_size; } - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() - { - using ADataType = remove_cvref_t; - return Problem::VectorLoadSize / sizeof(ADataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() - { - using BDataType = remove_cvref_t; - return Problem::VectorLoadSize / sizeof(BDataType); - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { @@ -362,7 +195,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % M1 == 0); constexpr index_t K3 = total_pixels / M1; - constexpr index_t KPack = GetSmemPackA(); + constexpr index_t KPack = GetVectorLoadSize(); static_assert(KPack % K3 == 0); constexpr index_t K2 = KPack / K3; if constexpr(get_warp_size() % (K2 * M0) == 0) @@ -445,7 +278,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % N1 == 0); constexpr index_t K3 = total_pixels / N1; - constexpr index_t KPack = GetSmemPackB(); + constexpr index_t KPack = GetVectorLoadSize(); static_assert(KPack % K3 == 0); constexpr index_t K2 = KPack / K3; if constexpr(get_warp_size() % (K2 * N0) == 0) @@ -530,7 +363,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % M1 == 0); constexpr index_t K3 = total_pixels / M1; - constexpr index_t kKPack = GetSmemPackB(); + constexpr index_t kKPack = GetVectorLoadSize(); static_assert(kKPack % K3 == 0); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave constexpr index_t warp_size = get_warp_size(); @@ -578,7 +411,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % N1 == 0); constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemPackB(); + constexpr index_t kKPack = GetVectorLoadSize(); static_assert(kKPack % K3 == 0); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave constexpr index_t warp_size = get_warp_size();