From c0c1c04b5049cd22429a47be86e1f98f2275c089 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Apr 2025 06:26:57 +0000 Subject: [PATCH] fix bugs --- ...dlops_b_preshuffle_gufusion_dequant_v1.hpp | 1 + .../gpu/grid/gridwise_moe_gemm.hpp | 32 +++++++++++++++---- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp index ce102ff1ad..29750b8baa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -314,6 +314,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< // Initialize C c_thread_buf.Clear(); + c_thread_buf_up.Clear(); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 0930a64b55..7582669e08 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1205,6 +1205,7 @@ struct GridwiseMoeGemm return {blockIdx.x, blockIdx.y}; } }(); + const index_t block_n_id = block_mn.first; const index_t block_m_id = block_mn.second; const index_t token0 = @@ -1320,7 +1321,7 @@ struct GridwiseMoeGemm KPerBlock); if constexpr(IsInputGemm) { - const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; const auto b_grid_buf_up = make_dynamic_buffer( p_b_grid_up + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); @@ -1468,8 +1469,18 @@ struct GridwiseMoeGemm } else if(ActivationOperation == Activation::gelu) { - tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx), - c_thread_buf(cidx)); + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + auto gate = scale_a * scale_b * c_thread_buf[cidx]; + auto up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; } else if(ActivationOperation == Activation::swiglu) { @@ -1478,7 +1489,12 @@ struct GridwiseMoeGemm PerTokenQuant]; auto gate = scale_a * scale_b * c_thread_buf[cidx]; auto up = scale_a * scale_up * c_thread_buf_up[cidx]; - gate = gate * math::rcp(1.0 + math::exp(-gate)); + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf(cidx) = gate * up; } } @@ -1524,14 +1540,16 @@ struct GridwiseMoeGemm } else if(ActivationOperation == Activation::gelu) { - tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx), - c_thread_buf(cidx)); + auto gate = c_thread_buf[cidx]; + auto up = c_thread_buf_up[cidx]; + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; } else if(ActivationOperation == Activation::swiglu) { auto gate = c_thread_buf[cidx]; auto up = c_thread_buf_up[cidx]; - gate = gate * math::rcp(1.0 + math::exp(-gate)); + tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf(cidx) = gate * up; } }