From b5d6100bbcc06646b287c5a457c90644515739f4 Mon Sep 17 00:00:00 2001 From: letaoqin Date: Fri, 22 Nov 2024 04:24:07 +0000 Subject: [PATCH] change file name --- .../instances/fused_moegemm_api_internal.hpp | 2 +- example/ck_tile/16_fused_moe_general/main.cpp | 2 +- include/ck_tile/ops/fused_moe.hpp | 3 +- ...m_gl.hpp => fused_moegemm_pipeline_gl.hpp} | 33 +- .../fused_moegemm_pipeline_gl_policy.hpp | 887 ++++++++++++++++++ 5 files changed, 911 insertions(+), 16 deletions(-) rename include/ck_tile/ops/fused_moe/pipeline/{fused_moegemm_pipeline_flatmm_gl.hpp => fused_moegemm_pipeline_gl.hpp} (83%) create mode 100644 include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl_policy.hpp diff --git a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_internal.hpp b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_internal.hpp index 8b47491485..3cdf98e49e 100644 --- a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_internal.hpp +++ b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_internal.hpp @@ -38,7 +38,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) f_traits>; // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx; - using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmGl; + using f_pipeline = ck_tile::FusedMoeGemmPipeline_General; using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear; using f_kernel = ck_tile::FusedMoeGemmGlKernel; diff --git a/example/ck_tile/16_fused_moe_general/main.cpp b/example/ck_tile/16_fused_moe_general/main.cpp index 84ed4f3ff1..0b36090686 100644 --- a/example/ck_tile/16_fused_moe_general/main.cpp +++ b/example/ck_tile/16_fused_moe_general/main.cpp @@ -256,7 +256,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // } // std::cout << std::endl; // } - std::cout << sorted_token_ids_host << std::endl; + // std::cout << sorted_token_ids_host << std::endl; // std::cout << num_sorted_tiles_host << std::endl; // std::cout << sorted_expert_ids_host << std::endl; // std::cout << topk_weight_host << std::endl; diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index e454770617..1da9f092f7 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -12,7 +12,8 @@ #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp" -#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl.hpp similarity index 83% rename from include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp rename to include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl.hpp index 92c1176c1a..425626ec9f 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl.hpp @@ -5,7 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl_policy.hpp" namespace ck_tile { @@ -18,8 +18,8 @@ we need to design the pipeline such that all waves along gemm-N dim (gemm-m only | w0 | w1 | w2 | w3 | gemm-m +----+----+----+----+ */ -template -struct FusedMoeGemmPipeline_FlatmmGl +template +struct FusedMoeGemmPipeline_General { using Problem = remove_cvref_t; using Policy = remove_cvref_t; @@ -70,14 +70,15 @@ struct FusedMoeGemmPipeline_FlatmmGl CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - // matrix a or tokens smem - constexpr index_t smem_mat_a = - BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType); - // shuffle C matrix - constexpr index_t smem_bridge = - BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); + // // matrix a or tokens smem + // constexpr index_t smem_mat_a = + // BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType); + // // shuffle C matrix + // constexpr index_t smem_bridge = + // BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); - return max(smem_mat_a, smem_bridge); + // return max(smem_mat_a, smem_bridge); + return Policy::template GetSmemSize(); } // this is the thread-offset along row/col @@ -104,12 +105,19 @@ struct FusedMoeGemmPipeline_FlatmmGl ignore = hidden_size; ignore = intermediate_size; - auto a_copy_dram_window = make_tile_window( + CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast(smem); + auto a_lds_view = make_tensor_view( + smem_0, Policy::template MakeLdsStoreDesc_A()); + auto a_lds_win = make_tile_window(a_lds_view, make_tuple(number{}, number{}), {0, 0}); + + auto a_global_to_dram_window = make_tile_window( a_window_.get_bottom_tensor_view(), make_tuple(number{}, number{}), a_window_.get_window_origin(), Policy::template MakeGlobalTileDistribution_A()); - auto a_dram = load_tile(a_copy_dram_window); + auto a_dram_block = load_tile(a_global_to_dram_window); + + store_tile(a_lds_win, a_dram_block); #if 0 //check a matrix gather right or not constexpr auto a_spans = decltype(a_dram)::get_distributed_spans(); @@ -126,7 +134,6 @@ struct FusedMoeGemmPipeline_FlatmmGl }); }); #endif - ignore = a_dram; } }; diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl_policy.hpp new file mode 100644 index 0000000000..09399d1975 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl_policy.hpp @@ -0,0 +1,887 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" +#include "ck_tile/ops/flatmm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +namespace ck_tile { + +struct FusedMoeGemmPipelineGeneralPolicy +{ + CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords() + { + // TODO: always 1 dword + return 1; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A() + { + // using async + constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords(); + constexpr index_t data_bytes = sizeof(typename Problem::ADataType); + static_assert(copy_bytes % data_bytes == 0); + return copy_bytes / data_bytes; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G() + { + constexpr index_t copy_bytes = [&]() { return 16; }(); + constexpr index_t data_bytes = sizeof(typename Problem::GDataType); + static_assert(copy_bytes % data_bytes == 0); + return copy_bytes / data_bytes; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D() + { + constexpr index_t copy_bytes = [&]() { return 16; }(); + constexpr index_t data_bytes = sizeof(typename Problem::DDataType); + static_assert(copy_bytes % data_bytes == 0); + return copy_bytes / data_bytes; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_O() + { + if constexpr(Problem::Traits::OAtomic == 1) + { + // pack fp16/bf16 atomic + static_assert(sizeof(typename Problem::ODataType) == 2); + return 2; + } + else if constexpr(Problem::Traits::OAtomic == 2) + { + // fp32 atomic + return 1; + } + else + { + return 16 / sizeof(typename Problem::ODataType); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack() + { + // TODO: this is for 3d layout + return 16 / sizeof(remove_cvref_t); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_A() + { + return GetSmemKPack(); + } + + // used for bridge LDS shuffle + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_Y() + { + // TODO: this should match mfma layout + return 16 / sizeof(typename Problem::YDataType); + } + +#if 0 + template + CK_TILE_HOST_DEVICE static constexpr auto GetWaveFlattenShape() + { + using WarpGemm = GetWarpGemm0{}; // assume warpgemm0/1 are the same + + constexpr index_t Kv = GetAlignment_G<{Problem}>(); + constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + return sequence{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockTileNrKr() + { + using WarpGemm = GetWarpGemm0{}; // assume warpgemm0/1 are the same + + constexpr index_t Kv = GetAlignment_G<{Problem}>(); + constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + return sequence{}; + } +#endif + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A() + { + constexpr auto a_sld_desc = MakeLdsLoadDesc_A(); + constexpr auto a_sst_desc = MakeLdsStoreDesc_A(); + static_assert(a_sld_desc.get_element_space_size() == a_sst_desc.get_element_space_size()); + return a_sld_desc.get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge() + { + constexpr auto bridge_sld_desc = MakeBridgeLdsLoadDesc(); + constexpr auto bridge_sst_desc = MakeBridgeLdsStoreDesc(); + static_assert(bridge_sld_desc.get_element_space_size() == + bridge_sst_desc.get_element_space_size()); + return bridge_sld_desc.get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + constexpr index_t a_lds = GetSmemSize_A(); + constexpr index_t bridge_lds = GetSmemSize_Bridge(); + return max(a_lds, bridge_lds); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK() + { + constexpr index_t K_vec = Alignment; + constexpr index_t K_rem = KPerBlock / K_vec; + + if constexpr(get_warp_size() < K_rem) + { + static_assert(K_rem % get_warp_size() == 0); + constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k + constexpr index_t K_wav = K_rem / get_warp_size(); + static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet"); + constexpr index_t M_wav = NumWarps / K_wav; + static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check"); + constexpr index_t M_rep = MPerBlock / M_wav; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2>>, + tuple, sequence<1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else + { + constexpr index_t K_lan = K_rem; + constexpr index_t M_lan = get_warp_size() / K_lan; + constexpr index_t M_wav = NumWarps; + static_assert(MPerBlock % (M_lan * M_wav) == 0, + "this tile size is too small please check"); + constexpr index_t M_rep = MPerBlock / (M_lan * M_wav); + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } + + // optimized version for async, not same as simple MXK dist(pay attention!!) + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK_Async() + { + constexpr index_t K_vec = Alignment; + constexpr index_t K_rem = KPerBlock / K_vec; + + if constexpr(get_warp_size() <= K_rem) + { + static_assert(K_rem % get_warp_size() == 0); + constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k + constexpr index_t K_wav = K_rem / get_warp_size(); + static_assert(K_wav <= NumWarps, "do not support thread has repeat along K yet"); + constexpr index_t M_wav = NumWarps / K_wav; + static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check"); + constexpr index_t M_rep = MPerBlock / M_wav; + // NOTE: no swap, but hard to avoid LDS bank conflict + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2>>, + tuple, sequence<1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else + { + constexpr index_t K_lan = K_rem; + constexpr index_t M_lan = get_warp_size() / K_lan; + constexpr index_t M_wav = NumWarps; + static_assert(MPerBlock % (M_lan * M_wav) == 0, + "this tile size is too small please check"); + constexpr index_t M_rep = MPerBlock / (M_lan * M_wav); + // NOTE: swapped for LDS load bank conflict free + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + // Note M_wave(num waves) is the fastest dim, different from sipmle 2d + // distribution + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } + +#if 0 + // Caution: this will require global memory pre-shuffled to follow the mfma layout + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_MatrixCore_Swizzled() + { + static_assert(Alignment % WarpGemm::WarpGemmAttribute::Impl::kABKPerLane == 0); + + if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) + { + constexpr index_t Kv = Alignment; + constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + + static_assert(KPerBlock % (K1 * K2) == 0); + constexpr index_t Nr = NPerBlock / Nw; + constexpr index_t Kr = KPerBlock / (Kv * Kw); + + constexpr index_t Nr_p = WavesPerBlock_N; + constexpr index_t Kr_p = WavesPerBlock_K; + constexpr index_t Nr_y = Nr / Nr_p; + constexpr index_t Kr_y = Kr / Kr_p; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, // 0 + // major 1 2 3 + // minor 0 1 0 1 0 1 2 + tuple, sequence, sequence>, + + // Nr_p, Kr_p Kw Nw + tuple, sequence<3, 3>>, + tuple, sequence<0, 1>>, + + // Nr_y Kr_y Kv + sequence<1, 2, 3>, + sequence<0, 0, 2>>{}); + // clang-format on + } + } +#endif + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_Nr_Kr_W() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence, + sequence>, + tuple, sequence<3>>, + tuple, sequence<0>>, + sequence<1, 2, 3>, + sequence<0, 0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A() + { + constexpr index_t Block_M_ = Problem::BlockShape::Block_M0; + constexpr index_t Block_K_ = Problem::BlockShape::Block_K0; + constexpr index_t NumWarps_ = Problem::BlockShape::NumWarps; + constexpr index_t Alignment_ = GetAlignment_A(); + return MakeGlobalTileDistribution_SimpleMxK_Async(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G() + { + constexpr auto PermuteEnum = Problem::Traits::PermuteEnum; + // constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2; + using S_ = typename Problem::BlockShape; + if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) + { + // number{}.rrr(); + // number{}.eee(); + return MakeGlobalTileDistribution_Nr_Kr_W()>(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D() + { + constexpr auto PermuteEnum = Problem::Traits::PermuteEnum; + using S_ = typename Problem::BlockShape; + if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) + { + return MakeGlobalTileDistribution_Nr_Kr_W()>(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O() + { + using S_ = remove_cvref_t; + using WarpGemm = remove_cvref_t())>; + // using CDataType = typename WarpGemm::CDataType; + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + return c_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A() + { + // A async->LDS + constexpr index_t Block_M = Problem::BlockShape::Block_M0; + constexpr index_t Block_K = Problem::BlockShape::Block_K0; + // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; + constexpr index_t warpSize = ck_tile::get_warp_size(); + constexpr index_t NumWarps = Problem::BlockShape::NumWarps; + + constexpr index_t KPack = GetSmemKPack_A(); // LDS + constexpr index_t KVector = GetAlignment_A(); // async copy 1 dword + constexpr index_t KPad = KPack; // pad between warps + + static_assert(Block_K % KVector == 0); + constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K + if constexpr(LanesPerK >= warpSize) + { + // need multiple waves to load K + static_assert(LanesPerK % warpSize == 0); + constexpr index_t wavesPerK = LanesPerK / warpSize; + if constexpr(wavesPerK > NumWarps) + { + // TODO: need multiple issues along K to load all data + } + else + { + constexpr index_t wavesPerM = NumWarps / wavesPerK; + constexpr index_t NumIssues = Block_M / wavesPerM; + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number{}), // k2 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number<1>{}), // k2 + number{}, // lds store vector(actually no explicit store) + number<1>{}); + + constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return lds_block_desc_issues_warps_lanes; + } + } + else + { + // lanes within a wave load different M but same K + static_assert(warpSize % LanesPerK == 0); + constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); + + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number<1>{}), // k1 + number{}, // lds store vector(actually no explicit store) + number<1>{}); + + constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return lds_block_desc_issues_warps_lanes; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A() + { + // A async->LDS + // Note that, this descriptor is only to construct the layout inside LDS + // in real Gemm pipeline, ds_read may not follow this pattern + // (may follow that in tile_distribution) + // below code is almost the same as SmemStore dist, with difference: + // 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc + // 2). return discriptor is in NxK 2d layout + constexpr index_t Block_M = Problem::BlockShape::Block_M0; + constexpr index_t Block_K = Problem::BlockShape::Block_K0; + // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; + constexpr index_t warpSize = ck_tile::get_warp_size(); + constexpr index_t NumWarps = Problem::BlockShape::NumWarps; + + constexpr index_t KPack = GetSmemKPack_A(); // LDS + constexpr index_t KVector = GetAlignment_A(); // async copy 1 dword + constexpr index_t KPad = KPack; // pad between warps + + static_assert(Block_K % KVector == 0); + constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K + if constexpr(LanesPerK >= warpSize) + { + // need multiple waves to load K + static_assert(LanesPerK % warpSize == 0); + constexpr index_t wavesPerK = LanesPerK / warpSize; + if constexpr(wavesPerK >= NumWarps) + { + // TODO: need multiple issues along K to load all data + } + else + { + constexpr index_t wavesPerM = NumWarps / wavesPerK; + constexpr index_t NumIssues = Block_M / wavesPerM; + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number{}), // k2 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number<1>{}), // k2 + number{}, // lds load vector + number<1>{}); + + constexpr auto lds_desc_m_k = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lds_desc_m_k; + } + } + else + { + // lanes within a wave load different M but same K + static_assert(warpSize % LanesPerK == 0); + constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); + + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number<1>{}), // k1 + number{}, // lds load vector + number<1>{}); + + constexpr auto lds_desc_m_k = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lds_desc_m_k; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsLoadDesc() + { + constexpr index_t Block_M = Problem::BlockShape::Block_M0; + constexpr index_t Block_N = Problem::BlockShape::Block_N0; + + constexpr index_t KVector = GetSmemKPack_Y(); // async copy 1 dword + constexpr index_t KPad = 0; // pad between warps + + constexpr auto desc = + make_naive_tensor_descriptor(make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreDesc() + { + constexpr index_t Block_M = Problem::BlockShape::Block_M0; + constexpr index_t Block_N = Problem::BlockShape::Block_N0; + + constexpr index_t KVector = GetSmemKPack_Y(); // async copy 1 dword + constexpr index_t KPad = 0; // KVector; // pad between warps + + constexpr auto desc = + make_naive_tensor_descriptor(make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreForUKDesc() + { + constexpr index_t WarpPerBlock_N = Problem::BlockShape::WarpPerBlock_N0; + constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N0; + constexpr index_t Repeat_M = Problem::BlockShape::Repeat_M0; + + constexpr index_t kAMLane = 16; + constexpr index_t kABKLane = 4; + constexpr index_t kABKPerLane = 4; + + constexpr index_t KPack = kABKPerLane; + + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m + number{}, // n + number{}, // n + number{}, // n + number{}, // m + number{}), // n + make_tuple(number{}, // m + number{}, // n + number{}, // n + number{}, // n + number{}, // m + number<1>{}), // n + number{}, // lds store vector(actually no explicit store) + number<1>{}); + + constexpr auto desc = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple(make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, + number{}, + number{}, + number{}))), + make_tuple(sequence<0, 4>{}, sequence<1, 2, 3, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0() + { + using S_ = typename Problem::BlockShape; + // A is vgpr, B is agpr. But since we transposed, so also need swap this + // TODO: this is ugly + constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv; + // TODO: ugly + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16) + { + return WarpGemmImpl, + 2>>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32) + { + return WarpGemmImpl, + 2>>{}; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_0() + { + // this function return seq<...> used to identify gld/sld/valu... inside mfma sequence + // the purpose is to hide thoes instructions under mfma + // every value inside seq<...> is a mask, indicating a specific operation + using S_ = typename Problem::BlockShape; + constexpr index_t SLD_A = static_cast(FusedMoeGemmPipelineSequencerEnum::SLD_A); + constexpr index_t GLD_A = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_A); + constexpr index_t GLD_B = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_B); + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 && + S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 && + S_::Block_N1 == 128) + { + // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async + // gld_a 8x ds_read_b128 sld_a total 64 slot :) + // clang-format off + constexpr auto seq_all = + // 0 1 2 3 4 5 6 7 + sequence{}; // 7 + return seq_all; + // clang-format on + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 && + S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 && + S_::Block_N1 == 128) + { + // Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async + // gld_a 8x ds_read_b128 sld_a total 64 slot :) + // clang-format off + constexpr auto seq_all = + // 0 1 2 3 4 5 6 7 + sequence{}; // 3 + return seq_all; + // clang-format on + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_1() + { + // this function return seq<...> used to identify gld/sld/valu... inside mfma sequence + // the purpose is to hide thoes instructions under mfma + // every value inside seq<...> is a mask, indicating a specific operation + using S_ = typename Problem::BlockShape; + constexpr index_t GLD_B = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_B); + constexpr index_t GST_O = static_cast(FusedMoeGemmPipelineSequencerEnum::GST_O); + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 && + S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 && + S_::Block_N1 == 128) + { + // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async + // gld_a 8x ds_read_b128 sld_a total 64 slot :) + // clang-format off + constexpr auto seq_all = + // 0 1 2 3 4 5 6 7 + sequence{}; // 7 + return seq_all; + // clang-format on + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 && + S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 && + S_::Block_N1 == 128) + { + // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async + // gld_a 8x ds_read_b128 sld_a total 64 slot :) + // clang-format off + constexpr auto seq_all = + // 0 1 2 3 4 5 6 7 + sequence{}; // 3 + return seq_all; + // clang-format on + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1() + { + using S_ = typename Problem::BlockShape; + constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv; + // TODO: ugly + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16) + { + return WarpGemmImpl, + 2>>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32) + { + return WarpGemmImpl, + 2>>{}; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm0() + { + using S_ = remove_cvref_t; + using WarpGemm = remove_cvref_t())>; + using CDataType = typename WarpGemm::CDataType; + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm1() + { + using S_ = remove_cvref_t; + using WarpGemm = remove_cvref_t())>; + using CDataType = typename WarpGemm::CDataType; + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // this is used as A matrix for 2nd gemm + template + CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution() + { + using S_ = remove_cvref_t; + using WarpGemm = remove_cvref_t())>; + + // TODO: all waves a along different N, but same M + constexpr auto y_outer_dstr_enc = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{}); + constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode); + return y_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeYBlockTile() + { + constexpr auto y_block_dstr = MakeYTileDistribution(); + auto y_block_tensor = + make_static_distributed_tensor(y_block_dstr); + return y_block_tensor; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetUK_0() + { + using S_ = typename Problem::BlockShape; + if constexpr(std::is_same_v && + std::is_same_v && + S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 && + S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32) + { + return FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16{}; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetUK_1() + { + using S_ = typename Problem::BlockShape; + if constexpr(std::is_same_v && + std::is_same_v && + S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 && + S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32) + { + return FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16{}; + } + } +}; +} // namespace ck_tile