From ceaa5a984b058bf67453ba5d58d4462879157a24 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Thu, 24 Apr 2025 11:05:08 +0000 Subject: [PATCH] gu fusion v3 --- .../moe_gemm1_xdl_fp8.cpp | 8 +- ..._pipeline_xdlops_b_preshuffle_selector.hpp | 69 ++-- .../gpu/grid/gridwise_moe_gemm.hpp | 301 +++++++++++++++--- 3 files changed, 313 insertions(+), 65 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 3b31460953..001a30370b 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -158,8 +158,8 @@ using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t MXDLPerWave = 4; -static constexpr ck::index_t NXDLPerWave = 2; +static constexpr ck::index_t MXDLPerWave = 8; +static constexpr ck::index_t NXDLPerWave = 1; static constexpr ck::index_t BLOCKSIZE = 256; static constexpr ck::index_t NPerBlock = 64; static constexpr ck::index_t MNPerXDL = 16; @@ -190,8 +190,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceM // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 2, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>; + 2, 1, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>; // clang-format on diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp index 074b5873ee..df59513a88 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp @@ -8,6 +8,7 @@ #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp" @@ -171,26 +172,54 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector() static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3"); if constexpr(std::is_same::value) { - return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}; + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + + return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}; + } } else { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 255fb8cff4..ead4399d81 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1886,7 +1886,8 @@ struct GridwiseMoeGemm const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); - const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; if(expert_block_id * MPerBlock >= max_token_id) return; @@ -1895,12 +1896,13 @@ struct GridwiseMoeGemm const auto block_mn = [&]() -> std::pair { if constexpr(NSwizzle) { - const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; - const index_t prefix_block = ecnt_prefix * problem.NBlock; - const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; - const index_t expert_swizzle = ecnt > 0 ? ecnt : 1; - const index_t bid_new = blockIdx.x - prefix_block; - const index_t nid = __builtin_amdgcn_readfirstlane( + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); const index_t mid = __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); @@ -1911,9 +1913,9 @@ struct GridwiseMoeGemm return {blockIdx.x, blockIdx.y}; } }(); + const index_t block_n_id = block_mn.first; const index_t block_m_id = block_mn.second; - const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); @@ -1925,11 +1927,9 @@ struct GridwiseMoeGemm constexpr auto AMRepeats = MPerBlock / AMThreads; const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; - if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || - token0 >= problem.NumTokens) + if(token_pos >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray - gather_offsets; //= p_sorted_token_ids[token_pos]; + StaticallyIndexedArray gather_offsets; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; index_t token_offset = fused_token & 0xffffff; @@ -1939,7 +1939,8 @@ struct GridwiseMoeGemm } gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = @@ -1950,7 +1951,6 @@ struct GridwiseMoeGemm const auto b_grid_buf = make_dynamic_buffer( p_b_grid + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -2029,24 +2029,75 @@ struct GridwiseMoeGemm static_assert(std::is_default_constructible_v); auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_bufs, - a_block_slice_copy_step, - b_grid_desc_bpreshuffled, - b_blockwise_copy, - b_grid_buf, - b_block_bufs, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack * (get_thread_local_1d_id() % warpSize))); + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + num_k_block_main_loop); + } + else + { + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + } // shuffle C and write out { @@ -2074,6 +2125,185 @@ struct GridwiseMoeGemm constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + // mul scales + const float* p_sorted_weights_0 = p_ds_grid[I0]; + const float* p_scale_b = p_ds_grid[I1]; + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr) + { + if constexpr(PerTokenQuant) + { + constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); + p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + + get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; + } + else + { + p_scale_b += expert_id; + } + + vector_type scale_token_ids; + vector_type topk_weights; + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(PerTokenQuant) + { + scale_token_ids = + *c_style_pointer_cast*>( + p_sorted_token_ids + m_pos); + } + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + float scale_a = [&]() { + if constexpr(PerTokenQuant) + { + index_t fused_token = scale_token_ids.AsType()[m4]; + const index_t token_offset = fused_token & 0xffffff; + return token_offset < problem.NumTokens + ? p_sorted_weights_0[token_offset] + : 0.0; + } + else + { + return p_sorted_weights_0[0]; + } + }(); + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = + scale_a * scale_b * c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) * + topk_weights.AsType()[m4]; + } + } + }); + }); + }); + }); + } + else + { + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = topk_weights.AsType()[m4] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + } + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -2171,18 +2401,8 @@ struct GridwiseMoeGemm const auto ds_grid_buf = generate_tuple( [&](auto i) { - using DDataType = remove_cvref_t>; - const DDataType* ptr_ = p_ds_grid[i]; - // hack logic here to support different kind of strides. todo fix it. - // ascale t, 1; bscale E, N, 1, move ptr to E - // if(i.value == 1) - // { - // ptr_ += - // expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : - // 1); - // } return make_dynamic_buffer( - ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); }, Number{}); @@ -2258,7 +2478,6 @@ struct GridwiseMoeGemm auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, @@ -2297,7 +2516,7 @@ struct GridwiseMoeGemm block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - index_t token_offset = fused_token & 0xffffff; + IndexType token_offset = fused_token & 0xffffff; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); @@ -2310,7 +2529,7 @@ struct GridwiseMoeGemm // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, + c_thread_buf_fp32, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_shuffle_block_buf);