diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index 673b6b2f21..d579501077 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -827,7 +827,15 @@ struct GridwiseMoeGemmBlockScale struct SplitKBatchOffset { - __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + __device__ SplitKBatchOffset() + : a_k_split_offset(0), + b_k_split_offset(0), + ascale_k_split_offset(0), + bscale_k_split_offset(0) + { + } + + __device__ SplitKBatchOffset(const Problem& karg, index_t k_id) { if constexpr(is_same_v) { @@ -847,19 +855,9 @@ struct GridwiseMoeGemmBlockScale } else if constexpr(is_same_v) { - // KPack * NLane * KLane * K0 * N0 b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize; bscale_k_split_offset = k_id * karg.KRead / ScaleBlockK; } - - // if(k_id < karg.KBatch - 1) - // { - // karg.K = karg.KRead; - // } - // else - // { - // karg.K = karg.K - karg.KRead * (karg.KBatch - 1); - // } } index_t a_k_split_offset; @@ -1234,18 +1232,43 @@ struct GridwiseMoeGemmBlockScale const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + // When SplitK is enabled, base pointers have been shifted by + // SplitKBatchOffset in the kernel entry, but buffer descriptor element + // spaces are still based on full K. Subtract the pointer shift from + // each element space so the hardware buffer resource doesn't extend + // beyond the actual tensor allocation. + const auto splitk_offset = [&]() -> SplitKBatchOffset { + if constexpr(IsSplitK) + { + return SplitKBatchOffset(problem, blockIdx.z); + } + else + { + return SplitKBatchOffset(); + } + }(); + + assert(a_grid_desc_ak0_m_ak1.GetElementSpaceSize() >= splitk_offset.a_k_split_offset); + assert(b_grid_desc_bpreshuffled.GetElementSpaceSize() >= splitk_offset.b_k_split_offset); + assert(a_scale_grid_desc_am_ak.GetElementSpaceSize() >= + splitk_offset.ascale_k_split_offset); + assert(b_scale_grid_desc_bn_ak.GetElementSpaceSize() >= + splitk_offset.bscale_k_split_offset); + const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() - splitk_offset.a_k_split_offset); const auto b_grid_buf = make_dynamic_buffer( p_b_grid + static_cast(expert_id) * expert_stride / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); + b_grid_desc_bpreshuffled.GetElementSpaceSize() - splitk_offset.b_k_split_offset); const auto a_scale_grid_buf = make_dynamic_buffer( - p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + p_a_scale_grid, + a_scale_grid_desc_am_ak.GetElementSpaceSize() - splitk_offset.ascale_k_split_offset); const auto b_scale_grid_buf = make_dynamic_buffer( p_b_scale_grid + static_cast(expert_id) * expert_scale_stride, - b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + b_scale_grid_desc_bn_ak.GetElementSpaceSize() - + splitk_offset.bscale_k_split_offset); // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = @@ -1742,18 +1765,39 @@ struct GridwiseMoeGemmBlockScale const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + // Same fix as Run(): reduce buffer element spaces by split offset + const auto splitk_offset = [&]() -> SplitKBatchOffset { + if constexpr(IsSplitK) + { + return SplitKBatchOffset(problem, blockIdx.z); + } + else + { + return SplitKBatchOffset(); + } + }(); + + assert(a_grid_desc_ak0_m_ak1.GetElementSpaceSize() >= splitk_offset.a_k_split_offset); + assert(b_grid_desc_bpreshuffled.GetElementSpaceSize() >= splitk_offset.b_k_split_offset); + assert(a_scale_grid_desc_am_ak.GetElementSpaceSize() >= + splitk_offset.ascale_k_split_offset); + assert(b_scale_grid_desc_bn_ak.GetElementSpaceSize() >= + splitk_offset.bscale_k_split_offset); + const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() - splitk_offset.a_k_split_offset); const auto b_grid_buf = make_dynamic_buffer( p_b_grid + static_cast(expert_id) * expert_stride / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); + b_grid_desc_bpreshuffled.GetElementSpaceSize() - splitk_offset.b_k_split_offset); const auto a_scale_grid_buf = make_dynamic_buffer( - p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + p_a_scale_grid, + a_scale_grid_desc_am_ak.GetElementSpaceSize() - splitk_offset.ascale_k_split_offset); const auto b_scale_grid_buf = make_dynamic_buffer( p_b_scale_grid + static_cast(expert_id) * expert_scale_stride, - b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + b_scale_grid_desc_bn_ak.GetElementSpaceSize() - + splitk_offset.bscale_k_split_offset); // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 =