[rocm-libraries] ROCm/rocm-libraries#5094 (commit d4548e6)

[CK] use int64 for ptr offset
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

When the number of experts (E) is large (e.g., E=257 in DeepSeek-V3),
the `expert_id * expert_stride` calculation in MOE GEMM kernels
overflows `int32` (`index_t`), causing the weight matrix (B) pointer to
wrap to an invalid address and triggering a GPU memory access fault.

For example, with `N=1024, K=7168, IsInputGemm=true`:
- `expert_stride = N * K * 2 = 14,680,064`
- `INT32_MAX / expert_stride ≈ 146`
- Any `expert_id >= 147` causes overflow → negative offset → illegal
memory access → GPU crash

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: amd-shiraz <shiraz.ali@amd.com>
This commit is contained in:
lalala-sh
2026-03-06 02:01:03 +00:00
committed by assistant-librarian[bot]
parent 7ee6c44387
commit b0c13f3124
5 changed files with 87 additions and 72 deletions

View File

@@ -1175,9 +1175,10 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
}
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const IndexType expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
const long_index_t expert_offset =
static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize;
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
@@ -1640,9 +1641,10 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
}
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const IndexType expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
const long_index_t expert_offset =
static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize;
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);

View File

@@ -1224,11 +1224,11 @@ struct GridwiseMoeGemmBlockScale
}
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockK));
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(math::integer_divide_ceil(problem.N, ScaleBlockN)) *
(IsInputGemm ? 2 : 1) * math::integer_divide_ceil(problem.K, ScaleBlockK));
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
@@ -1237,14 +1237,14 @@ struct GridwiseMoeGemmBlockScale
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid + expert_id * expert_scale_stride,
p_b_scale_grid + static_cast<long_index_t>(expert_id) * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy
@@ -1406,7 +1406,7 @@ struct GridwiseMoeGemmBlockScale
const auto b_grid_buf_up =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid_up +
expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
BDataType,
@@ -1427,7 +1427,7 @@ struct GridwiseMoeGemmBlockScale
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const auto b_scale_grid_buf_up =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
@@ -1733,11 +1733,11 @@ struct GridwiseMoeGemmBlockScale
}
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockK));
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(math::integer_divide_ceil(problem.N, ScaleBlockN)) *
(IsInputGemm ? 2 : 1) * math::integer_divide_ceil(problem.K, ScaleBlockK));
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
@@ -1745,14 +1745,14 @@ struct GridwiseMoeGemmBlockScale
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid + expert_id * expert_scale_stride,
p_b_scale_grid + static_cast<long_index_t>(expert_id) * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy
@@ -1922,7 +1922,7 @@ struct GridwiseMoeGemmBlockScale
const auto b_grid_buf_up =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid_up +
expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
BDataType,
@@ -1943,7 +1943,8 @@ struct GridwiseMoeGemmBlockScale
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const auto b_scale_grid_buf_up =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
p_b_scale_grid_up +
static_cast<long_index_t>(expert_id) * expert_scale_stride / BPackedSize,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up =
ThreadwiseTensorSliceTransfer_v2<BScaleType,

View File

@@ -1342,10 +1342,10 @@ struct GridwiseMoeGemmMX
gather_offsets(m0) = static_cast<IndexType>(token_offset);
});
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
problem.N * (IsInputGemm ? 2 : 1) *
const long_index_t expert_stride =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
// N0, K0, Blocksize*KPack
@@ -1356,13 +1356,13 @@ struct GridwiseMoeGemmMX
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
// A, B scale buffer
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
p_b_scale_grid + (static_cast<long_index_t>(expert_id) * expert_scale_stride) / sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// lds max alignment
@@ -1502,7 +1502,7 @@ struct GridwiseMoeGemmMX
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * expert_stride,
p_b_grid_up + static_cast<long_index_t>(expert_id) * expert_stride,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
@@ -1525,7 +1525,7 @@ struct GridwiseMoeGemmMX
const BScaleDataType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride / sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
@@ -2105,10 +2105,10 @@ struct GridwiseMoeGemmMX
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
problem.N * (IsInputGemm ? 2 : 1) *
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
// N0, K0, Blocksize*KPack
@@ -2119,13 +2119,15 @@ struct GridwiseMoeGemmMX
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
// A, B scale buffer
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
p_b_scale_grid + (static_cast<long_index_t>(expert_id) * expert_scale_stride) /
sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// lds max alignment
@@ -2274,7 +2276,7 @@ struct GridwiseMoeGemmMX
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * expert_stride,
p_b_grid_up + static_cast<long_index_t>(expert_id) * expert_stride,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
// lds ping pong buffers for up
@@ -2313,7 +2315,8 @@ struct GridwiseMoeGemmMX
const BScaleDataType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride /
sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<

