diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 82be6ac7ce..48ccb49db4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1235,9 +1235,9 @@ struct GridwiseMoeGemm } gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = + const IndexType expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - + const IndexType expert_offset = expert_id * expert_stride / BPackedSize; // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); @@ -1245,8 +1245,7 @@ struct GridwiseMoeGemm const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * expert_stride / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); + p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -1335,8 +1334,7 @@ struct GridwiseMoeGemm { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * expert_stride / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); + p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize()); auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< BDataType, BDataType, @@ -1947,9 +1945,9 @@ struct GridwiseMoeGemm } gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = + const IndexType expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - + const IndexType expert_offset = expert_id * expert_stride / BPackedSize; // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); @@ -1957,8 +1955,7 @@ struct GridwiseMoeGemm const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * expert_stride / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); + p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -2055,8 +2052,7 @@ struct GridwiseMoeGemm { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * expert_stride / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); + p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize()); auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< BDataType, BDataType,