From 820ac195a0f5630c555fe90f5347a5e25abffb3a Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Fri, 6 Mar 2026 10:00:01 +0800 Subject: [PATCH] [CK] use int64 for ptr offset (#5094) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation When the number of experts (E) is large (e.g., E=257 in DeepSeek-V3), the `expert_id * expert_stride` calculation in MOE GEMM kernels overflows `int32` (`index_t`), causing the weight matrix (B) pointer to wrap to an invalid address and triggering a GPU memory access fault. For example, with `N=1024, K=7168, IsInputGemm=true`: - `expert_stride = N * K * 2 = 14,680,064` - `INT32_MAX / expert_stride ≈ 146` - Any `expert_id >= 147` causes overflow → negative offset → illegal memory access → GPU crash ## Technical Details ## Test Plan ## Test Result ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Co-authored-by: Thomas Ning Co-authored-by: amd-shiraz --- .../gpu/grid/gridwise_moe_gemm.hpp | 14 ++++--- .../gpu/grid/gridwise_moe_gemm_blockscale.hpp | 37 +++++++++--------- .../gpu/grid/gridwise_moe_mx_gemm.hpp | 35 +++++++++-------- .../gpu/grid/gridwise_moe_mx_gemm_bns.hpp | 35 +++++++++-------- .../grid/gridwise_moe_mx_gemm_bpreshuffle.hpp | 38 +++++++++++-------- 5 files changed, 87 insertions(+), 72 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 5e95d3c55b..90e6021dcb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1175,9 +1175,10 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< } gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const IndexType expert_stride = - __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const IndexType expert_offset = expert_id * expert_stride / BPackedSize; + const long_index_t expert_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * problem.K * (IsInputGemm ? 2 : 1)); + const long_index_t expert_offset = + static_cast(expert_id) * expert_stride / BPackedSize; // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); @@ -1640,9 +1641,10 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< } gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const IndexType expert_stride = - __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const IndexType expert_offset = expert_id * expert_stride / BPackedSize; + const long_index_t expert_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * problem.K * (IsInputGemm ? 2 : 1)); + const long_index_t expert_offset = + static_cast(expert_id) * expert_stride / BPackedSize; // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index 6c7b77476a..673b6b2f21 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -1224,11 +1224,11 @@ struct GridwiseMoeGemmBlockScale } gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = - __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( - math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) * - math::integer_divide_ceil(problem.K, ScaleBlockK)); + const long_index_t expert_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * problem.K * (IsInputGemm ? 2 : 1)); + const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + static_cast(math::integer_divide_ceil(problem.N, ScaleBlockN)) * + (IsInputGemm ? 2 : 1) * math::integer_divide_ceil(problem.K, ScaleBlockK)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = @@ -1237,14 +1237,14 @@ struct GridwiseMoeGemmBlockScale const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * static_cast(expert_stride) / BPackedSize, + p_b_grid + static_cast(expert_id) * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + expert_id * expert_scale_stride, + p_b_scale_grid + static_cast(expert_id) * expert_scale_stride, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy @@ -1406,7 +1406,7 @@ struct GridwiseMoeGemmBlockScale const auto b_grid_buf_up = make_dynamic_buffer( p_b_grid_up + - expert_id * static_cast(expert_stride) / BPackedSize, + static_cast(expert_id) * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< BDataType, @@ -1427,7 +1427,7 @@ struct GridwiseMoeGemmBlockScale p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride, + p_b_scale_grid_up + static_cast(expert_id) * expert_scale_stride, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2(token_offset) * problem.K; }); - const index_t expert_stride = - __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( - math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) * - math::integer_divide_ceil(problem.K, ScaleBlockK)); + const long_index_t expert_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * problem.K * (IsInputGemm ? 2 : 1)); + const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + static_cast(math::integer_divide_ceil(problem.N, ScaleBlockN)) * + (IsInputGemm ? 2 : 1) * math::integer_divide_ceil(problem.K, ScaleBlockK)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); @@ -1745,14 +1745,14 @@ struct GridwiseMoeGemmBlockScale const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * static_cast(expert_stride) / BPackedSize, + p_b_grid + static_cast(expert_id) * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + expert_id * expert_scale_stride, + p_b_scale_grid + static_cast(expert_id) * expert_scale_stride, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy @@ -1922,7 +1922,7 @@ struct GridwiseMoeGemmBlockScale const auto b_grid_buf_up = make_dynamic_buffer( p_b_grid_up + - expert_id * static_cast(expert_stride) / BPackedSize, + static_cast(expert_id) * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< BDataType, @@ -1943,7 +1943,8 @@ struct GridwiseMoeGemmBlockScale p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize, + p_b_scale_grid_up + + static_cast(expert_id) * expert_scale_stride / BPackedSize, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2(token_offset); }); - const index_t expert_stride = - __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( - problem.N * (IsInputGemm ? 2 : 1) * + const long_index_t expert_stride = + __builtin_amdgcn_readfirstlane(static_cast(problem.N) * problem.K * (IsInputGemm ? 2 : 1)); + const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * (IsInputGemm ? 2 : 1) * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); // N0, K0, Blocksize*KPack @@ -1356,13 +1356,13 @@ struct GridwiseMoeGemmMX const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + p_b_grid + static_cast(expert_id) * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); // A, B scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + p_b_scale_grid + (static_cast(expert_id) * expert_scale_stride) / sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // lds max alignment @@ -1502,7 +1502,7 @@ struct GridwiseMoeGemmMX const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * expert_stride, + p_b_grid_up + static_cast(expert_id) * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad< @@ -1525,7 +1525,7 @@ struct GridwiseMoeGemmMX const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType); const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType), + p_b_scale_grid_up + static_cast(expert_id) * expert_scale_stride / sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< @@ -2105,10 +2105,10 @@ struct GridwiseMoeGemmMX gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = - __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( - problem.N * (IsInputGemm ? 2 : 1) * + const long_index_t expert_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * problem.K * (IsInputGemm ? 2 : 1)); + const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * (IsInputGemm ? 2 : 1) * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); // N0, K0, Blocksize*KPack @@ -2119,13 +2119,15 @@ struct GridwiseMoeGemmMX const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + p_b_grid + static_cast(expert_id) * expert_stride, + b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); // A, B scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + p_b_scale_grid + (static_cast(expert_id) * expert_scale_stride) / + sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // lds max alignment @@ -2274,7 +2276,7 @@ struct GridwiseMoeGemmMX { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * expert_stride, + p_b_grid_up + static_cast(expert_id) * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); // lds ping pong buffers for up @@ -2313,7 +2315,8 @@ struct GridwiseMoeGemmMX const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType); const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType), + p_b_scale_grid_up + static_cast(expert_id) * expert_scale_stride / + sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index 8cf03e3f5c..8559b78fe0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -1235,10 +1235,10 @@ struct GridwiseMoeGemmMXBNS gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = - __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( - problem.N * (IsInputGemm ? 2 : 1) * + const long_index_t expert_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * problem.K * (IsInputGemm ? 2 : 1)); + const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * (IsInputGemm ? 2 : 1) * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); // N0, K0, Blocksize*KPack @@ -1249,13 +1249,15 @@ struct GridwiseMoeGemmMXBNS const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + p_b_grid + static_cast(expert_id) * expert_stride, + b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); // A, B scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + p_b_scale_grid + (static_cast(expert_id) * expert_scale_stride) / + sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // lds max alignment @@ -1421,7 +1423,7 @@ struct GridwiseMoeGemmMXBNS const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * expert_stride, + p_b_grid_up + static_cast(expert_id) * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); auto b_blockwise_copy_up = @@ -1457,7 +1459,8 @@ struct GridwiseMoeGemmMXBNS const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType); const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType), + p_b_scale_grid_up + static_cast(expert_id) * expert_scale_stride / + sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< @@ -1763,10 +1766,10 @@ struct GridwiseMoeGemmMXBNS gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = - __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( - problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); + const long_index_t expert_stride = + __builtin_amdgcn_readfirstlane(static_cast(problem.N) * problem.K * (IsInputGemm ? 2 : 1)); + const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = @@ -1776,12 +1779,12 @@ struct GridwiseMoeGemmMXBNS p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + p_b_grid + static_cast(expert_id) * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + p_b_scale_grid + (static_cast(expert_id) * expert_scale_stride) / sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy @@ -1933,7 +1936,7 @@ struct GridwiseMoeGemmMXBNS { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * expert_stride / BPackedSize, + p_b_grid_up + static_cast(expert_id) * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< BDataType, @@ -1952,7 +1955,7 @@ struct GridwiseMoeGemmMXBNS KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride, + p_b_scale_grid_up + static_cast(expert_id) * expert_scale_stride, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< BScaleDataType, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp index c8a917a115..3254c5d043 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp @@ -1329,10 +1329,10 @@ struct GridwiseMoeGemmMX_BPreshuffle gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = - __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( - problem.N * (IsInputGemm ? 2 : 1) * + const long_index_t expert_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * problem.K * (IsInputGemm ? 2 : 1)); + const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * (IsInputGemm ? 2 : 1) * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); // N0, K0, Blocksize*KPack @@ -1343,13 +1343,15 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + p_b_grid + static_cast(expert_id) * expert_stride, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); // A, B scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + p_b_scale_grid + (static_cast(expert_id) * expert_scale_stride) / + sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy @@ -1487,7 +1489,7 @@ struct GridwiseMoeGemmMX_BPreshuffle { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * expert_stride, + p_b_grid_up + static_cast(expert_id) * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2( - p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType), + p_b_scale_grid_up + static_cast(expert_id) * expert_scale_stride / + sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< @@ -1819,10 +1822,10 @@ struct GridwiseMoeGemmMX_BPreshuffle gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = - __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( - problem.N * (IsInputGemm ? 2 : 1) * + const long_index_t expert_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * problem.K * (IsInputGemm ? 2 : 1)); + const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + static_cast(problem.N) * (IsInputGemm ? 2 : 1) * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); // N0, K0, Blocksize*KPack @@ -1833,13 +1836,15 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + p_b_grid + static_cast(expert_id) * expert_stride, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); // A, B scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + p_b_scale_grid + (static_cast(expert_id) * expert_scale_stride) / + sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy @@ -1980,7 +1985,7 @@ struct GridwiseMoeGemmMX_BPreshuffle { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * expert_stride, + p_b_grid_up + static_cast(expert_id) * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2( - p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType), + p_b_scale_grid_up + static_cast(expert_id) * expert_scale_stride / + sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<