From 12480d9dc94fd13526ebcd85e56e5b00c3276fea Mon Sep 17 00:00:00 2001 From: huizzhan Date: Wed, 13 Aug 2025 07:14:23 +0000 Subject: [PATCH] grouped topk debug --- ...grouped_topk_pipeline_agmem_bgmem_creg.hpp | 192 +++++++++++------- .../gemm_softmax_grouped_topk.cpp | 16 +- .../grid_gemm_softmax_grouped_topk.hpp | 1 + 3 files changed, 129 insertions(+), 80 deletions(-) diff --git a/example/ck_tile/39_gemm_softmax_grouped_topk/block_gemm_softmax_grouped_topk_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/39_gemm_softmax_grouped_topk/block_gemm_softmax_grouped_topk_pipeline_agmem_bgmem_creg.hpp index 501039b192..5c3aa6b55d 100755 --- a/example/ck_tile/39_gemm_softmax_grouped_topk/block_gemm_softmax_grouped_topk_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/39_gemm_softmax_grouped_topk/block_gemm_softmax_grouped_topk_pipeline_agmem_bgmem_creg.hpp @@ -441,8 +441,8 @@ struct BlockGemmSoftmaxGroupedTopkPipelineAGmemBGmemCReg auto p_compute = make_static_distributed_tensor(c_block_tile.get_tile_distribution()); - auto debug_block_tile = - make_static_distributed_tensor(p_compute.get_tile_distribution()); + // auto debug_block_tile = + // make_static_distributed_tensor(p_compute.get_tile_distribution()); constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); @@ -476,88 +476,136 @@ struct BlockGemmSoftmaxGroupedTopkPipelineAGmemBGmemCReg // apply topk for softmax output auto x_tmp = p_compute; - - // // // initialize x_tmp for topk debug - // // printf("===============on device debug input=====================\n"); - // // std::mt19937 rng(123); - // // std::uniform_int_distribution dist_debug_input(1, 100); - - // constexpr auto x_tmp_spans = decltype(x_tmp)::get_distributed_spans(); - // sweep_tile_span(x_tmp_spans[number<0>{}], [&](auto idx0) { - // sweep_tile_span(x_tmp_spans[number<1>{}], [&](auto idx1) { - // const auto tile_idx = get_x_indices_from_distributed_indices( - // x_tmp.get_tile_distribution(), make_tuple(idx0, idx1)); - // auto row_id = tile_idx.at(number<0>{}); - // auto col_id = tile_idx.at(number<1>{}); - // constexpr auto i_j_idx = make_tuple(idx0, idx1); - // x_tmp(i_j_idx) = sin(float(row_id + col_id)) * 100; - // // x_tmp(i_j_idx) = sin(float(row_id)) * cos(float(col_id)); - // // x_tmp(i_j_idx) = float(dist_debug_input(rng)); - // }); - // }); - - // argmax for topk - const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) { - return e0.value > e1.value ? e0 : e1; - }; - - for(index_t i_k = 0; i_k < topk; i_k++) - { - constexpr auto p_compute_spans = decltype(p_compute)::get_distributed_spans(); - auto packet = [&]() { - auto tmp = make_static_distributed_tensor(p_compute.get_tile_distribution()); - sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - tmp.get_tile_distribution(), make_tuple(idx0, idx1)); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - ArgmaxPacket t; - t.value = x_tmp(i_j_idx); // !!! we reference p_compute here - t.arg = tile_idx.at(number<1>{}); - tmp(i_j_idx) = t; - }); + // calculate group score, need to creat group scores tensor + int num_expert_group = 16; + // int topk_group = 2; + int expert_per_group = kNPerBlock / num_expert_group; + constexpr auto p_compute_spans = decltype(p_compute)::get_distributed_spans(); + auto group_scores = x_tmp; + // init group_scores to inf + sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + group_scores(i_j_idx) = -numeric::infinity(); + }); + }); + for (index_t n_group = 0; n_group < num_expert_group; n_group++) { + // get group value matrix (masked other groups) + auto group_tmp = x_tmp; + sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + group_tmp.get_tile_distribution(), make_tuple(idx0, idx1)); + auto col_id = tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + group_tmp(i_j_idx) = ((col_id >= (n_group * expert_per_group)) && (col_id < ((n_group + 1) * expert_per_group))) ? x_tmp(i_j_idx) : -numeric::infinity(); }); - return tmp; - }(); - - auto argmax_init = ArgmaxPacket{-numeric::infinity(), 0}; - auto r = block_tile_reduce(packet, sequence<1>{}, f_argmax, argmax_init); - - block_tile_reduce_xor_sync(r, f_argmax); - - // constexpr auto value_spans = decltype(value_block_tile)::get_distributed_spans(); + }); + // get one column for group scores = rowmax(group_tmp) + auto group_scores_col = block_tile_reduce( + group_tmp, sequence<1>{}, f_max, std::numeric_limits::lowest()); + block_tile_reduce_sync(group_scores_col, f_max); + // get all group scores sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( p_compute.get_tile_distribution(), make_tuple(idx0, idx1)); - // auto row_id = tile_idx.at(number<0>{}); auto col_id = tile_idx.at(number<1>{}); constexpr auto i_j_idx = make_tuple(idx0, idx1); - ArgmaxPacket tmp = r(i_idx); - // debug_block_tile(i_j_idx) = (col_id == i_k) ? tmp.value: debug_block_tile(i_j_idx); - debug_block_tile(i_j_idx) = (col_id == i_k) ? tmp.arg: debug_block_tile(i_j_idx); - // value_block_tile(i_j_idx) = tmp.value; - // index_block_tile(i_j_idx) = tmp.arg; - }); - }); - - // update value - sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - p_compute.get_tile_distribution(), make_tuple(idx0, idx1)); - auto col_id = tile_idx.at(number<1>{}); - - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - x_tmp(i_j_idx) = (col_id == r(i_j_idx).arg) ? -numeric::infinity() - : x_tmp(i_j_idx); + group_scores(i_j_idx) = (col_id == n_group) ? group_scores_col(i_idx): group_scores(i_j_idx); }); }); } - return debug_block_tile; + + // // another scheme to reshape x_tmp from 2d to 3d + // const auto x_tmp_3d = x_tmp.block_tile_reshape((kM, num_expert_group, kN/num_expert_group)); + // auto group_scores = block_tile_reduce( + // x_tmp_3d, sequence<2>{}, f_max, std::numeric_limits::lowest()); + // block_tile_reduce_sync(group_scores, f_max); + + // // Step2: select group values and group_indices + // // argmax for topk + // const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) { + // return e0.value > e1.value ? e0 : e1; + // }; + // auto group_packet = topk(group_scores, topk_group) + + // // Step3: mask score matrix + // sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) { + // sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) { + // const auto tile_idx = get_x_indices_from_distributed_indices( + // p_compute.get_tile_distribution(), make_tuple(idx0, idx1)); + // auto col_id = tile_idx.at(number<1>{}); + + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + + // x_tmp(i_j_idx) = (col_id != group_packet(i_j_idx).arg) ? -numeric::infinity() + // : x_tmp(i_j_idx); + // }); + // }); + + // // Step4: select topk values from masked scores + // for(index_t i_k = 0; i_k < topk; i_k++) + // { + // constexpr auto p_compute_spans = decltype(p_compute)::get_distributed_spans(); + // auto packet = [&]() { + // auto tmp = make_static_distributed_tensor(p_compute.get_tile_distribution()); + // sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) { + // sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) { + // const auto tile_idx = get_x_indices_from_distributed_indices( + // tmp.get_tile_distribution(), make_tuple(idx0, idx1)); + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // ArgmaxPacket t; + // t.value = x_tmp(i_j_idx); // !!! we reference p_compute here + // t.arg = tile_idx.at(number<1>{}); + // tmp(i_j_idx) = t; + // }); + // }); + // return tmp; + // }(); + + // auto argmax_init = ArgmaxPacket{-numeric::infinity(), 0}; + // auto r = block_tile_reduce(packet, sequence<1>{}, f_argmax, argmax_init); + + // block_tile_reduce_xor_sync(r, f_argmax); + + // // constexpr auto value_spans = decltype(value_block_tile)::get_distributed_spans(); + + // sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) { + // constexpr auto i_idx = make_tuple(idx0); + // sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) { + // const auto tile_idx = get_x_indices_from_distributed_indices( + // p_compute.get_tile_distribution(), make_tuple(idx0, idx1)); + // // auto row_id = tile_idx.at(number<0>{}); + // auto col_id = tile_idx.at(number<1>{}); + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // ArgmaxPacket tmp = r(i_idx); + // debug_block_tile(i_j_idx) = (col_id == i_k) ? tmp.value: debug_block_tile(i_j_idx); + // // debug_block_tile(i_j_idx) = (col_id == i_k) ? tmp.arg: debug_block_tile(i_j_idx); + // // value_block_tile(i_j_idx) = tmp.value; + // // index_block_tile(i_j_idx) = tmp.arg; + // }); + // }); + + // // update value + // sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) { + // constexpr auto i_idx = make_tuple(idx0); + // sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) { + // const auto tile_idx = get_x_indices_from_distributed_indices( + // p_compute.get_tile_distribution(), make_tuple(idx0, idx1)); + // auto col_id = tile_idx.at(number<1>{}); + + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + + // x_tmp(i_j_idx) = (col_id == r(i_idx).arg) ? -numeric::infinity() + // : x_tmp(i_j_idx); + // }); + // }); + // } + // return debug_block_tile; + return group_scores; } }; diff --git a/example/ck_tile/39_gemm_softmax_grouped_topk/gemm_softmax_grouped_topk.cpp b/example/ck_tile/39_gemm_softmax_grouped_topk/gemm_softmax_grouped_topk.cpp index 0a810f1f05..80b8dcd114 100755 --- a/example/ck_tile/39_gemm_softmax_grouped_topk/gemm_softmax_grouped_topk.cpp +++ b/example/ck_tile/39_gemm_softmax_grouped_topk/gemm_softmax_grouped_topk.cpp @@ -286,9 +286,9 @@ int main(int argc, char* argv[]) // reference_topk(debug_host_input, value_ref, index_ref, topk); // debug_ref = reference_basic_gemm(a_host, b_host); - // debug_ref = reference_basic_gemm_softmax(a_host, b_host); - reference_basic_gemm_softmax_grouped_topk( - a_host, b_host, value_ref, index_ref, topk); + debug_ref = reference_basic_gemm_softmax(a_host, b_host); + // reference_basic_gemm_softmax_grouped_topk( + // a_host, b_host, value_ref, index_ref, topk); debug_buf.FromDevice(debug_host_dev.mData.data()); value_buf.FromDevice(value_host_dev.mData.data()); index_buf.FromDevice(index_host_dev.mData.data()); @@ -306,14 +306,14 @@ int main(int argc, char* argv[]) for(int i_t = 0; i_t < tokens; i_t++) { auto s_begin = std::vector{static_cast(i_t), static_cast(0)}; - auto s_end = - std::vector{static_cast(i_t + 1), static_cast(topk)}; // auto s_end = - // std::vector{static_cast(i_t + 1), static_cast(N)}; + // std::vector{static_cast(i_t + 1), static_cast(topk)}; + auto s_end = + std::vector{static_cast(i_t + 1), static_cast(N)}; auto s_debug_host = debug_host_dev.slice(s_begin, s_end); - // auto s_debug_ref = debug_ref.slice(s_begin, s_end); + auto s_debug_ref = debug_ref.slice(s_begin, s_end); // auto s_debug_ref = value_ref.slice(s_begin, s_end); - auto s_debug_ref = index_ref.slice(s_begin, s_end); + // auto s_debug_ref = index_ref.slice(s_begin, s_end); rtn &= ck_tile::check_err(s_debug_host, s_debug_ref, std::string("[") + std::to_string(i_t) + diff --git a/example/ck_tile/39_gemm_softmax_grouped_topk/grid_gemm_softmax_grouped_topk.hpp b/example/ck_tile/39_gemm_softmax_grouped_topk/grid_gemm_softmax_grouped_topk.hpp index a744abf954..5ec14481e7 100755 --- a/example/ck_tile/39_gemm_softmax_grouped_topk/grid_gemm_softmax_grouped_topk.hpp +++ b/example/ck_tile/39_gemm_softmax_grouped_topk/grid_gemm_softmax_grouped_topk.hpp @@ -137,6 +137,7 @@ struct GridGemm // block_gemm_pipeline(a_block_window, b_block_window, debug_block_tile, value_block_tile, index_block_tile, K / kKPerBlock, p_smem_char); const auto debug_block_tile = block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char); + // block_gemm_pipeline(a_block_window, b_block_window, debug_block_tile, K / kKPerBlock, p_smem_char); // cast DataType and apply CElementFunction const auto debug_cast_block_tile = tile_elementwise_in(