[rocm-libraries] ROCm/rocm-libraries#5225 (commit 880166b)

[CK] fix moe memset size which is bigger than alloc

## Motivation
Fix an out-of-bounds hipMemsetAsync in DeviceMoeGemmBlockScale that
crashes split-K MOE GEMM with "HIP runtime error: invalid argument".
When KBatch > 1, the invoker zeroes the output buffer using arg.M *
arg.N as the byte count. However, arg.M is the padded sorted-token-id
length from MOE routing, which can be much larger than the actual output
allocation (NumTokens * TopK * N). This causes hipMemsetAsync to write
beyond the buffer, and the silently-swallowed HIP error propagates to
the subsequent kernel launch via hipGetLastError().
This patch replaces arg.M with arg.NumTokens * arg.TopK so the memset
matches the actual output size.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
lalala-sh
2026-03-16 09:30:57 +00:00
committed by assistant-librarian[bot]
parent eb033ef208
commit a3ccd5dca1

View File

@@ -256,7 +256,8 @@ struct DeviceMoeGemmBlockScale
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType) *
arg_.NumTokens * arg_.TopK * arg_.N *
sizeof(CDataType) *
(IsInputGemm && IsSplitK ? 2 : 1),
stream_config.stream_id_));
};
@@ -273,12 +274,12 @@ struct DeviceMoeGemmBlockScale
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType) *
(IsInputGemm && IsSplitK ? 2 : 1),
stream_config.stream_id_));
hipGetErrorString(
hipMemsetAsync(arg.p_c_grid,
0,
arg.NumTokens * arg.TopK * arg.N * sizeof(CDataType) *
(IsInputGemm && IsSplitK ? 2 : 1),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}