diff --git a/example/ck_tile/16_fused_moe_general/fused_moegemm.hpp b/example/ck_tile/16_fused_moe_general/fused_moegemm.hpp index 0abfb28d2b..86cfb3d25d 100644 --- a/example/ck_tile/16_fused_moe_general/fused_moegemm.hpp +++ b/example/ck_tile/16_fused_moe_general/fused_moegemm.hpp @@ -13,6 +13,22 @@ template struct FusedMoeGemmTypeConfig; +template +struct FusedMoeGemmTypeConfig +{ + using ADataType = ck_tile::fp16_t; + using GDataType = ck_tile::fp16_t; + using DDataType = ck_tile::fp16_t; + using AccDataType = float; + using ODataType = ck_tile::fp16_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using TopkWeightDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::index_t; +}; + template struct FusedMoeGemmTypeConfig { diff --git a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp index 7aa6185b8c..43dc93b4d9 100644 --- a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp +++ b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp @@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) { - using t_ = fmoe_, S<1, 4, 1>, S<32, 32, 8>, 1, 0>; + using t_ = fmoe_, S<1, 4, 1>, S<32, 32, 8>, 1, 0>; r = fused_moegemm_(s, a); } // clang-format on diff --git a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_internal.hpp b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_internal.hpp index 3cdf98e49e..b66df7dc93 100644 --- a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_internal.hpp +++ b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_internal.hpp @@ -19,8 +19,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) typename Ts_::WarpPerBlock_0, typename Ts_::WarpTile_0, typename Ts_::BlockTile_1, - typename Ts_::WarpPerBlock_0, - typename Ts_::WarpTile_0>; + typename Ts_::WarpPerBlock_1, + typename Ts_::WarpTile_1>; using f_problem = ck_tile::FusedMoeGemmPipelineProblem; - using WarpPerBlock_1 = ck_tile::remove_cvref_t; + using WarpPerBlock_1 = ck_tile::sequence<1, 1, 4>;//ck_tile::remove_cvref_t; using WarpTile_1 = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t GateOnly = GateOnly_; diff --git a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp index 5fd456c1c8..744c5771ed 100644 --- a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp +++ b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp @@ -8,7 +8,7 @@ // clang-format off template float fused_moegemm_< - fmoe_, S<1, 4, 1>, S<32, 32, 8>, 1, 0> + fmoe_, S<1, 4, 1>, S<32, 32, 8>, 1, 0> >(const ck_tile::stream_config& s, fused_moegemm_args a); // clang-format on diff --git a/example/ck_tile/16_fused_moe_general/main.cpp b/example/ck_tile/16_fused_moe_general/main.cpp index dae6e36a43..867c7e0905 100644 --- a/example/ck_tile/16_fused_moe_general/main.cpp +++ b/example/ck_tile/16_fused_moe_general/main.cpp @@ -252,8 +252,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{0.0f, 1.0f}(topk_weight_host); // permute weight - ck_tile::HostTensor g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); - ck_tile::HostTensor d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); + // ck_tile::HostTensor g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); + // ck_tile::HostTensor d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); // do moe sorting if(balance) @@ -287,7 +287,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // std::cout << num_sorted_tiles_host << std::endl; // output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size); std::cout << sorted_expert_ids_host << std::endl; - // std::cout << topk_weight_host << std::endl; + std::cout << topk_weight_host << std::endl; // std::cout << sorted_weight_host << std::endl; @@ -431,7 +431,7 @@ int main(int argc, char* argv[]) // no dynamic quant case if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32") { - return run( + return run( arg_parser) ? 0 : -2; diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp index 45e73ff96e..61167209eb 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp @@ -88,6 +88,32 @@ struct FusedMoeGemmPipeline_General const auto a_coord = a_dist.calculate_index(); return a_coord; } + + template + CK_TILE_HOST_DEVICE static void PrintMem(T& tensor) + { + constexpr auto spans = T::get_distributed_spans(); + int counter = 0; + sweep_tile_span(spans[number<0>{}], [&](auto idxn) { + sweep_tile_span(spans[number<1>{}], [&](auto idxk) { + constexpr auto i_j_idx = make_tuple(idxn, idxk); + const auto tile_idx = + get_x_indices_from_distributed_indices(tensor.get_tile_distribution(), i_j_idx); + if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + { + const auto row = tile_idx.at(number<0>{}); + const auto col = tile_idx.at(number<1>{}); + printf("in G row is %d , col is %d, counter is %d, value is: %f" + " \n", + row, + col, + counter, + ck_tile::type_convert(tensor(i_j_idx))); + counter = counter + 1; + } + }); + }); + } template CK_TILE_DEVICE auto operator()(const AWindow& a_window_, const GWindow& g_window_, @@ -131,56 +157,13 @@ struct FusedMoeGemmPipeline_General auto a_dram_block = load_tile(a_global_to_dram_window); store_tile(a_lds_win, a_dram_block); #if 0 - { - // check a matrix gather right or not - constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans(); - int counter = 0; - sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) { - sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) { - constexpr auto i_j_idx = make_tuple(idxm, idxk); - if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) - { - counter = counter + 1; - index_t idm_0 = idxm.impl_.at(0); - index_t idk_0 = idxk.impl_.at(0); - printf("in A idm is %d , idk_ is %d , counter is %d, value is: %f \n", - idm_0, - idk_0, - counter, - ck_tile::type_convert(a_dram_block(i_j_idx))); - } - }); - }); - } + PrintMem(a_dram_block); #endif auto g_dram_block = load_tile(g_global_to_dram_window); #if 0 - { - constexpr auto g_spans = decltype(g_dram_block)::get_distributed_spans(); - int counter = 0; - sweep_tile_span(g_spans[number<0>{}], [&](auto idxn) { - sweep_tile_span(g_spans[number<1>{}], [&](auto idxk) { - constexpr auto i_j_idx = make_tuple(idxn, idxk); - const auto tile_idx = get_x_indices_from_distributed_indices( - g_dram_block.get_tile_distribution(), i_j_idx); - if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) - { - counter = counter + 1; - const auto row = tile_idx.at(number<0>{}); - const auto col = tile_idx.at(number<1>{}); - printf("in G row is %d , col is %d, counter is %d, value is: %f" - " \n", - row, - col, - counter, - ck_tile::type_convert(g_dram_block(i_j_idx))); - } - }); - }); - } - + PrintMem(g_dram_block); #endif clear_tile(s_acc); // initialize C @@ -215,32 +198,8 @@ struct FusedMoeGemmPipeline_General // activation(s_acc.get_thread_buffer()(i),s_acc.get_thread_buffer()[i]); // }); tile_elementwise_inout(activation, s_acc, s_acc); -#if 1 - { - constexpr auto a_spans = decltype(s_acc)::get_distributed_spans(); - int counter = 0; - // a_spans[0] = 1; - sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) { - sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) { - constexpr auto i_j_idx = make_tuple(idxm, idxn); - const auto tile_idx = get_x_indices_from_distributed_indices( - g_dram_block.get_tile_distribution(), i_j_idx); - if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) - { - counter = counter + 1; - const auto row = tile_idx.at(number<0>{}); - const auto col = tile_idx.at(number<1>{}); - printf("in c row is %d , col is %d, counter is %d, value is: " - "%f \n", - row, - col, - counter, - ck_tile::type_convert(s_acc(i_j_idx))); - } - }); - }); - } - +#if 0 + PrintMem(s_acc); #endif // move sacc to LDS auto bridge_lds_view = make_tensor_view( @@ -249,15 +208,30 @@ struct FusedMoeGemmPipeline_General make_tile_window(bridge_lds_view, Policy::template MakeBridgeLdsBlockDesc().get_lengths(), {0, 0}); - + // cast data to YDataType auto y_pre = cast_tile(s_acc); + // constexpr index_t thread_buffer_size = SaccBlockTileType::get_thread_buffer_size(); + // static_for<0, thread_buffer_size, 1>{}([&](auto i) { + // //y_pre.get_thread_buffer()(i) = type_convert(s_acc.get_thread_buffer()[i]); + // if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + // { + // printf("soure value: %f to value: %f\n", + // s_acc.get_thread_buffer()[i], + // type_convert(y_pre.get_thread_buffer()[i])); + // } + // }); + +#if 1 + PrintMem(y_pre); +#endif + // save to lds store_tile(bridge_slds_win, y_pre); block_sync_lds(); // gemm down constexpr auto gemm_1 = Policy::template GetBlockGemm1(); - using SaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); - auto o_acc = SaccBlockTileType{}; + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + auto o_acc = OaccBlockTileType{}; // y data auto bridge_llds_win = make_tile_window(bridge_lds_view, @@ -265,6 +239,7 @@ struct FusedMoeGemmPipeline_General {0, 0}, Policy::template MakeYTileDistribution()); auto y = load_tile(bridge_llds_win); + // d data auto d_global_to_dram_window = make_tile_window( d_window_.get_bottom_tensor_view(), @@ -278,6 +253,7 @@ struct FusedMoeGemmPipeline_General index_t iCounter1 = n1_loops - 1; while(iCounter1 > 0) { + clear_tile(o_acc); block_sync_lds(); gemm_1(o_acc, y, d); block_sync_lds(); @@ -292,9 +268,16 @@ struct FusedMoeGemmPipeline_General } // tail { + clear_tile(o_acc); block_sync_lds(); gemm_1(o_acc, y, d); + + auto o = cast_tile(o_acc); + store_tile(o_window_, o); } +#if 0 + PrintMem(o_acc); +#endif // store_tile(o_window_, a_dram_block); } }; diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp index a9ddaa00fc..e909552deb 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp @@ -13,7 +13,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp" namespace ck_tile { @@ -230,7 +230,28 @@ struct FusedMoeGemmPipelineGeneralPolicy typename S_::WarpPerBlock_1, decltype(warp_gemm)>; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV2{}; + } + + // this is used as A matrix for 2nd gemm + template + CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution() + { + using S_ = remove_cvref_t; + using WarpGemm = remove_cvref_t())>; + + constexpr auto y_outer_dstr_enc = tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 1>>{}; + + constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{}); + constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode); + return y_block_dstr; } template @@ -240,12 +261,12 @@ struct FusedMoeGemmPipelineGeneralPolicy using WarpGemm = remove_cvref_t())>; constexpr auto d_outer_dstr_enc = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, + sequence<1>, + tuple, sequence>, + tuple>, + tuple>, sequence<1, 2>, - sequence<0, 0>>{}; + sequence<0, 1>>{}; constexpr auto d_block_dstr_encode = detail::make_embed_tile_distribution_encoding( d_outer_dstr_enc, typename WarpGemm::BWarpDstrEncoding{}); @@ -326,7 +347,15 @@ struct FusedMoeGemmPipelineGeneralPolicy // TODO: this is ugly constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv; // TODO: ugly - if constexpr(std::is_same_v && + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8) + { + return WarpGemmImpl, + 1>>{}; + } + else if constexpr(std::is_same_v && std::is_same_v && S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8) { @@ -358,7 +387,15 @@ struct FusedMoeGemmPipelineGeneralPolicy using S_ = typename Problem::BlockShape; constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv; // TODO: ugly - if constexpr(std::is_same_v && + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8) + { + return WarpGemmImpl, + 1>>{}; + } + else if constexpr(std::is_same_v && std::is_same_v && S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8) { @@ -383,27 +420,5 @@ struct FusedMoeGemmPipelineGeneralPolicy 2>>{}; } } - - // this is used as A matrix for 2nd gemm - template - CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution() - { - using S_ = remove_cvref_t; - using WarpGemm = remove_cvref_t())>; - - // TODO: all waves a along different N, but same M - constexpr auto y_outer_dstr_enc = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{}); - constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode); - return y_block_dstr; - } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp new file mode 100644 index 0000000000..da6441650f --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block distributed tensor +// C is block distributed tensor +template +struct BlockGemmARegBRegCRegV2 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensor& a_block_tensor, + const BBlockTensor& b_block_tensor) const + { + static_assert(std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + // M->N Warp + // constexpr auto a_block_outer_dstr_encoding = + // tile_distribution_encoding, + // tuple, sequence>, + // tuple>, + // tuple>, + // sequence<1, 2>, + // sequence<0, 0>>{}; + + // constexpr auto b_block_outer_dstr_encoding = + // tile_distribution_encoding, + // tuple, sequence>, + // tuple>, + // tuple>, + // sequence<1, 2>, + // sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + // constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + // a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + // constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + // b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + // check ABC-block-distribution + // static_assert( + // std::is_same_v, + // remove_cvref_t>, + // "A distribution is wrong!"); + // static_assert( + // std::is_same_v, + // remove_cvref_t>, + // "B distribution is wrong!"); + static_assert( + std::is_same_v, + remove_cvref_t>, + "C distribution is wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor, + const BBlockTensor& b_block_tensor) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor, b_block_tensor); + return c_block_tensor; + } +}; + +} // namespace ck_tile