diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index e3c8d72590..569afed256 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -1,4 +1,3 @@ - // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. @@ -282,7 +281,11 @@ int main(int argc, char* argv[]) using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; - using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using CodegenGemmPolicy = ck_tile:: + UniversalGemmPipelineAgBgCrPolicy; + + using CodegenGemmPipeline = + ck_tile::GemmPipelineAGmemBGmemCRegV1; invoke_gemm +struct UniversalGemmPipelineAgBgCrPolicy +{ + using LayoutA = remove_cvref_t; + using LayoutB = remove_cvref_t; + using LayoutC = remove_cvref_t; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + static constexpr bool TransposeC = true; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using WarpGemm = WarpGemmMfmaDispatcher; + + using ADataType = remove_cvref_t; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t K1 = WarpGemm::kK; + constexpr index_t K0 = KPerBlock / K1; + + if constexpr(std::is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple(K0 * number{}, number{}, K1), + make_tuple(K1, number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(K1)), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_kMLdsLayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(K0, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( + a_lds_block_desc_ak0_kMLdsLayer_m_ak1, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return a_lds_block_desc_m_k; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I0); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = Problem::kBlockSize / M0; + constexpr auto K0PerThreadWrite = K0 / KThreadWrite; + constexpr auto KThreadRead = 64 / WarpGemm::kM; + constexpr auto K0PerThreadRead = K0 / KThreadRead; + + constexpr auto kfold = + (K1 * M0 * sizeof(ADataType) > 128) ? 1 : 128 / (K1 * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=kN0 + constexpr auto mpair = (K1 * WarpGemm::kM * sizeof(ADataType) > 128) + ? 1 + : ((128 / (K1 * WarpGemm::kM * sizeof(ADataType))) > M0 + ? M0 + : 128 / (K1 * WarpGemm::kM * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + K1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + K1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return a_lds_block_desc_m_k; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + using WarpGemm = WarpGemmMfmaDispatcher; + + using BDataType = remove_cvref_t; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = WarpGemm::kK; + constexpr index_t K0 = KPerBlock / K1; + + if constexpr(std::is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple(K0 * number{}, number{}, K1), + make_tuple(K1, number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(K1)), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_kNLdsLayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(K0, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( + b_lds_block_desc_bk0_kNLdsLayer_n_bk1, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return b_lds_block_desc_n_k; + } + else // RowMajor B + { + constexpr auto N0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = Problem::kBlockSize / N0; + constexpr auto K0PerThreadWrite = K0 / KThreadWrite; + constexpr auto KThreadRead = 64 / WarpGemm::kN; + constexpr auto K0PerThreadRead = K0 / KThreadRead; + + constexpr auto kfold = + (K1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (K1 * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=kN0 + constexpr auto npair = (K1 * WarpGemm::kN * sizeof(BDataType) > 128) + ? 1 + : ((128 / (K1 * WarpGemm::kN * sizeof(BDataType))) > N0 + ? N0 + : 128 / (K1 * WarpGemm::kN * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + K1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + K1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return b_lds_block_desc_n_k; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size(); + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() + { + constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * + MakeBLdsBlockDescriptor().get_element_space_size(); + return smem_size_b; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + constexpr index_t smem_size_b = GetSmemSizeB(); + index_t smem_size = 0; + smem_size += smem_size_a + smem_size_b; + + return smem_size; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using WarpGemm = WarpGemmMfmaDispatcher; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = WarpGemm::kK; + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using WarpGemm = WarpGemmMfmaDispatcher; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = WarpGemm::kK; + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + + constexpr index_t N1 = BlockSize / get_warp_size(); + static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = NPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + return BlockGemmASmemBSmemCRegV1{}; + } +}; + +} // namespace ck_tile