mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5094 (commit d4548e6)
[CK] use int64 for ptr offset MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation When the number of experts (E) is large (e.g., E=257 in DeepSeek-V3), the `expert_id * expert_stride` calculation in MOE GEMM kernels overflows `int32` (`index_t`), causing the weight matrix (B) pointer to wrap to an invalid address and triggering a GPU memory access fault. For example, with `N=1024, K=7168, IsInputGemm=true`: - `expert_stride = N * K * 2 = 14,680,064` - `INT32_MAX / expert_stride ≈ 146` - Any `expert_id >= 147` causes overflow → negative offset → illegal memory access → GPU crash ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Co-authored-by: Thomas Ning <Thomas.Ning@amd.com> Co-authored-by: amd-shiraz <shiraz.ali@amd.com>
This commit is contained in:
committed by
assistant-librarian[bot]
parent
7ee6c44387
commit
b0c13f3124
@@ -1175,9 +1175,10 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const IndexType expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
|
||||
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const long_index_t expert_offset =
|
||||
static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize;
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
@@ -1640,9 +1641,10 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const IndexType expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
|
||||
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const long_index_t expert_offset =
|
||||
static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize;
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
|
||||
@@ -1224,11 +1224,11 @@ struct GridwiseMoeGemmBlockScale
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK));
|
||||
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(math::integer_divide_ceil(problem.N, ScaleBlockN)) *
|
||||
(IsInputGemm ? 2 : 1) * math::integer_divide_ceil(problem.K, ScaleBlockK));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
@@ -1237,14 +1237,14 @@ struct GridwiseMoeGemmBlockScale
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
|
||||
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_scale_grid + expert_id * expert_scale_stride,
|
||||
p_b_scale_grid + static_cast<long_index_t>(expert_id) * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
@@ -1406,7 +1406,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
const auto b_grid_buf_up =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_grid_up +
|
||||
expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
|
||||
static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
@@ -1427,7 +1427,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
|
||||
const auto b_scale_grid_buf_up =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride,
|
||||
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
auto b_scale_thread_copy_up =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleType,
|
||||
@@ -1733,11 +1733,11 @@ struct GridwiseMoeGemmBlockScale
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK));
|
||||
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(math::integer_divide_ceil(problem.N, ScaleBlockN)) *
|
||||
(IsInputGemm ? 2 : 1) * math::integer_divide_ceil(problem.K, ScaleBlockK));
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
@@ -1745,14 +1745,14 @@ struct GridwiseMoeGemmBlockScale
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
|
||||
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_scale_grid + expert_id * expert_scale_stride,
|
||||
p_b_scale_grid + static_cast<long_index_t>(expert_id) * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
@@ -1922,7 +1922,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
const auto b_grid_buf_up =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_grid_up +
|
||||
expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
|
||||
static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
@@ -1943,7 +1943,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
|
||||
const auto b_scale_grid_buf_up =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
|
||||
p_b_scale_grid_up +
|
||||
static_cast<long_index_t>(expert_id) * expert_scale_stride / BPackedSize,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
auto b_scale_thread_copy_up =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleType,
|
||||
|
||||
@@ -1342,10 +1342,10 @@ struct GridwiseMoeGemmMX
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset);
|
||||
});
|
||||
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
problem.N * (IsInputGemm ? 2 : 1) *
|
||||
const long_index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
@@ -1356,13 +1356,13 @@ struct GridwiseMoeGemmMX
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
// A, B scale buffer
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
|
||||
p_b_scale_grid + (static_cast<long_index_t>(expert_id) * expert_scale_stride) / sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// lds max alignment
|
||||
@@ -1502,7 +1502,7 @@ struct GridwiseMoeGemmMX
|
||||
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride,
|
||||
p_b_grid_up + static_cast<long_index_t>(expert_id) * expert_stride,
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
@@ -1525,7 +1525,7 @@ struct GridwiseMoeGemmMX
|
||||
const BScaleDataType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
|
||||
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride / sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
@@ -2105,10 +2105,10 @@ struct GridwiseMoeGemmMX
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
problem.N * (IsInputGemm ? 2 : 1) *
|
||||
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
@@ -2119,13 +2119,15 @@ struct GridwiseMoeGemmMX
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride,
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
// A, B scale buffer
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
|
||||
p_b_scale_grid + (static_cast<long_index_t>(expert_id) * expert_scale_stride) /
|
||||
sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// lds max alignment
|
||||
@@ -2274,7 +2276,7 @@ struct GridwiseMoeGemmMX
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride,
|
||||
p_b_grid_up + static_cast<long_index_t>(expert_id) * expert_stride,
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
// lds ping pong buffers for up
|
||||
@@ -2313,7 +2315,8 @@ struct GridwiseMoeGemmMX
|
||||
const BScaleDataType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
|
||||
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride /
|
||||
sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
|
||||
@@ -1235,10 +1235,10 @@ struct GridwiseMoeGemmMXBNS
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
problem.N * (IsInputGemm ? 2 : 1) *
|
||||
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
@@ -1249,13 +1249,15 @@ struct GridwiseMoeGemmMXBNS
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride,
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
// A, B scale buffer
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
|
||||
p_b_scale_grid + (static_cast<long_index_t>(expert_id) * expert_scale_stride) /
|
||||
sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// lds max alignment
|
||||
@@ -1421,7 +1423,7 @@ struct GridwiseMoeGemmMXBNS
|
||||
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride,
|
||||
p_b_grid_up + static_cast<long_index_t>(expert_id) * expert_stride,
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto b_blockwise_copy_up =
|
||||
@@ -1457,7 +1459,8 @@ struct GridwiseMoeGemmMXBNS
|
||||
const BScaleDataType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
|
||||
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride /
|
||||
sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
@@ -1763,10 +1766,10 @@ struct GridwiseMoeGemmMXBNS
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
|
||||
const long_index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
@@ -1776,12 +1779,12 @@ struct GridwiseMoeGemmMXBNS
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
|
||||
p_b_scale_grid + (static_cast<long_index_t>(expert_id) * expert_scale_stride) / sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
@@ -1933,7 +1936,7 @@ struct GridwiseMoeGemmMXBNS
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride / BPackedSize,
|
||||
p_b_grid_up + static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
@@ -1952,7 +1955,7 @@ struct GridwiseMoeGemmMXBNS
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride,
|
||||
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BScaleDataType,
|
||||
|
||||
@@ -1329,10 +1329,10 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
problem.N * (IsInputGemm ? 2 : 1) *
|
||||
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
@@ -1343,13 +1343,15 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
// A, B scale buffer
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
|
||||
p_b_scale_grid + (static_cast<long_index_t>(expert_id) * expert_scale_stride) /
|
||||
sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
@@ -1487,7 +1489,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride,
|
||||
p_b_grid_up + static_cast<long_index_t>(expert_id) * expert_stride,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up =
|
||||
ThreadwiseTensorSliceTransfer_v2<BDataType,
|
||||
@@ -1513,7 +1515,8 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
const BScaleDataType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
|
||||
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride /
|
||||
sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
@@ -1819,10 +1822,10 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
problem.N * (IsInputGemm ? 2 : 1) *
|
||||
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const long_index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
@@ -1833,13 +1836,15 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
p_b_grid + static_cast<long_index_t>(expert_id) * expert_stride,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
// A, B scale buffer
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
|
||||
p_b_scale_grid + (static_cast<long_index_t>(expert_id) * expert_scale_stride) /
|
||||
sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
@@ -1980,7 +1985,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride,
|
||||
p_b_grid_up + static_cast<long_index_t>(expert_id) * expert_stride,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up =
|
||||
ThreadwiseTensorSliceTransfer_v2<BDataType,
|
||||
@@ -2006,7 +2011,8 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
const BScaleDataType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
|
||||
p_b_scale_grid_up + static_cast<long_index_t>(expert_id) * expert_scale_stride /
|
||||
sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
|
||||
Reference in New Issue
Block a user