[rocm-libraries] ROCm/rocm-libraries#5086 (commit f4880d7)

[CK] Fix MOE FP8 SplitK buffer descriptor OOB

When SplitK is enabled, kernel entry shifts A/B/AScale/BScale base
pointers by SplitKBatchOffset, but make_dynamic_buffer element spaces
are still based on full K dimension. This causes hardware buffer
resource descriptors to extend beyond the actual tensor allocation,
leading to GPU memory access faults when the tensor happens to be placed
at the end of an allocated memory pool region.

Fix by subtracting the split offset from each buffer's element space in
both Run() (v1 pipeline) and Run_2Lds() (v2/v3 pipeline), so the buffer
descriptor range [shifted_base, shifted_base + reduced_space) exactly
covers the valid allocation.

Also refactor SplitKBatchOffset to accept const Problem& (instead of
Argument&) and add a default constructor, enabling direct reuse in
Run/Run_2Lds without duplicating offset calculation logic.

Made-with: Cursor

## Motivation

<!-- Explain the purpose of this PR and the goals it aims to achieve.
-->

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
lalala-sh
2026-03-19 02:43:30 +00:00
committed by assistant-librarian[bot]
parent e5683e2290
commit 345a56c55e

View File

@@ -827,7 +827,15 @@ struct GridwiseMoeGemmBlockScale
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
__device__ SplitKBatchOffset()
: a_k_split_offset(0),
b_k_split_offset(0),
ascale_k_split_offset(0),
bscale_k_split_offset(0)
{
}
__device__ SplitKBatchOffset(const Problem& karg, index_t k_id)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
@@ -847,19 +855,9 @@ struct GridwiseMoeGemmBlockScale
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
// KPack * NLane * KLane * K0 * N0
b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
bscale_k_split_offset = k_id * karg.KRead / ScaleBlockK;
}
// if(k_id < karg.KBatch - 1)
// {
// karg.K = karg.KRead;
// }
// else
// {
// karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
// }
}
index_t a_k_split_offset;
@@ -1234,18 +1232,43 @@ struct GridwiseMoeGemmBlockScale
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
// When SplitK is enabled, base pointers have been shifted by
// SplitKBatchOffset in the kernel entry, but buffer descriptor element
// spaces are still based on full K. Subtract the pointer shift from
// each element space so the hardware buffer resource doesn't extend
// beyond the actual tensor allocation.
const auto splitk_offset = [&]() -> SplitKBatchOffset {
if constexpr(IsSplitK)
{
return SplitKBatchOffset(problem, blockIdx.z);
}
else
{
return SplitKBatchOffset();
}
}();
assert(a_grid_desc_ak0_m_ak1.GetElementSpaceSize() >= splitk_offset.a_k_split_offset);
assert(b_grid_desc_bpreshuffled.GetElementSpaceSize() >= splitk_offset.b_k_split_offset);
assert(a_scale_grid_desc_am_ak.GetElementSpaceSize() >=
splitk_offset.ascale_k_split_offset);
assert(b_scale_grid_desc_bn_ak.GetElementSpaceSize() >=
splitk_offset.bscale_k_split_offset);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() - splitk_offset.a_k_split_offset);
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
b_grid_desc_bpreshuffled.GetElementSpaceSize() - splitk_offset.b_k_split_offset);
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
p_a_scale_grid,
a_scale_grid_desc_am_ak.GetElementSpaceSize() - splitk_offset.ascale_k_split_offset);
const auto b_scale_grid_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid + static_cast<long_index_t>(expert_id) * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
b_scale_grid_desc_bn_ak.GetElementSpaceSize() -
splitk_offset.bscale_k_split_offset);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 =
@@ -1742,18 +1765,39 @@ struct GridwiseMoeGemmBlockScale
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
// Same fix as Run(): reduce buffer element spaces by split offset
const auto splitk_offset = [&]() -> SplitKBatchOffset {
if constexpr(IsSplitK)
{
return SplitKBatchOffset(problem, blockIdx.z);
}
else
{
return SplitKBatchOffset();
}
}();
assert(a_grid_desc_ak0_m_ak1.GetElementSpaceSize() >= splitk_offset.a_k_split_offset);
assert(b_grid_desc_bpreshuffled.GetElementSpaceSize() >= splitk_offset.b_k_split_offset);
assert(a_scale_grid_desc_am_ak.GetElementSpaceSize() >=
splitk_offset.ascale_k_split_offset);
assert(b_scale_grid_desc_bn_ak.GetElementSpaceSize() >=
splitk_offset.bscale_k_split_offset);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() - splitk_offset.a_k_split_offset);
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
b_grid_desc_bpreshuffled.GetElementSpaceSize() - splitk_offset.b_k_split_offset);
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
p_a_scale_grid,
a_scale_grid_desc_am_ak.GetElementSpaceSize() - splitk_offset.ascale_k_split_offset);
const auto b_scale_grid_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid + static_cast<long_index_t>(expert_id) * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
b_scale_grid_desc_bn_ak.GetElementSpaceSize() -
splitk_offset.bscale_k_split_offset);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 =