From 234b8d415c413fa4da3adee84d9305febe3161b0 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Tue, 25 Mar 2025 09:44:32 +0000 Subject: [PATCH] change code --- .../gpu/grid/gridwise_moe_gemm.hpp | 52 ++++++++++--------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index d203fc40fa..a9dafc62cd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1347,7 +1347,7 @@ struct GridwiseMoeGemm c_thread_buf_up, num_k_block_main_loop); - static_assert(NXdlPerWave == 1, "ONLY 1 now"); + // static_assert(NXdlPerWave == 1, "ONLY 1 now"); // const float scale_gate = scale_b[0]; // const float scale_up = scale_b[problem.N * perTokenQuantStride]; // static_for<0, c_thread_buf.Size(), 1>{}([&](auto i) { @@ -1397,13 +1397,17 @@ struct GridwiseMoeGemm const index_t m1 = get_warp_local_1d_id() / M1; const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; vector_type scale_token_ids; + vector_type topk_weights; // for gemm2 only static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave static_for<0, NXdlPerWave, 1>{}([&](auto n0) { static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk - + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; if constexpr(perTokenQuantStride) { - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; - scale_token_ids = *c_style_pointer_cast *>(p_sorted_token_ids + m_pos); + scale_token_ids = *c_style_pointer_cast *>(p_sorted_token_ids + m_pos); + } + if constexpr (!IsInputGemm) + { + topk_weights = *c_style_pointer_cast *>(p_ds_grid[I2] + m_pos); } static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size float scale_a = [&]() { @@ -1421,10 +1425,17 @@ struct GridwiseMoeGemm constexpr index_t c_offset = blockwise_gemm_pipeline.c_thread_desc_.CalculateOffset(make_tuple(m0, n0, m2 * M4 + m4)); constexpr auto cidx = Number{}; - auto gate = scale_a * scale_gate * c_thread_buf[cidx]; - auto up = scale_a * scale_up * c_thread_buf_up[cidx]; - gate = gate * math::rcp(1.0 + math::exp(-gate)); - c_thread_buf(cidx) = gate * up; + if constexpr (IsInputGemm) // gu fusion + { + auto gate = scale_a * scale_gate * c_thread_buf[cidx]; + auto up = scale_a * scale_up * c_thread_buf_up[cidx]; + gate = gate * math::rcp(1.0 + math::exp(-gate)); + c_thread_buf(cidx) = gate * up; + } + else + { + c_thread_buf(cidx) = scale_a * scale_gate * c_thread_buf[cidx]; + } }); }); }); @@ -1527,17 +1538,8 @@ struct GridwiseMoeGemm const auto ds_grid_buf = generate_tuple( [&](auto i) { - using DDataType = remove_cvref_t>; - const DDataType* ptr_ = p_ds_grid[i]; - // hack logic here to support different kind of strides. todo fix it. - // ascale t, 1; bscale E, N, 1, move ptr to E - // if(i.value == 1) - // { - // ptr_ += - // expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1); - // } return make_dynamic_buffer( - ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); }, Number{}); @@ -1667,18 +1669,18 @@ struct GridwiseMoeGemm static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; IndexType token_offset = fused_token & 0xffffff; - float weight = 1.0f; + // float weight = 1.0f; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - else - { - const float* p_sorted_weights_2 = p_ds_grid[I2]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; - } + // else + // { + // const float* p_sorted_weights_2 = p_ds_grid[I2]; + // weight = weight * p_sorted_weights_2[c_token_pos + m0]; + // } scatter_offsets(m0) = static_cast(token_offset) * problem.N; - scatter_weights(m0) = weight; + // scatter_weights(m0) = weight; }); block_sync_lds();