mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5225 (commit 880166b)
[CK] fix moe memset size which is bigger than alloc ## Motivation Fix an out-of-bounds hipMemsetAsync in DeviceMoeGemmBlockScale that crashes split-K MOE GEMM with "HIP runtime error: invalid argument". When KBatch > 1, the invoker zeroes the output buffer using arg.M * arg.N as the byte count. However, arg.M is the padded sorted-token-id length from MOE routing, which can be much larger than the actual output allocation (NumTokens * TopK * N). This causes hipMemsetAsync to write beyond the buffer, and the silently-swallowed HIP error propagates to the subsequent kernel launch via hipGetLastError(). This patch replaces arg.M with arg.NumTokens * arg.TopK so the memset matches the actual output size. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
eb033ef208
commit
a3ccd5dca1
@@ -256,7 +256,8 @@ struct DeviceMoeGemmBlockScale
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType) *
|
||||
arg_.NumTokens * arg_.TopK * arg_.N *
|
||||
sizeof(CDataType) *
|
||||
(IsInputGemm && IsSplitK ? 2 : 1),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
@@ -273,12 +274,12 @@ struct DeviceMoeGemmBlockScale
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType) *
|
||||
(IsInputGemm && IsSplitK ? 2 : 1),
|
||||
stream_config.stream_id_));
|
||||
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.NumTokens * arg.TopK * arg.N * sizeof(CDataType) *
|
||||
(IsInputGemm && IsSplitK ? 2 : 1),
|
||||
stream_config.stream_id_));
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user