mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5086 (commit f4880d7)
[CK] Fix MOE FP8 SplitK buffer descriptor OOB When SplitK is enabled, kernel entry shifts A/B/AScale/BScale base pointers by SplitKBatchOffset, but make_dynamic_buffer element spaces are still based on full K dimension. This causes hardware buffer resource descriptors to extend beyond the actual tensor allocation, leading to GPU memory access faults when the tensor happens to be placed at the end of an allocated memory pool region. Fix by subtracting the split offset from each buffer's element space in both Run() (v1 pipeline) and Run_2Lds() (v2/v3 pipeline), so the buffer descriptor range [shifted_base, shifted_base + reduced_space) exactly covers the valid allocation. Also refactor SplitKBatchOffset to accept const Problem& (instead of Argument&) and add a default constructor, enabling direct reuse in Run/Run_2Lds without duplicating offset calculation logic. Made-with: Cursor ## Motivation <!-- Explain the purpose of this PR and the goals it aims to achieve. --> ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
e5683e2290
commit
345a56c55e
@@ -827,7 +827,15 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
|
||||
__device__ SplitKBatchOffset()
|
||||
: a_k_split_offset(0),
|
||||
b_k_split_offset(0),
|
||||
ascale_k_split_offset(0),
|
||||
bscale_k_split_offset(0)
|
||||
{
|
||||
}
|
||||
|
||||
__device__ SplitKBatchOffset(const Problem& karg, index_t k_id)
|
||||
{
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
@@ -847,19 +855,9 @@ struct GridwiseMoeGemmBlockScale
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
// KPack * NLane * KLane * K0 * N0
|
||||
b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
|
||||
bscale_k_split_offset = k_id * karg.KRead / ScaleBlockK;
|
||||
}
|
||||
|
||||
// if(k_id < karg.KBatch - 1)
|
||||
// {
|
||||
// karg.K = karg.KRead;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
|
||||
// }
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
@@ -1234,18 +1232,43 @@ struct GridwiseMoeGemmBlockScale
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
|
||||
// When SplitK is enabled, base pointers have been shifted by
|
||||
// SplitKBatchOffset in the kernel entry, but buffer descriptor element
|
||||
// spaces are still based on full K. Subtract the pointer shift from
|
||||
// each element space so the hardware buffer resource doesn't extend
|
||||
// beyond the actual tensor allocation.
|
||||
const auto splitk_offset = [&]() -> SplitKBatchOffset {
|
||||
if constexpr(IsSplitK)
|
||||
{
|
||||
return SplitKBatchOffset(problem, blockIdx.z);
|
||||
}
|
||||
else
|
||||
{
|
||||
return SplitKBatchOffset();
|
||||
}
|
||||
}();
|
||||
|
||||
assert(a_grid_desc_ak0_m_ak1.GetElementSpaceSize() >= splitk_offset.a_k_split_offset);
|
||||
assert(b_grid_desc_bpreshuffled.GetElementSpaceSize() >= splitk_offset.b_k_split_offset);
|
||||
assert(a_scale_grid_desc_am_ak.GetElementSpaceSize() >=
|
||||
splitk_offset.ascale_k_split_offset);
|
||||
assert(b_scale_grid_desc_bn_ak.GetElementSpaceSize() >=
|
||||
splitk_offset.bscale_k_split_offset);
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() - splitk_offset.a_k_split_offset);
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize() - splitk_offset.b_k_split_offset);
|
||||
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
p_a_scale_grid,
|
||||
a_scale_grid_desc_am_ak.GetElementSpaceSize() - splitk_offset.ascale_k_split_offset);
|
||||
const auto b_scale_grid_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_scale_grid + static_cast<long_index_t>(expert_id) * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize() -
|
||||
splitk_offset.bscale_k_split_offset);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 =
|
||||
@@ -1742,18 +1765,39 @@ struct GridwiseMoeGemmBlockScale
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
|
||||
// Same fix as Run(): reduce buffer element spaces by split offset
|
||||
const auto splitk_offset = [&]() -> SplitKBatchOffset {
|
||||
if constexpr(IsSplitK)
|
||||
{
|
||||
return SplitKBatchOffset(problem, blockIdx.z);
|
||||
}
|
||||
else
|
||||
{
|
||||
return SplitKBatchOffset();
|
||||
}
|
||||
}();
|
||||
|
||||
assert(a_grid_desc_ak0_m_ak1.GetElementSpaceSize() >= splitk_offset.a_k_split_offset);
|
||||
assert(b_grid_desc_bpreshuffled.GetElementSpaceSize() >= splitk_offset.b_k_split_offset);
|
||||
assert(a_scale_grid_desc_am_ak.GetElementSpaceSize() >=
|
||||
splitk_offset.ascale_k_split_offset);
|
||||
assert(b_scale_grid_desc_bn_ak.GetElementSpaceSize() >=
|
||||
splitk_offset.bscale_k_split_offset);
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() - splitk_offset.a_k_split_offset);
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize() - splitk_offset.b_k_split_offset);
|
||||
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
p_a_scale_grid,
|
||||
a_scale_grid_desc_am_ak.GetElementSpaceSize() - splitk_offset.ascale_k_split_offset);
|
||||
const auto b_scale_grid_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_scale_grid + static_cast<long_index_t>(expert_id) * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize() -
|
||||
splitk_offset.bscale_k_split_offset);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 =
|
||||
|
||||
Reference in New Issue
Block a user