diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp new file mode 100644 index 0000000000..c2fad6132f --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp @@ -0,0 +1,973 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = math::integer_divide_ceil( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = math::integer_divide_floor( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + + constexpr auto buffer_load_stages_more = + (num_buffer_load_inst_a + num_buffer_load_inst_b) - + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_b_stages = + buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + + (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) / + buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages; + + constexpr auto buffer_load_issue_point_b = 0; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto ds_write_issue_point = 0; + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + } + + template + __device__ static constexpr auto EpilogueScheduler_1(Stage stage) + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst_b = + MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + + constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2; + + constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; + constexpr auto staged_num_mfma = num_mfma / MRepeat; + + constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + + if constexpr(stage.value == 0) + { + constexpr auto staged_num_buffer_load_b_per_ds_read_a = + num_buffer_load_inst_b / staged_num_ds_read_inst_a; + constexpr auto staged_num_mfma_per_buffer_load_b = + staged_num_mfma / num_buffer_load_inst_b; + // B global + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + + static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) { + ignore = ibuf_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(stage.value == 1) + { +#if 0 + constexpr auto staged_num_ds_write_a_per_ds_read_a = + num_ds_write_inst_a / staged_num_ds_read_inst_a; + constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a; + // A local write + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + + static_for<0, staged_num_ds_write_a_per_ds_read_a, 1>{}([&](auto idswrite_inst) { + ignore = idswrite_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + }); + + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); +#elif 1 + constexpr auto staged_num_mfma_per_ds_write_a = + math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); + + constexpr auto stage_more_mfma = + staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; + + // A local write + static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { + if constexpr(i_inst.value < stage_more_mfma) + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + else + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + }); +#endif + __builtin_amdgcn_sched_barrier(0); + } + else + { + // A local Read + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + } + + __device__ static constexpr auto EpilogueScheduler_2() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + + constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2; + + constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; + constexpr auto staged_num_mfma = num_mfma / MRepeat; + + constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + + // A local Read + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1 B1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, 2, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value == MRepeat - 2) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + HotLoopScheduler(); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value == (MRepeat - 2)) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == MRepeat - 1) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + + HotLoopScheduler(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + + HotLoopScheduler(); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + // Reduce the vgpr usage here. + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I2, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp index 6d115e7620..02da036d10 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -123,6 +123,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -184,298 +185,230 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3 - __device__ static constexpr auto HotLoopScheduler(Stage stage) + __device__ static constexpr auto HotLoopScheduler() { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - if constexpr(stage.value == 0) - { - constexpr auto staged_num_buffer_load_b_per_ds_read_a = - num_buffer_load_inst_b / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_buffer_load_b = - staged_num_mfma / num_buffer_load_inst_b; - // B global - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - - static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) { - ignore = ibuf_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 1) - { - constexpr auto staged_num_mfma_per_ds_write_a = - math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; - - // A local write - static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 2) - { - constexpr auto staged_num_mfma_per_buffer_load_a = - math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a; - - // A global - static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - } - }); - - __builtin_amdgcn_sched_barrier(0); - } - else - { - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } - } - - template - __device__ static constexpr auto EpilogueScheduler_1(Stage stage) - { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - if constexpr(stage.value == 0) - { - constexpr auto staged_num_buffer_load_b_per_ds_read_a = - num_buffer_load_inst_b / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_buffer_load_b = - staged_num_mfma / num_buffer_load_inst_b; - // B global - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - - static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) { - ignore = ibuf_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 1) - { #if 0 - constexpr auto staged_num_ds_write_a_per_ds_read_a = - num_ds_write_inst_a / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a; - // A local write - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; - static_for<0, staged_num_ds_write_a_per_ds_read_a, 1>{}([&](auto idswrite_inst) { - ignore = idswrite_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - }); + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_mfma_per_issue_more = math::integer_divide_ceil( + num_mfma_inst, num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_mfma_per_issue_less = math::integer_divide_floor( + num_mfma_inst, num_buffer_load_inst_a + num_buffer_load_inst_b); + // Insert more mfmas between bufferloads + constexpr auto num_stage1_bufferloads = + num_mfma_inst - + (num_buffer_load_inst_a + num_buffer_load_inst_b) * num_mfma_per_issue_less; + constexpr auto num_stage1_mfma = num_mfma_per_issue_more * num_stage1_bufferloads; + // Insert less mfmas between bufferloads + // constexpr auto num_stage2_mfma = num_mfma_inst - num_stage1_mfma; + + constexpr auto buffer_load_issue_point = 0; + constexpr auto ds_write_issue_point_stage1 = num_mfma_per_issue_more >= 3 ? 1 : 0; + constexpr auto ds_write_issue_point_stage2 = num_mfma_per_issue_less >= 3 ? 1 : 0; + + static_for<0, num_mfma_inst, 1>{}([&](auto i) { + constexpr auto current_buffer_load_issue = + i < num_stage1_mfma + ? (i / num_mfma_per_issue_more) + : (num_stage1_bufferloads + (i - num_stage1_mfma) / num_mfma_per_issue_less); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + // Group num_mfma_perstage num_ds_read_a_perstage + // Hide A lds rd issue latency at begining of each stage + if constexpr((i % num_mfma_perstage) >= + (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + + // Schedule VMEM access instruction distributed evenly in the loop + // Hide B/A global rd issue latency + if constexpr(((i < num_stage1_mfma) && + (i % num_mfma_per_issue_more == buffer_load_issue_point)) || + ((i >= num_stage1_mfma) && + ((i - num_stage1_mfma) % num_mfma_per_issue_less == + buffer_load_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + // Hide A lds wr issue latency + if constexpr((current_buffer_load_issue >= num_buffer_load_inst_b) && + ((((i < num_stage1_mfma) && + (i % num_mfma_per_issue_more == ds_write_issue_point_stage1)) || + ((i >= num_stage1_mfma) && + ((i - num_stage1_mfma) % num_mfma_per_issue_less == + ds_write_issue_point_stage2))) && + (((i < num_stage1_mfma) && + ((i / num_mfma_per_issue_more - num_buffer_load_inst_b) < num_ds_write_inst_a)) || + ((i >= num_stage1_mfma) && + ((i - num_stage1_mfma) / num_mfma_per_issue_less + + num_stage1_bufferloads - num_buffer_load_inst_b) < num_ds_write_inst_a)))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); #elif 1 - constexpr auto staged_num_mfma_per_ds_write_a = - math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - // A local write - static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = math::integer_divide_ceil( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = math::integer_divide_floor( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + + constexpr auto buffer_load_stages_more = + (num_buffer_load_inst_a + num_buffer_load_inst_b) - + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_b_stages = + buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + + (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) / + buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages; + + constexpr auto buffer_load_issue_point_b = 0; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto ds_write_issue_point = 0; + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } - else + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read } }); -#endif - __builtin_amdgcn_sched_barrier(0); - } - else - { - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } - } - - __device__ static constexpr auto EpilogueScheduler_2() - { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); - __builtin_amdgcn_sched_barrier(0); + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); +#endif } template {}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(I0, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(I0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, 2, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); }); // Initialize C @@ -558,26 +495,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(local_read_buf)); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - } - else if constexpr(m0.value == 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); - } - else if constexpr(m0.value == 2) - { - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { vector_type a_thread_vec; @@ -613,49 +542,88 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } - - HotLoopScheduler(m0); }); + HotLoopScheduler(); }; LoopFunc(I0, I1); @@ -667,20 +635,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(I1)); - } - else if constexpr(m0.value == MRepeat - 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { vector_type a_thread_vec; @@ -707,36 +669,72 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } - - EpilogueScheduler_1(m0); }); + HotLoopScheduler(); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -764,25 +762,31 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); + + HotLoopScheduler(); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency - // __builtin_amdgcn_sched_barrier(0); } else if constexpr(TailNum == TailNumber::Odd) { @@ -813,18 +817,21 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); } @@ -841,7 +848,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 08d177035e..a6110d2bfc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -264,77 +264,152 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle= 256) ? 1 : 2; - constexpr auto MemoryDataOp = - IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; - if(has_main_k_block_loop) + if(IsInputGemm || arg.TopK == 1) { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + constexpr auto MemoryDataOp = InMemoryDataOperationEnum::Set; + + if(has_main_k_block_loop) { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - const auto kernel = kernel_moe_gemm; + const auto kernel = kernel_moe_gemm_2lds; RunKernel(kernel); } else { - const auto kernel = kernel_moe_gemm; + const auto kernel = kernel_moe_gemm_2lds; RunKernel(kernel); } } - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + else { - const auto kernel = kernel_moe_gemm_2lds; + throw std::runtime_error("todo: only v1 & v2 support now"); + } + } +#if 1 + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_moe_gemm; RunKernel(kernel); } + } +#endif + } + else + { + constexpr auto MemoryDataOp = InMemoryDataOperationEnum::AtomicAdd; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + } else { - const auto kernel = kernel_moe_gemm_2lds; + throw std::runtime_error("todo: only v1 & v2 support now"); + } + } +#if 1 + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_moe_gemm; RunKernel(kernel); } } - else - { - throw std::runtime_error("todo: only v1 & v2 support now"); - } - } -#if 1 - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - const auto kernel = kernel_moe_gemm; - RunKernel(kernel); - } - } #endif - + } return ave_time; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 255fb8cff4..54bc1a76df 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -188,7 +188,10 @@ struct GridwiseMoeGemm math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); static constexpr index_t KLane = mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); - static constexpr index_t KRepeat = KPerBlock / KLane / KPack; + + static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; + // static_assert(KGroup == 2, ""); + static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); static constexpr index_t NLane = NPerXdl; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; // static constexpr index_t NumTokens = 1; @@ -249,7 +252,7 @@ struct GridwiseMoeGemm } __host__ __device__ static auto CalculateBK0Shuffled(index_t K) { - return math::integer_divide_ceil(K, KLane * KPack); + return math::integer_divide_ceil(K, KLane * KPack / KGroup); } __host__ __device__ static auto CalculateKPadded(index_t K) @@ -391,7 +394,7 @@ struct GridwiseMoeGemm __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1301,7 +1304,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1347,7 +1350,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); blockwise_gemm_pipeline.template Run( a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, @@ -1886,7 +1889,8 @@ struct GridwiseMoeGemm const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); - const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; if(expert_block_id * MPerBlock >= max_token_id) return; @@ -1895,12 +1899,13 @@ struct GridwiseMoeGemm const auto block_mn = [&]() -> std::pair { if constexpr(NSwizzle) { - const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; - const index_t prefix_block = ecnt_prefix * problem.NBlock; - const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; - const index_t expert_swizzle = ecnt > 0 ? ecnt : 1; - const index_t bid_new = blockIdx.x - prefix_block; - const index_t nid = __builtin_amdgcn_readfirstlane( + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); const index_t mid = __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); @@ -1911,9 +1916,9 @@ struct GridwiseMoeGemm return {blockIdx.x, blockIdx.y}; } }(); + const index_t block_n_id = block_mn.first; const index_t block_m_id = block_mn.second; - const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); @@ -1925,11 +1930,9 @@ struct GridwiseMoeGemm constexpr auto AMRepeats = MPerBlock / AMThreads; const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; - if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || - token0 >= problem.NumTokens) + if(token_pos >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray - gather_offsets; //= p_sorted_token_ids[token_pos]; + StaticallyIndexedArray gather_offsets; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; index_t token_offset = fused_token & 0xffffff; @@ -1939,7 +1942,8 @@ struct GridwiseMoeGemm } gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = @@ -1950,7 +1954,6 @@ struct GridwiseMoeGemm const auto b_grid_buf = make_dynamic_buffer( p_b_grid + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -2012,7 +2015,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2029,24 +2032,76 @@ struct GridwiseMoeGemm static_assert(std::is_default_constructible_v); auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_bufs, - a_block_slice_copy_step, - b_grid_desc_bpreshuffled, - b_blockwise_copy, - b_grid_buf, - b_block_bufs, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + num_k_block_main_loop); + } + else + { + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + } // shuffle C and write out { @@ -2074,6 +2129,185 @@ struct GridwiseMoeGemm constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + // mul scales + const float* p_sorted_weights_0 = p_ds_grid[I0]; + const float* p_scale_b = p_ds_grid[I1]; + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr) + { + if constexpr(PerTokenQuant) + { + constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); + p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + + get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; + } + else + { + p_scale_b += expert_id; + } + + vector_type scale_token_ids; + vector_type topk_weights; + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(PerTokenQuant) + { + scale_token_ids = + *c_style_pointer_cast*>( + p_sorted_token_ids + m_pos); + } + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + float scale_a = [&]() { + if constexpr(PerTokenQuant) + { + index_t fused_token = scale_token_ids.AsType()[m4]; + const index_t token_offset = fused_token & 0xffffff; + return token_offset < problem.NumTokens + ? p_sorted_weights_0[token_offset] + : 0.0; + } + else + { + return p_sorted_weights_0[0]; + } + }(); + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = + scale_a * scale_b * c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) * + topk_weights.AsType()[m4]; + } + } + }); + }); + }); + }); + } + else + { + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = topk_weights.AsType()[m4] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + } + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -2171,18 +2405,8 @@ struct GridwiseMoeGemm const auto ds_grid_buf = generate_tuple( [&](auto i) { - using DDataType = remove_cvref_t>; - const DDataType* ptr_ = p_ds_grid[i]; - // hack logic here to support different kind of strides. todo fix it. - // ascale t, 1; bscale E, N, 1, move ptr to E - // if(i.value == 1) - // { - // ptr_ += - // expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : - // 1); - // } return make_dynamic_buffer( - ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); }, Number{}); @@ -2258,7 +2482,6 @@ struct GridwiseMoeGemm auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, @@ -2297,7 +2520,7 @@ struct GridwiseMoeGemm block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - index_t token_offset = fused_token & 0xffffff; + IndexType token_offset = fused_token & 0xffffff; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); @@ -2310,7 +2533,7 @@ struct GridwiseMoeGemm // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, + c_thread_buf_fp32, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_shuffle_block_buf);