From 70fa98adf8801847cccc0be5832c357dbdfbe394 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 5 Nov 2024 16:06:52 +0800 Subject: [PATCH] update code --- example/ck_tile/15_fused_moe/CMakeLists.txt | 15 + .../ck_tile/15_fused_moe/fused_moegemm.hpp | 52 +-- .../instances/fused_moegemm_api.cpp | 35 ++ .../instances/fused_moegemm_api_internal.hpp | 46 +++ .../instances/fused_moegemm_api_traits.hpp | 50 +++ example/ck_tile/15_fused_moe/main.cpp | 314 ++++++++++++------ include/ck_tile/host.hpp | 1 + include/ck_tile/host/device_memory.hpp | 30 ++ .../host/reference/reference_moe_sorting.hpp | 78 +++++ .../host/reference/reference_permute.hpp | 7 +- include/ck_tile/ops/fused_moe.hpp | 14 +- .../fused_moe/kernel/fused_moegemm_kernel.hpp | 61 ++-- .../fused_moegemm_pipeline_flatmm.hpp | 19 +- .../fused_moegemm_pipeline_flatmm_policy.hpp | 9 +- .../fused_moegemm_pipeline_problem.hpp | 4 +- .../pipeline/fused_moegemm_traits.hpp | 18 +- 16 files changed, 564 insertions(+), 189 deletions(-) create mode 100644 example/ck_tile/15_fused_moe/CMakeLists.txt create mode 100644 example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp create mode 100644 example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp create mode 100644 example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp create mode 100644 include/ck_tile/host/reference/reference_moe_sorting.hpp diff --git a/example/ck_tile/15_fused_moe/CMakeLists.txt b/example/ck_tile/15_fused_moe/CMakeLists.txt new file mode 100644 index 0000000000..24ba06bc7f --- /dev/null +++ b/example/ck_tile/15_fused_moe/CMakeLists.txt @@ -0,0 +1,15 @@ +set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding ${TILE_EXAPMLE_FUSED_MOE}") +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp) +target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS}) + +set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS}) diff --git a/example/ck_tile/15_fused_moe/fused_moegemm.hpp b/example/ck_tile/15_fused_moe/fused_moegemm.hpp index c68fdb9c0e..154dfdf1c4 100644 --- a/example/ck_tile/15_fused_moe/fused_moegemm.hpp +++ b/example/ck_tile/15_fused_moe/fused_moegemm.hpp @@ -16,33 +16,33 @@ struct FusedMoeGemmTypeConfig; template struct FusedMoeGemmTypeConfig; { - using ADataType = ck_tile::bf16_t; - using GDataType = ck_tile::bf16_t; - using DDataType = ck_tile::bf16_t; - using AccDataType = float; - using ODataType = ck_tile::bf16_t; - using AScaleDataType = ck_tile::remove_cvref_t; - using W0ScaleDataType = ck_tile::remove_cvref_t; - using W1ScaleDataType = ck_tile::remove_cvref_t; - using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using ADataType = ck_tile::bf16_t; + using GDataType = ck_tile::bf16_t; + using DDataType = ck_tile::bf16_t; + using AccDataType = float; + using ODataType = ck_tile::bf16_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; using TopkWeightDataType = ck_tile::remove_cvref_t; - using IndexDataType = ck_tile::index_t; + using IndexDataType = ck_tile::index_t; }; template struct FusedMoeGemmTypeConfig; { - using ADataType = ck_tile::int8_t; - using GDataType = ck_tile::int8_t; - using DDataType = ck_tile::int8_t; - using AccDataType = int32_t; - using ODataType = ck_tile::bf16_t; - using AScaleDataType = ck_tile::remove_cvref_t; - using W0ScaleDataType = ck_tile::remove_cvref_t; - using W1ScaleDataType = ck_tile::remove_cvref_t; - using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using ADataType = ck_tile::int8_t; + using GDataType = ck_tile::int8_t; + using DDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using ODataType = ck_tile::bf16_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; using TopkWeightDataType = ck_tile::remove_cvref_t; - using IndexDataType = ck_tile::index_t; + using IndexDataType = ck_tile::index_t; }; // runtime args @@ -53,14 +53,16 @@ struct fused_moegemm_args : public ck_tile::Layernorm2dFwdHostArgs // This is the public API, will be generated by script struct fused_moegemm_traits { - std::string prec_i; // input precision - std::string prec_w; // weight precision - std::string prec_o; // output precision + std::string prec_i; // input precision + std::string prec_w; // weight precision + std::string prec_o; // output precision std::string prec_st; // token scale data type std::string prec_sw; // weight scale data type std::string prec_sq; // smooth quant scale - std::string prec_kw; // topk-weight data type - int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant + std::string prec_kw; // topk-weight data type + int block_m; + int gate_only; + int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant }; float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp new file mode 100644 index 0000000000..bb18180efb --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "fused_moegemm.hpp" + +// Note: this internal API only declare, not define here, otherwise will block `make -j` +template +float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a); + +float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s) +{ + template + using S = ck_tile::sequence; + float r = -1; + if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && block_m == 32 && + gate_only == 1) + { + using t_ = fmoe_, + S<4, 1, 1>, + S<32, 32, 16>, + 1, + 0>; + fused_moegemm_(s, a); + } + return r; +} diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp new file mode 100644 index 0000000000..0ae122ff82 --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "fused_moegemm_api_traits.hpp" +#include "ck_tile/ops/fused_moe.hpp" + +template +float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) +{ + using f_traits = ck_tile::FusedMoeGemmTraits; + using f_shape = ck_tile::FusedMoeGemmShape; + using f_problem = ck_tile::FusedMoeGemmPipelineProblem + + using f_pipeline = ck_tile::FusedMoeGemmPipeline_Flatmm; + using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear; + using f_kernel = ck_tile::FusedMoeGemmKernel; + + const dim3 grids = f_kernel::GridSize(a); + constexpr dim3 blocks = f_kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = f_kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << f_kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(f_kernel{}, grids, blocks, 0, kargs)); +} diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp new file mode 100644 index 0000000000..d9fca3f26b --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template + typename WarpPerBlock_, + typename WarpTile_, // seq<*,*,*>, used to select mfma + ck_tile::index_t GateOnly_ = 0, + ck_tile::index_t FusedQuant_ = 0> +struct fmoe_ // traits, ugly name, only used for internal +{ + using TypeConfig = FusedMoeGemmTypeConfig; + + using ADataType = remove_cvref_t; + using GDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using AScaleDataType = remove_cvref_t; + using GScaleDataType = remove_cvref_t; + using DScaleDataType = remove_cvref_t; + using YSmoothScaleDataType = remove_cvref_t; + using TopkWeightDataType = remove_cvref_t; + using IndexDataType = remove_cvref_t; + + static constexpr index_t BT_ = BlockTIle_::at(number<0>{}); // block token + static constexpr index_t BI_ = BlockTIle_::at(number<1>{}); // block intermediate + static constexpr index_t BH_ = BlockTIle_::at(number<2>{}); // block hidden + static constexpr index_t BD_ = BlockTIle_::at(number<3>{}); // block down + + using BlockTile_0 = ck_tile::sequence; + using WarpPerBlock_0 = remove_cvref_t; + using WarpTile_0 = remove_cvref_t; + + using BlockTile_1 = ck_tile::sequence; + using WarpPerBlock_1 = remove_cvref_t; + using WarpTile_1 = remove_cvref_t; + + static constexpr ck_tile::index_t GateOnly = GateOnly_; + static constexpr ck_tile::index_t FusedQuant = FusedQuant_; +}; diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index b91f402f42..a446ea3be1 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -1,7 +1,10 @@ #include "ck_tile/host.hpp" -#include "layernorm2d_fwd.hpp" +#include "fused_moegemm.hpp" #include #include +#include +#include +#include // different threshold for different dtype template @@ -20,18 +23,64 @@ auto get_elimit() return ck_tile::make_tuple(rtol, atol); } - // mfma_type, 0:32x32, 1:16x16 -template -auto shuffle_moe_weight(const H& t, std::string mfma_dtype, int mfma_type = 0) +// TODO: padding? +template +auto shuffle_moe_weight(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type = 0) { static_assert(t.get_lengths().size() == 3); int b_ = t.get_lengths()[0]; int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[2]; - if ((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) { - std::vector new_lens {b_, n_/32, 32, k_/16, 2, 8}; + if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) + { + ck_tile::HostTensor t_view({b_, n_ / 32, 32, k_ / 16, 2, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1) + { + ck_tile::HostTensor t_view({b_, n_ / 16, 16, k_ / 32, 4, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0) + { + ck_tile::HostTensor t_view({b_, n_ / 32, 32, k_ / 32, 2, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1) + { + ck_tile::HostTensor t_view({b_, n_ / 16, 16, k_ / 64, 4, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + return t; } + +template +void topid_unique_gen( + std::vector& host_tensor, int tokens, int topk, int num_expert, int seed) +{ + size_t total_size = topk * tokens; + std::srand(seed); + std::set unique_set; + IndexType current_v; + for(size_t i = 0; i < total_size; i++) + { + if(i % topk == 0) + { + unique_set.clear(); + } + current_v = std::rand() % num_expert; + while(unique_set.find(current_v) != unique_set.end()) + { + current_v = std::rand() % num_expert; + } + unique_set.insert(current_v); + host_tensor[i] = current_v; + } } auto create_args(int argc, char* argv[]) @@ -55,8 +104,11 @@ auto create_args(int argc, char* argv[]) .insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32") .insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32") .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") - .insert("gonly", "0", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") - .insert("balance", "1", "if set to 1, will try balance the expert in topk-ids(convenient for testing)") + .insert( + "gate_only", "0", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") + .insert("balance", + "1", + "if set to 1, will try balance the expert in topk-ids(convenient for testing)") .insert("warmup", "5", "cold iter") .insert("repeat", "20", "hot iter"); @@ -64,133 +116,178 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type, SQ:smooth-quant-type, KW:topk-weight-type +// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type, +// SQ:smooth-quant-type, KW:topk-weight-type template bool run(const ck_tile::ArgParser& arg_parser) { - ck_tile::index_t tokens = arg_parser.get_int("t"); - ck_tile::index_t experts = arg_parser.get_int("e"); - ck_tile::index_t topk = arg_parser.get_int("k"); - ck_tile::index_t hidden_size = arg_parser.get_int("h"); - ck_tile::index_t intermediate_size = arg_parser.get_int("i"); - ck_tile::index_t stride = arg_parser.get_int("stride"); - ck_tile::index_t block_m = arg_parser.get_int("bm"); + ck_tile::index_t tokens = arg_parser.get_int("t"); + ck_tile::index_t experts = arg_parser.get_int("e"); + ck_tile::index_t topk = arg_parser.get_int("k"); + ck_tile::index_t hidden_size = arg_parser.get_int("h"); + ck_tile::index_t intermediate_size = arg_parser.get_int("i"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + ck_tile::index_t block_m = arg_parser.get_int("bm"); if(stride < 0) stride = hidden_size; std::string prec_i = arg_parser.get_str("prec_i"); - std::string prec_w = arg_parser.get_str("prec_w"); + std::string prec_w = arg_parser.get_str("prec_w"); std::string prec_o = arg_parser.get_str("prec_o"); - std::string prec_st = arg_parser.get_str("prec_st"); - std::string prec_sw = arg_parser.get_str("prec_sw"); - std::string prec_sq = arg_parser.get_str("prec_sq"); + std::string prec_st = arg_parser.get_str("prec_st"); + std::string prec_sw = arg_parser.get_str("prec_sw"); + std::string prec_sq = arg_parser.get_str("prec_sq"); std::string prec_kw = arg_parser.get_str("prec_kw"); - prec_st = (prec_st == "auto") ? "fp32" : prec_st; - prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; - prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; - prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - int fused_quant = arg_parser.get_int("fquant"); - int gonly = arg_parser.get_int("gonly"); - int balance = arg_parser.get_int("balance"); - int tp = arg_parser.get_int("tp"); - ck_tile::index_t shared_intermediate_size = intermediate_size * (gonly ? 1 : 2) / tp; + prec_st = (prec_st == "auto") ? "fp32" : prec_st; + prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; + prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; + prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int fused_quant = arg_parser.get_int("fquant"); + int gate_only = arg_parser.get_int("gate_only"); + int balance = arg_parser.get_int("balance"); + int tp = arg_parser.get_int("tp"); + ck_tile::index_t shared_intermediate_size = intermediate_size * (gate_only ? 1 : 2) / tp; - using TypeConfig = FusedMoeGemmTypeConfig; - using ADataType = typename TypeConfig::ADataType ; - using GDataType = typename TypeConfig::GDataType ; - using DDataType = typename TypeConfig::DDataType ; - using AccDataType = typename TypeConfig::AccDataType ; - using ODataType = typename TypeConfig::ODataType ; - using AScaleDataType = typename TypeConfig::AScaleDataType ; - using W0ScaleDataType = typename TypeConfig::W0ScaleDataType ; - using W1ScaleDataType = typename TypeConfig::W1ScaleDataType ; - using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType; - using TopkWeightDataType = typename TypeConfig::TopkWeightDataType ; - using IndexDataType = typename TypeConfig::IndexDataType ; + using TypeConfig = FusedMoeGemmTypeConfig; + using ADataType = typename TypeConfig::ADataType; + using GDataType = typename TypeConfig::GDataType; + using DDataType = typename TypeConfig::DDataType; + using AccDataType = typename TypeConfig::AccDataType; + using ODataType = typename TypeConfig::ODataType; + using AScaleDataType = typename TypeConfig::AScaleDataType; + using GScaleDataType = typename TypeConfig::GScaleDataType; + using DScaleDataType = typename TypeConfig::DScaleDataType; + using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType; + using TopkWeightDataType = typename TypeConfig::TopkWeightDataType; + using IndexDataType = typename TypeConfig::IndexDataType; // host verify ck_tile::HostTensor a_host({tokens, hidden_size}, {stride, 1}); - ck_tile::HostTensor g_host({e, shared_intermediate_size, hidden_size}); - ck_tile::HostTensor d_host({e, intermediate_size, hidden_size}); + ck_tile::HostTensor g_host({e, shared_intermediate_size, hidden_size}); + ck_tile::HostTensor d_host({e, intermediate_size, hidden_size}); + ck_tile::HostTensor o_host({tokens, hidden_size}, {stride, 1}); + ck_tile::HostTensor sa_host({tokens}); + ck_tile::HostTensor sg_host({shared_intermediate_size}); + ck_tile::HostTensor sd_host({intermediate_size}); + ck_tile::HostTensor sy_host({intermediate_size}); // smooth-quant + ck_tile::HostTensor topk_ids_host({tokens, topk}); // to be sort + ck_tile::HostTensor topk_weight_host({tokens, topk}); // to be sort + int max_num_tokens_padded = topk * tokens + experts * (block_m - 1); + ck_tile::HostTensor sorted_token_ids_host({max_num_tokens_padded}); + ck_tile::HostTensor sorted_weight_host({max_num_tokens_padded}); + ck_tile::HostTensor sorted_expert_ids_host( + {(max_num_tokens_padded + block_m - 1) / block_m}); + ck_tile::HostTensor num_sorted_tiles_host({1}); - ck_tile::HostTensor x_residual_host({m, n}, {stride, 1}); - ck_tile::HostTensor y_residual_host({m, n}, {stride, 1}); - - ck_tile::HostTensor y_host_ref({m, n}, {stride, 1}); - ck_tile::HostTensor y_host_dev({m, n}, {stride, 1}); - - ck_tile::HostTensor mean_host_ref({m}); - ck_tile::HostTensor invStd_host_ref({m}); - ck_tile::HostTensor y_scale_host_ref({m}); - ck_tile::HostTensor y_scale_host_dev({m}); - - ck_tile::HostTensor x_scale_host({n}); - ck_tile::HostTensor x_scale_host_dev({n}); + // permute weight + ck_tile::HostTensor g_perm_host = shuffle_moe_weight(g_host, prec_w); + ck_tile::HostTensor d_perm_host = shuffle_moe_weight(d_host, prec_w); ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); - ck_tile::FillUniformDistribution{-.5f, .5f}(x_residual_host); - ck_tile::FillUniformDistribution{-1.f, 1.f}(x_scale_host); - ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); - ck_tile::FillUniformDistribution{-.5f, .5f}(beta_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(g_perm_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(d_perm_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(sa_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(sg_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(sd_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(sy_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(topk_weight_host); - ck_tile::DeviceMem x_buf(a_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); - ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes()); - ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes()); + // do moe sorting + if(balance) + { + int e_cnt = 0 for(int i = 0; i < static_cast(topk_ids_host.mData.size()); i++) + { + topk_ids_host.mData[i] = e_cnt; + e_cnt++; + if(e_cnt >= experts) + e_cnt = 0; + } + } + else + { + topid_unique_gen(topk_ids_host.mData, tokens, topk, experts, 11913); + } - ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); + ck_tile::reference_moe_sorting( + topk_ids_host, + topk_weight_host, + sorted_token_ids_host, + sorted_weight_host, + sorted_expert_ids_host, + num_sorted_tiles_host.mData[0], + experts, + block_m); + // done, preparing GPU buffer + ck_tile::DeviceMem a_buf(a_host); + ck_tile::DeviceMem g_perm_buf(g_perm_host); + ck_tile::DeviceMem d_perm_buf(d_perm_host); + ck_tile::DeviceMem sa_buf(sa_host); + ck_tile::DeviceMem sg_buf(sg_host); + ck_tile::DeviceMem sd_buf(sd_host); + ck_tile::DeviceMem sy_buf(sy_host); + ck_tile::DeviceMem o_buf(o_host); - x_buf.ToDevice(a_host.data()); - gamma_buf.ToDevice(gamma_host.data()); - beta_buf.ToDevice(beta_host.data()); - x_residual_buf.ToDevice(x_residual_host.data()); - x_scale_buf.ToDevice(x_scale_host.data()); + ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host); + ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host); + ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host); + ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host); auto prec_str = [&]() { auto base_str = prec_i; + if(prec_i != prec_w) + base_str += "x" + prec_w; if(prec_i != prec_o) + base_str += "=" + prec_o; + if(fused_quant != 0) { - base_str += "|" + prec_o; - } - if(fused_quant == 1) - { - base_str += std::string("(") + prec_sy + ")"; + base_str += std::string("(") + prec_sa + "|" + prec_sg + "|" + prec_sq + ")"; } return base_str; }(); std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + << " t:" << tokens << ", e:" << experts << ", k:" << topk << ", st:" << stride + << ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp + << ", go:" << gate_only << ", q:" << fused_quant << std::flush; - layernorm2d_fwd_traits traits{ - prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; + fused_moegemm_traits traits{prec_i, + prec_w, + prec_o, + prec_st, + prec_sw, + prec_sq, + prec_kw, + block_m, + gate_only, + fused_quant}; - layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), - fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, - fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr, - gamma_buf.GetDeviceBuffer(), - beta_buf.GetDeviceBuffer(), + fused_moegemm_args args{a_buf.GetDeviceBuffer(), + fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, + g_buf.GetDeviceBuffer(), + d_buf.GetDeviceBuffer(), + fused_quant != 0 + ? sg_buf.GetDeviceBuffer(), + fused_quant != 0 + ? sd_buf.GetDeviceBuffer(), + fused_quant == 1 + ? sy_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + sorted_token_ids_buf.GetDeviceBuffer(), + sorted_weight_buf.GetDeviceBuffer(), + sorted_expert_ids_buf.GetDeviceBuffer(), + num_sorted_tiles_buf.GetDeviceBuffer(), + hidden_size, + intermediate_size, + num_tokens, + experts, + stride }; - y_buf.GetDeviceBuffer(), - fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr, - fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr, - nullptr, // p_mean, unsupported yet - nullptr, // p_invStd, unsupported yet - - epsilon, - m, - n, - stride}; - - float ave_time = layernorm2d_fwd( + float ave_time = fused_moegemm( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); if(ave_time < 0) @@ -199,22 +296,30 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } +#if 0 std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(GammaDataType) * n + sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; +#else + std::size_t flop_gemm_0 = 2 * tokens * topk * shared_intermediate_size * hidden_size; + std::size_t flop_gemm_1 = 2 * tokens * topk * hidden_size * hidden_size; + double tflops = (flop_gemm_0 + flop_gemm_1) / (static_cast(ave_time) * 1e-3) / 1e12; + // float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << tflops << " tflops" << std::flush; +#endif bool pass = true; if(do_validation) { +#if 0 // reference if(fused_add != 0) { // fused pre_add/pre_add_store // TODO we accumulate directly to a_host for simplcity here... - std::transform(a_host.mData.cbegin(), a_host.mData.cend(), x_residual_host.mData.cbegin(), @@ -353,6 +458,9 @@ bool run(const ck_tile::ArgParser& arg_parser) } std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; +#else + std::cout << std::flush << std::endl; +#endif } return pass; diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index c0ab13ce3d..2e96009ace 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -23,6 +23,7 @@ #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" +#include "ck_tile/host/reference/reference_moe_sorting.hpp" #include "ck_tile/host/reference/reference_permute.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp" diff --git a/include/ck_tile/host/device_memory.hpp b/include/ck_tile/host/device_memory.hpp index 7c8549f74f..7d85ae3c42 100644 --- a/include/ck_tile/host/device_memory.hpp +++ b/include/ck_tile/host/device_memory.hpp @@ -7,6 +7,7 @@ #include #include #include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/host_tensor.hpp" namespace ck_tile { template @@ -36,6 +37,19 @@ struct DeviceMem mpDeviceBuf = nullptr; } } + template + DeviceMem(const HostTensor& t) : mMemSize(t.get_element_space_size_in_bytes()) + { + if(mMemSize != 0) + { + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + } + else + { + mpDeviceBuf = nullptr; + } + ToDevice(t.data()); + } void Realloc(std::size_t mem_size) { if(mpDeviceBuf) @@ -92,6 +106,22 @@ struct DeviceMem HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); } } + + // construct a host tensor with type T + template + HostTensor ToHost(std::size_t cpySize = mMemSize) + { + // TODO: host tensor could be slightly larger than the device tensor + // we just copy all data from GPU buffer + std::size_t host_elements = + (cpySize + sizeof(T) - 1) / sizeof(T) HostTensor h_({host_elements}); + if(mpDeviceBuf) + { + HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); + } + return h_; + } + void SetZero() const { if(mpDeviceBuf) diff --git a/include/ck_tile/host/reference/reference_moe_sorting.hpp b/include/ck_tile/host/reference/reference_moe_sorting.hpp new file mode 100644 index 0000000000..78e3393994 --- /dev/null +++ b/include/ck_tile/host/reference/reference_moe_sorting.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, + const HostTensor& weights, + HostTensor& sorted_token_ids, + HostTensor& sorted_weight, + HostTensor& sorted_expert_ids, + index_t& unit_cnt, + const index_t experts, + const index_t unit_size) +{ + const index_t num_token = topk_ids.mDesc.get_lengths()[0]; + const index_t topk = topk_ids.mDesc.get_lengths()[1]; + std::vector> expert_tokens(experts, + std::vector(unit_size, num_token)); + std::vector> expert_token_weights( + experts, std::vector(unit_size, 0)); + std::vector expert_slices(experts, 1); + std::vector expert_slice_idxs(experts, 0); + + for(index_t t = 0; t < num_token; t++) + { + for(index_t k = 0; k < topk; k++) + { + IndexType e = topk_ids(t, k); + WeightType w = weights(t, k); + index_t idx = expert_slice_idxs[e]; + if(idx > expert_slices[e] * unit_size - 1) + { + expert_slices[e]++; + index_t new_size = expert_slices[e] * unit_size; + expert_tokens[e].resize(new_size); + expert_token_weights[e].resize(new_size); + for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++) + { + expert_tokens[e][i] = num_token; + expert_token_weights[e][i] = 0; + } + } + + expert_tokens[e][idx] = t; + expert_token_weights[e][idx] = w; + expert_slice_idxs[e]++; + } + } + + IndexType* out_tokens = sorted_token_ids.data(); + WeightType* out_weights = sorted_weight.data(); + IndexType* out_expert_id = sorted_expert_ids.data(); + for(index_t e = 0; e < experts; e++) + { + memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size); + out_tokens += expert_slices[e] * unit_size; + memcpy(out_weights, + expert_token_weights[e].data(), + sizeof(WeightType) * expert_slices[e] * unit_size); + out_weights += expert_slices[e] * unit_size; + + for(index_t s = 0; s < expert_slices[e]; s++) + { + out_expert_id[s] = e; + unit_cnt++; + } + out_expert_id += expert_slices[e]; + } + + return; +} +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_permute.hpp b/include/ck_tile/host/reference/reference_permute.hpp index 836d65bd76..4e0f1a877e 100644 --- a/include/ck_tile/host/reference/reference_permute.hpp +++ b/include/ck_tile/host/reference/reference_permute.hpp @@ -56,11 +56,10 @@ reference_permute(const HostTensor& x, HostTensor& y, std::v } template -CK_TILE_HOST auto -reference_permute(const HostTensor& x, std::vector perm) +CK_TILE_HOST auto reference_permute(const HostTensor& x, std::vector perm) { - auto x_shape = x.get_lengths(); - ck_tile::index_t rank = perm.size(); + auto x_shape = x.get_lengths(); + ck_tile::index_t rank = perm.size(); std::vector y_shape = [&]() { std::vector tmp(rank, 0); for(int i = 0; i < static_cast(rank); i++) diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index 66ac95754c..d896f3ab30 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -3,12 +3,12 @@ #pragma once -#include "ck_tile/ops/fused_moe/kernel/fused_moe_kernel.hpp" -#include "ck_tile/ops/fused_moe/kernel/fused_moe_shape.hpp" -#include "ck_tile/ops/fused_moe/kernel/fused_moe_tile_partitioner.hpp" -#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm.hpp" -#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm_policy.hpp" -#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_problem.hpp" -#include "ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp" +#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp" +#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp" +#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp index d68be66f85..4c2b5c8246 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp @@ -22,17 +22,17 @@ // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // -// max_tokens_post_padded : top_k * input_tokens + num_experts * (M_a - 1) +// max_num_tokens_padded : top_k * input_tokens + num_experts * (M_a - 1) // * this could be larger than actual, since actual tokens are on GPU // // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] // |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| // sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] // -// * length is max_tokens_post_padded, actual size is num_tokens_post_padded_ptr +// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr // // sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5] -// * length is (max_tokens_post_padded + block_size - 1) / block_size +// * length is (max_num_tokens_padded + block_size - 1) / block_size // // num_tokens_post_padded_ptr : [28] // num_sorted_tiles_ptr : [7] @@ -43,11 +43,12 @@ // 3) use num_sorted_tiles_ptr, already divided by M_a // // * below used for indexing -// 1) sorted_token_ids_ptr +// 1) sorted_token_ids_ptr [max_num_tokens_padded] // 2) sorted_weight_ptr // 3) sorted_expert_ids_ptr // 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) // +// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) // // [indexing implementation-2] // before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] @@ -92,15 +93,15 @@ struct FusedMoeGemmHostArgs const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input void* o_ptr; // [m, k], output token - const void* sorted_token_ids_ptr; - const void* sorted_weight_ptr; - const void* sorted_expert_ids_ptr; - const void* num_sorted_tiles_ptr; + const void* sorted_token_ids_ptr; // [max_num_tokens_padded] + const void* sorted_weight_ptr; // [max_num_tokens_padded] + const void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size] + const void* num_sorted_tiles_ptr; // [1] - index_t hidden_size; // k + index_t hidden_size; // k index_t intermediate_size; // n (TP slice this) - index_t num_tokens; // input number of tokens for current iteration - index_t num_experts; // number of groups + index_t num_tokens; // input number of tokens for current iteration + index_t num_experts; // number of groups // index_t top_k; // need this? index_t stride_token; // for input/output, stride for each row, should >= hidden_size @@ -134,10 +135,10 @@ struct FusedMoeGemmKernel using Traits = typename Pipeline::Problem::Traits; - static constexpr bool IsGateOnly = Traits::IsGateOnly; - static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant; - static constexpr bool PadHiddenSize = Traits::PadHiddenSize; - static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize; + static constexpr bool IsGateOnly = Traits::IsGateOnly; + static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant; + static constexpr bool PadHiddenSize = Traits::PadHiddenSize; + static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize; // clang-format off template struct t2s; @@ -173,10 +174,10 @@ struct FusedMoeGemmKernel const void* sorted_expert_ids_ptr; const void* num_sorted_tiles_ptr; - index_t hidden_size; // k + index_t hidden_size; // k index_t intermediate_size; // n (TP slice this) - index_t num_tokens; // input number of tokens for current iteration - index_t num_experts; // number of groups + index_t num_tokens; // input number of tokens for current iteration + index_t num_experts; // number of groups // index_t top_k; // need this? index_t stride_token; // for input/output, stride for each row, should >= hidden_size @@ -214,7 +215,7 @@ struct FusedMoeGemmKernel index_t nr_0 = kargs.intermediate_size / Pipeline::Block_Nr0; index_t kr_0 = kargs.hidden_size / Pipeline::Block_Kr0; - index_t nr_1 = kargs.hidden_size / Pipeline::Block_Nr1; // should be same as kr_0 + index_t nr_1 = kargs.hidden_size / Pipeline::Block_Nr1; // should be same as kr_0 index_t kr_1 = kargs.intermediate_size / Pipeline::Block_Kr1; // should be same as nr_0 index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size; @@ -280,11 +281,12 @@ struct FusedMoeGemmKernel make_tuple(kr_0 * BlockShape::Block_W0, number{}, 1), number{}, number<1>{}); - const auto g_view_1_ = pad_tensor_view(g_view_, - make_tuple(number{}, - number{}, - number{}), - sequence{}); + const auto g_view_1_ = + pad_tensor_view(g_view_, + make_tuple(number{}, + number{}, + number{}), + sequence{}); const auto g_window_ = make_tile_window(g_view_1_, make_tuple(number{}, @@ -308,11 +310,12 @@ struct FusedMoeGemmKernel make_tuple(kr_1 * Pipeline::Block_W1, Pipeline::Block_W1, 1), number{}, number<1>{}); - const auto d_view_1_ = pad_tensor_view(d_view_, - make_tuple(number{}, - number{}, - number{}), - sequence{}); + const auto d_view_1_ = + pad_tensor_view(d_view_, + make_tuple(number{}, + number{}, + number{}), + sequence{}); const auto d_window_ = make_tile_window(d_view_1_, make_tuple(number{}, diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp index f11aee2036..04bd9881c8 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp @@ -44,10 +44,10 @@ struct FusedMoeGemmPipeline_Flatmm using Traits = typename Pipeline::Problem::Traits; - static constexpr bool IsGateOnly = Traits::IsGateOnly; - static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant; - static constexpr bool PadHiddenSize = Traits::PadHiddenSize; - static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize; + static constexpr bool IsGateOnly = Traits::IsGateOnly; + static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant; + static constexpr bool PadHiddenSize = Traits::PadHiddenSize; + static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize; static constexpr index_t kAlignmentA = Policy::GetAlignment_A(); static constexpr index_t kAlignmentG = Policy::GetAlignment_G(); @@ -133,11 +133,12 @@ struct FusedMoeGemmPipeline_Flatmm make_tuple(kr_0 * BlockShape::Block_W0, number{}, 1), number{}, number<1>{}); - const auto u_view_1_ = pad_tensor_view(u_view_, - make_tuple(number{}, - number{}, - number{}), - sequence{}); + const auto u_view_1_ = + pad_tensor_view(u_view_, + make_tuple(number{}, + number{}, + number{}), + sequence{}); return u_view_1_; } }(); diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp index 857484efbe..3db2c72259 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp @@ -225,7 +225,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0() { - if constexpr(Problem::Traits::PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) + if constexpr(Problem::Traits::PermuteEnum == + FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) { using WarpGemm = GetWarpGemm0{}; // assume warpgemm0/1 are the same constexpr index_t NPerBlock = Problem::BlockShape::Block_N0; @@ -703,7 +704,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0() { - if constexpr(Problem::Traits::PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) + if constexpr(Problem::Traits::PermuteEnum == + FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) { using WarpGemm = GetWarpGemm0{}; // assume warpgemm0/1 are the same constexpr index_t NPerBlock = Problem::BlockShape::Block_N0; @@ -723,7 +725,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_1() { - if constexpr(Problem::Traits::PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) + if constexpr(Problem::Traits::PermuteEnum == + FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) { using WarpGemm = GetWarpGemm1{}; // assume warpgemm0/1 are the same constexpr index_t NPerBlock = Problem::BlockShape::kBlockN_1; diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp index 510103e3f0..6089c2558f 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp @@ -14,8 +14,8 @@ template struct FusedMoeGemmTraits + FusedMoeGemmWeightPermuteEnum PermuteEnum_ = + FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten, + bool PadHiddenSize_ = false, + bool PadIntermediateSize_ = false> +struct FusedMoeGemmTraits { // Gate+Up or Gate only - static constexpr bool IsGateOnly = IsGateOnly_; - static constexpr bool UseSmoothQuant = UseSmoothQuant_; - static constexpr index_t OAtomic = OAtomic_; - static constexpr bool PadHiddenSize = PadHiddenSize_; - static constexpr bool PadIntermediateSize = PadIntermediateSize_; + static constexpr bool IsGateOnly = IsGateOnly_; + static constexpr bool UseSmoothQuant = UseSmoothQuant_; + static constexpr index_t OAtomic = OAtomic_; + static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_; + static constexpr bool PadHiddenSize = PadHiddenSize_; + static constexpr bool PadIntermediateSize = PadIntermediateSize_; }; } // namespace ck_tile