From 7e4c8a56edd9870421ac0734e9be2d680239d469 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 19 Mar 2025 22:58:27 +0800 Subject: [PATCH] Ck moe hot fix (#1979) * fix useless code and remove usless oob * clang format * fix coredump in e2e test * fix2 * fix clang format * fix output oob * clang format * rm useless comments --------- Co-authored-by: coderfeli Co-authored-by: illsilin [ROCm/composable_kernel commit: 7eaedeb36cc1dabd739ab59339afc007970b5393] --- .../gpu/grid/gridwise_moe_gemm.hpp | 17 +++++++---------- ...dwise_tensor_slice_transfer_v7r3_scatter.hpp | 8 ++++---- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 5337fd5e2c..1924c27b2b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1563,12 +1563,8 @@ struct GridwiseMoeGemm const float* p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray - scatter_offsets; //= p_sorted_token_ids[c_token_pos]; + StaticallyIndexedArray scatter_offsets; StaticallyIndexedArray scatter_weights; //= for topk - // too hack here, 2 specific for topk weights, fixme - // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] - // & 0xff000000) >> 24; auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = @@ -1576,7 +1572,9 @@ struct GridwiseMoeGemm static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; index_t token_offset = fused_token & 0xffffff; - float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]]; + float weight = token_offset < problem.NumTokens + ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] + : 0.0; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); @@ -2074,9 +2072,6 @@ struct GridwiseMoeGemm StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; StaticallyIndexedArray scatter_weights; //= for topk - // too hack here, 2 specific for topk weights, fixme - // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] - // & 0xff000000) >> 24; auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = @@ -2084,7 +2079,9 @@ struct GridwiseMoeGemm static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; index_t token_offset = fused_token & 0xffffff; - float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]]; + float weight = token_offset < problem.NumTokens + ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] + : 0.0; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 29570c94e3..6a1c195dc1 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -430,13 +430,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter } // copy data from buf_vectors into dst_bufs static_for<0, nDst, 1>{}([&](auto i) { - using dst_vector_t = typename remove_cvref_t::type; - auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); + using dst_vector_t = typename remove_cvref_t::type; + auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); + const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize(); constexpr InMemoryDataOperationEnum DstInMemOp = static_cast(DstInMemOps::At(i.value)); - dst_bufs(i).template Update( - dst_offset, true, dst_vectors[i].template AsType()[I0]); + dst_offset, is_dst_valid, dst_vectors[i].template AsType()[I0]); }); // move coordinate