From c79879c669138f74ee42e38c607912430f3d99f0 Mon Sep 17 00:00:00 2001 From: oscar Date: Fri, 28 Nov 2025 17:41:29 +0800 Subject: [PATCH] splitk hack pass --- .../moe_gemm1_xdl_fp8_blockscale_splitk.cpp | 21 +++-- .../impl/device_moe_gemm_blockscale.hpp | 10 ++- .../gpu/grid/gridwise_moe_gemm_blockscale.hpp | 77 +++++++++++-------- 3 files changed, 64 insertions(+), 44 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp index 3ba49b28d6..cd321d8574 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp @@ -185,16 +185,21 @@ int main(int argc, char* argv[]) // GEMM shape ck::index_t N = 4096; ck::index_t K = 6144; + // ck::index_t N = 128; + // ck::index_t K = 512; ck::index_t experts = 8; ck::index_t topk = 2; // ck::index_t sorted_tile_num = 515; // ck::index_t valid_tile_num = 512; - // ck::index_t tokens = 8192; + // ck::index_t tokens = 208; // ck::index_t sorted_tile_num = 15; // ck::index_t valid_tile_num = 13; - ck::index_t sorted_tile_num = 259; - ck::index_t valid_tile_num = 256; - ck::index_t tokens = 4096; + // ck::index_t sorted_tile_num = 259; + // ck::index_t valid_tile_num = 256; + // ck::index_t tokens = 4096; + ck::index_t sorted_tile_num = 2; + ck::index_t valid_tile_num = 2; + ck::index_t tokens = 32; #else // deepseek ck::index_t N = 2048; @@ -256,14 +261,14 @@ int main(int argc, char* argv[]) } ck::index_t StrideA = K; ck::index_t StrideB = K; - ck::index_t StrideE = N; + ck::index_t StrideE = N * 2; constexpr ck::index_t NumDTensor = DsDataType::Size(); constexpr auto StrideDs = std::array{0}; ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2; - ck::index_t KBatch = 1; + ck::index_t KBatch = 6; Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); @@ -319,9 +324,9 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); break; case 2: diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index 18c8843c4a..0fd3e0b53a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -203,12 +203,11 @@ struct DeviceMoeGemmBlockScale } index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N * (IsInputGemm && IsSplitK ? 2 : 1), arg.K, arg.KBatch); float ave_time = 0; - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + index_t K_split = arg.KBatch == 1 ? arg.K : arg.KBatch * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const auto RunKernel = [&](const auto& kernel) { @@ -443,6 +442,11 @@ struct DeviceMoeGemmBlockScale { return false; } + if (arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0) + { + // Not support Kpadding with KBatch > 1 + return false; + } if(get_warp_size() == 64) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index 43134cecf5..e3ba4794f6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -51,6 +51,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + // printf("splitk_batch_offset.a_k_split_offset: %d\n", splitk_batch_offset.a_k_split_offset); + // printf("splitk_batch_offset.b_k_split_offset: %d\n", splitk_batch_offset.b_k_split_offset); + // printf("splitk_batch_offset.ascale_k_split_offset: %d\n", splitk_batch_offset.ascale_k_split_offset); + // printf("splitk_batch_offset.bscale_k_split_offset: %d\n", splitk_batch_offset.bscale_k_split_offset); GridwiseGemm::template Run( karg.p_sorted_token_ids, @@ -60,8 +64,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_ds_grid, karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, + karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset, p_shared, karg, karg.a_element_op, @@ -101,8 +105,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_ds_grid, karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, + karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset, p_shared, p_shared1, karg, @@ -250,13 +254,15 @@ struct GridwiseMoeGemmBlockScale return 1; }(); - __host__ static auto CalculateGridSize(index_t M, index_t N) + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t K, index_t KBatch) { const index_t nblock = math::integer_divide_ceil(N, NPerBlock); const index_t mblock = math::integer_divide_ceil(M, MPerBlock); const index_t gridx = NSwizzle ? nblock * mblock : nblock; const index_t gridy = NSwizzle ? 1 : mblock; - return std::make_tuple(gridx, gridy, 1); + const index_t gridz = KBatch == 1 ? 1 : math::integer_divide_ceil(K, KPerBlock * KBatch); + + return std::make_tuple(gridx, gridy, gridz); } __host__ __device__ static auto CalculateMPadded(index_t M) @@ -285,27 +291,31 @@ struct GridwiseMoeGemmBlockScale __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + // auto K_t = K_Batch * KPerBlock; + // return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + return K_Batch == 1 ? K / AK1Value : K_Batch * KPerBlock / AK1Value; } __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + // auto K_t = K_Batch * KPerBlock; + // return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + return K_Batch == 1 ? K / BK1Value : K_Batch * KPerBlock / BK1Value; } __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * KPerBlock; + // auto K_t = K_Batch * KPerBlock; + // return (K + K_t - 1) / K_t * KPerBlock; + return K_Batch == 1 ? K : K_Batch * KPerBlock; } __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) { constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); - auto K_t = K_Batch * KReadVec; - return (K + K_t - 1) / K_t * KReadVec; + // auto K_t = K_Batch * KReadVec; + // return (K + K_t - 1) / K_t * KReadVec; + return K_Batch == 1 ? math::integer_divide_ceil(K, KReadVec) * KReadVec : K_Batch * KPerBlock; } __host__ __device__ static auto CalculateMBlock(index_t M) @@ -410,7 +420,6 @@ struct GridwiseMoeGemmBlockScale make_pass_through_transform(M)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - return a_grid_desc_ak0_m_ak1; } } @@ -758,19 +767,22 @@ struct GridwiseMoeGemmBlockScale // KPack * NLane * KLane * K0 * N0 b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize; } - - if(k_id < karg.KBatch - 1) - { - karg.K = karg.KRead; - } - else - { - karg.K = karg.K - karg.KRead * (karg.KBatch - 1); - } + ascale_k_split_offset = math::integer_divide_ceil(a_k_split_offset, ScaleBlockK); + bscale_k_split_offset = math::integer_divide_ceil(b_k_split_offset, ScaleBlockK); + // if(k_id < karg.KBatch - 1) + // { + // karg.K = karg.KRead; + // } + // else + // { + // karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + // } } index_t a_k_split_offset; index_t b_k_split_offset; + index_t ascale_k_split_offset; + index_t bscale_k_split_offset; }; __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -1191,7 +1203,7 @@ struct GridwiseMoeGemmBlockScale CElementwiseOperation c_element_op) { ignore = b_element_op; - index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1)); + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, @@ -1206,7 +1218,7 @@ struct GridwiseMoeGemmBlockScale IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, problem.MPadded, problem.N * (IsInputGemm && IsSplitK ? 2 : 1), - problem.NPadded, + problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1), problem.StrideC); const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( @@ -1372,9 +1384,8 @@ struct GridwiseMoeGemmBlockScale decltype(c_thread_buf) c_thread_buf_up; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); - + problem.KBatch == 1 ?(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock : problem.KBatch); constexpr index_t ScaleSliceSizeM = MXdlPerWave; constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); @@ -1940,7 +1951,7 @@ struct GridwiseMoeGemmBlockScale CElementwiseOperation c_element_op) { ignore = b_element_op; - index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1)); + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, @@ -1955,7 +1966,7 @@ struct GridwiseMoeGemmBlockScale IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, problem.MPadded, problem.N * (IsInputGemm && IsSplitK ? 2 : 1), - problem.NPadded, + problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1), problem.StrideC); const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( @@ -2126,8 +2137,8 @@ struct GridwiseMoeGemmBlockScale decltype(c_thread_buf) c_thread_buf_up; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); + problem.KBatch == 1 ?(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock : problem.KBatch); // scale constexpr index_t ScaleSliceSizeM = MXdlPerWave;