View File

@@ -1235,10 +1235,10 @@ struct GridwiseMoeGemmMXBNS
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
problem.N * (IsInputGemm ? 2 : 1) *
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
// N0, K0, Blocksize*KPack
@@ -1249,13 +1249,15 @@ struct GridwiseMoeGemmMXBNS
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
// A, B scale buffer
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
p_b_scale_grid + (static_cast<long_index_t>(expert_id) * expert_scale_stride) /
sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// lds max alignment
@@ -1421,7 +1423,7 @@ struct GridwiseMoeGemmMXBNS
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * expert_stride,
p_b_grid_up + static_cast<long_index_t>(expert_id) * expert_stride,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_blockwise_copy_up =
@@ -1457,7 +1459,8 @@ struct GridwiseMoeGemmMXBNS
const BScaleDataType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride /
sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
@@ -1763,10 +1766,10 @@ struct GridwiseMoeGemmMXBNS
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
const long_index_t expert_stride =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
@@ -1776,12 +1779,12 @@ struct GridwiseMoeGemmMXBNS
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
p_b_scale_grid + (static_cast<long_index_t>(expert_id) * expert_scale_stride) / sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy
@@ -1933,7 +1936,7 @@ struct GridwiseMoeGemmMXBNS
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * expert_stride / BPackedSize,
p_b_grid_up + static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
BDataType,
@@ -1952,7 +1955,7 @@ struct GridwiseMoeGemmMXBNS
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
BScaleDataType,

View File

@@ -1329,10 +1329,10 @@ struct GridwiseMoeGemmMX_BPreshuffle
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
problem.N * (IsInputGemm ? 2 : 1) *
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
// N0, K0, Blocksize*KPack
@@ -1343,13 +1343,15 @@ struct GridwiseMoeGemmMX_BPreshuffle
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
// A, B scale buffer
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
p_b_scale_grid + (static_cast<long_index_t>(expert_id) * expert_scale_stride) /
sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy
@@ -1487,7 +1489,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * expert_stride,
p_b_grid_up + static_cast<long_index_t>(expert_id) * expert_stride,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up =
ThreadwiseTensorSliceTransfer_v2<BDataType,
@@ -1513,7 +1515,8 @@ struct GridwiseMoeGemmMX_BPreshuffle
const BScaleDataType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride /
sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
@@ -1819,10 +1822,10 @@ struct GridwiseMoeGemmMX_BPreshuffle
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
problem.N * (IsInputGemm ? 2 : 1) *
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(problem.N) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
// N0, K0, Blocksize*KPack
@@ -1833,13 +1836,15 @@ struct GridwiseMoeGemmMX_BPreshuffle
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
// A, B scale buffer
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
p_b_scale_grid + (static_cast<long_index_t>(expert_id) * expert_scale_stride) /
sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy
@@ -1980,7 +1985,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * expert_stride,
p_b_grid_up + static_cast<long_index_t>(expert_id) * expert_stride,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up =
ThreadwiseTensorSliceTransfer_v2<BDataType,
@@ -2006,7 +2011,8 @@ struct GridwiseMoeGemmMX_BPreshuffle
const BScaleDataType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride /
sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<