diff --git a/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api.cpp b/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api.cpp index c1a4c495c3..93fd916592 100644 --- a/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api.cpp +++ b/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api.cpp @@ -19,13 +19,13 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) { - using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; + using t_ = fmoe_, S<1, 4, 1>, S<32, 32, 8>, 1, 0>; r = fused_moegemm_(s, a); } else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) { - using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; + using t_ = fmoe_, S<1, 4, 1>, S<32, 32, 8>, 1, 0>; r = fused_moegemm_(s, a); } // clang-format on diff --git a/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_internal.hpp b/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_internal.hpp index 5872179ef7..f89b8f8f71 100644 --- a/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_internal.hpp +++ b/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_internal.hpp @@ -19,8 +19,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) typename Ts_::WarpPerBlock_0, typename Ts_::WarpTile_0, typename Ts_::BlockTile_1, - typename Ts_::WarpPerBlock_0, - typename Ts_::WarpTile_0>; + typename Ts_::WarpPerBlock_1, + typename Ts_::WarpTile_1>; using f_problem = ck_tile::FusedMoeGemmPipelineProblem; // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx; - using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk; + using f_pipeline = ck_tile::FusedMoeGemmPipeline_General; using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear; - using f_kernel = ck_tile::FusedMoeGemmKernel; + using f_kernel = ck_tile::FusedMoeGemmGlKernel; const dim3 grids = f_kernel::GridSize(a); constexpr dim3 blocks = f_kernel::BlockSize(); diff --git a/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_traits.hpp b/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_traits.hpp index cc476685de..d23f30587b 100644 --- a/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_traits.hpp +++ b/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_traits.hpp @@ -44,8 +44,8 @@ struct fmoe_ // traits, ugly name, only used for internal using WarpPerBlock_0 = ck_tile::remove_cvref_t; using WarpTile_0 = ck_tile::remove_cvref_t; - using BlockTile_1 = ck_tile::sequence; - using WarpPerBlock_1 = ck_tile::remove_cvref_t; + using BlockTile_1 = ck_tile::sequence; + using WarpPerBlock_1 = ck_tile::sequence<1, 1, 4>;//ck_tile::remove_cvref_t; using WarpTile_1 = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t GateOnly = GateOnly_; diff --git a/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp b/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp index 93f9c77869..fcea85eeb0 100644 --- a/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp +++ b/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp @@ -8,7 +8,7 @@ // clang-format off template float fused_moegemm_< - fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0> + fmoe_, S<1, 4, 1>, S<32, 32, 8>, 1, 0> >(const ck_tile::stream_config& s, fused_moegemm_args a); // clang-format on diff --git a/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_fp16_m32.cpp b/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_fp16_m32.cpp index b8a823e8ed..744c5771ed 100644 --- a/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_fp16_m32.cpp +++ b/example/ck_tile/17_fused_moe_general/instances/fused_moegemm_fp16_m32.cpp @@ -8,7 +8,7 @@ // clang-format off template float fused_moegemm_< - fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0> + fmoe_, S<1, 4, 1>, S<32, 32, 8>, 1, 0> >(const ck_tile::stream_config& s, fused_moegemm_args a); // clang-format on diff --git a/example/ck_tile/17_fused_moe_general/main.cpp b/example/ck_tile/17_fused_moe_general/main.cpp index 78fa09e78c..562d0f0bfa 100644 --- a/example/ck_tile/17_fused_moe_general/main.cpp +++ b/example/ck_tile/17_fused_moe_general/main.cpp @@ -261,8 +261,8 @@ bool run(const ck_tile::ArgParser& arg_parser) } // permute weight - ck_tile::HostTensor g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); - ck_tile::HostTensor d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); + // ck_tile::HostTensor g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); + // ck_tile::HostTensor d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); // do moe sorting if(balance) @@ -345,8 +345,8 @@ bool run(const ck_tile::ArgParser& arg_parser) // done, preparing GPU buffer ck_tile::DeviceMem a_buf(a_host); - ck_tile::DeviceMem g_perm_buf(g_perm_host); - ck_tile::DeviceMem d_perm_buf(d_perm_host); + ck_tile::DeviceMem g_perm_buf(g_host); + ck_tile::DeviceMem d_perm_buf(d_host); ck_tile::DeviceMem sa_buf(sa_host); ck_tile::DeviceMem sg_buf(sg_host); ck_tile::DeviceMem sd_buf(sd_host); @@ -390,7 +390,8 @@ bool run(const ck_tile::ArgParser& arg_parser) tokens, experts, topk, - stride}; + stride, + max_num_tokens_padded}; float ave_time = fused_moegemm( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); diff --git a/example/ck_tile/17_fused_moe_general/misc/moe-0.png b/example/ck_tile/17_fused_moe_general/misc/moe-0.png deleted file mode 100644 index aed1964f28..0000000000 Binary files a/example/ck_tile/17_fused_moe_general/misc/moe-0.png and /dev/null differ diff --git a/example/ck_tile/17_fused_moe_general/misc/moe-1.png b/example/ck_tile/17_fused_moe_general/misc/moe-1.png deleted file mode 100644 index 91a1f2d9dd..0000000000 Binary files a/example/ck_tile/17_fused_moe_general/misc/moe-1.png and /dev/null differ diff --git a/example/ck_tile/17_fused_moe_general/misc/moe-2.png b/example/ck_tile/17_fused_moe_general/misc/moe-2.png deleted file mode 100644 index 98d83866fa..0000000000 Binary files a/example/ck_tile/17_fused_moe_general/misc/moe-2.png and /dev/null differ diff --git a/example/ck_tile/17_fused_moe_general/misc/moe-3.png b/example/ck_tile/17_fused_moe_general/misc/moe-3.png deleted file mode 100644 index 77c6d9b6e4..0000000000 Binary files a/example/ck_tile/17_fused_moe_general/misc/moe-3.png and /dev/null differ diff --git a/include/ck_tile/core/algorithm/indexing_adaptor.hpp b/include/ck_tile/core/algorithm/indexing_adaptor.hpp index ef59abdc99..c1d993125e 100644 --- a/include/ck_tile/core/algorithm/indexing_adaptor.hpp +++ b/include/ck_tile/core/algorithm/indexing_adaptor.hpp @@ -57,4 +57,76 @@ struct indexing_adaptor_onshot_cached return ck_tile::is_known_at_compile_time::value; } }; +#define Using_Gather 1 +template +struct indexing_adaptor +{ + + CK_TILE_HOST_DEVICE constexpr indexing_adaptor() = default; + CK_TILE_HOST_DEVICE constexpr indexing_adaptor(const IndexingType* idx) : cached_idx_(idx) {} + const IndexingType* cached_idx_; +#if Using_Gather + mutable index_t pre_up_index_ = 0; + mutable index_t pre_low_index_ = 0; +#endif + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = *(cached_idx_ + idx_up[number<0>{}]); +#if Using_Gather + pre_up_index_ = idx_up[number<0>{}]; + pre_low_index_ = idx_low(number<0>{}); +#if 0 + if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + { + printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{})); + } +#endif +#endif + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& /*idx_low*/, + const UpIdx& /*idx_up*/) const + { + // TODO: nonthing changed here + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); +#if !Using_Gather + idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}]; +#else + int up_index = idx_diff_up[number<0>{}] + pre_up_index_; + int low_index = *(cached_idx_ + up_index); + idx_diff_low(number<0>{}) = low_index - pre_low_index_; + + pre_up_index_ = up_index; + pre_low_index_ = low_index; +#if 0 + if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + { + printf("\n index form %d to %d, diff from %d to %d \n", + up_index, + low_index, + idx_diff_up[number<0>{}], + idx_diff_low(number<0>{})); + } +#endif +#endif + + // pass the diff to lower, but not changing the actually index + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp index d29e070fc0..e57b23adea 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp @@ -213,9 +213,9 @@ struct FusedMoeGemmGlKernel CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) { - constexpr index_t block_m = BlockShape::Block_M0; - int max_num_tokens_padded = - hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk; + //constexpr index_t block_m = BlockShape::Block_M0; + int max_num_tokens_padded = hargs.max_num_tokens_padded; + //hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk; // printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded); return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size); } diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp index 2d25d44f3c..a0f5142d8b 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp @@ -117,6 +117,7 @@ struct FusedMoeGemmHostArgs index_t topk; // need this? index_t stride_token; // for input/output, stride for each row, should >= hidden_size + index_t max_num_tokens_padded; // size of sorted_token_ids_ptr }; // This is scatter/gather b2b group-gemm