From debb95d35ac00458dd856bd2af3cee5f0db474b2 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Mon, 16 Mar 2026 17:30:07 +0800 Subject: [PATCH] [CK] fix moe memset size which is bigger than alloc (#5225) ## Motivation Fix an out-of-bounds hipMemsetAsync in DeviceMoeGemmBlockScale that crashes split-K MOE GEMM with "HIP runtime error: invalid argument". When KBatch > 1, the invoker zeroes the output buffer using arg.M * arg.N as the byte count. However, arg.M is the padded sorted-token-id length from MOE routing, which can be much larger than the actual output allocation (NumTokens * TopK * N). This causes hipMemsetAsync to write beyond the buffer, and the silently-swallowed HIP error propagates to the subsequent kernel launch via hipGetLastError(). This patch replaces arg.M with arg.NumTokens * arg.TopK so the memset matches the actual output size. ## Technical Details ## Test Plan ## Test Result ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../device/impl/device_moe_gemm_blockscale.hpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index 12d28f572c..684219b584 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -256,7 +256,8 @@ struct DeviceMoeGemmBlockScale if(arg_.KBatch > 1) hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, 0, - arg_.M * arg_.N * sizeof(CDataType) * + arg_.NumTokens * arg_.TopK * arg_.N * + sizeof(CDataType) * (IsInputGemm && IsSplitK ? 2 : 1), stream_config.stream_id_)); }; @@ -273,12 +274,12 @@ struct DeviceMoeGemmBlockScale else { if(arg.KBatch > 1) - hipGetErrorString(hipMemsetAsync(arg.p_c_grid, - 0, - arg.M * arg.N * sizeof(CDataType) * - (IsInputGemm && IsSplitK ? 2 : 1), - stream_config.stream_id_)); - + hipGetErrorString( + hipMemsetAsync(arg.p_c_grid, + 0, + arg.NumTokens * arg.TopK * arg.N * sizeof(CDataType) * + (IsInputGemm && IsSplitK ? 2 : 1), + stream_config.stream_id_)); ave_time = launch_and_time_kernel( stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); }