From 62877fa07409d16239839f75c0d004f5584cda07 Mon Sep 17 00:00:00 2001 From: OscarXu Date: Thu, 8 May 2025 14:33:31 +0800 Subject: [PATCH] 16x16 function merged to moe --- ..._xdlops_moe_blockscale_b_preshuffle_v1.hpp | 111 +- ..._xdlops_moe_blockscale_b_preshuffle_v3.hpp | 961 +++++++----------- .../impl/device_moe_gemm_blockscale.hpp | 26 +- .../gpu/grid/gridwise_moe_gemm_blockscale.hpp | 51 +- 4 files changed, 458 insertions(+), 691 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp index 031f9b6683..6b2e3dc0e4 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp @@ -197,96 +197,28 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves; - constexpr auto num_pk_fma_per_kscaleblock = MPerXDL == 16 ? 2 : 8; - constexpr auto num_mfma_per_kscaleblock = - MPerXDL == 16 ? KScaleBlock / 32 : KScaleBlock / 16; -#if 0 // B global static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - /* Judging issue v_pk_fma */ - if constexpr((i + 1) % num_mfma_per_kscaleblock == 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read }); // A global static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr((num_buffer_load_inst_b + 2 * i + 1) % num_mfma_per_kscaleblock == 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr((num_buffer_load_inst_b + 2 * i + 2) % num_mfma_per_kscaleblock == 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read }); // A local static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr((num_buffer_load_inst_b + 2 * num_buffer_load_inst_a + i + 1) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read - }); -#elif 1 // v_mul occured too early causing vmcnt stall - // B global - static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - /* Judging issue v_pk_fma */ - if constexpr((i + 1) % num_mfma_per_kscaleblock == 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - }); - - // A global - static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - if constexpr((num_buffer_load_inst_b + 2 * i + 1) % num_mfma_per_kscaleblock == 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - if constexpr((num_buffer_load_inst_b + 2 * i + 2) % num_mfma_per_kscaleblock == 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - }); - - // A local - static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + ignore = i; __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read - if constexpr((num_buffer_load_inst_b + 2 * num_buffer_load_inst_a + i + 1) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } }); -#endif } template \n", - // blockIdx.y, - // blockIdx.x, - // threadIdx.x, - // c_scale_thread_buf[Number<0>{}]); - // Local prefill A1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); @@ -583,6 +509,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< }); }); + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, num_scale_n_block, 1>{}([&](auto n0) { static_for<0, num_scale_k_block, 1>{}([&](auto k0) { @@ -600,19 +542,6 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< }); }); - block_sync_lds(); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); - }); - }); - a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_grid_buf, a_scale_thread_desc, @@ -638,8 +567,6 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); - HotLoopScheduler(); - __builtin_amdgcn_sched_barrier(0); }; LoopFunc(I0, I1); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp index 9902128bdc..d3bd088a40 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp @@ -134,6 +134,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< using Base::I0; using Base::I1; using Base::I2; + using Base::KGroup; using Base::KRepeat; using Base::xdlops_gemm; using typename Base::HotLoopInstList; @@ -163,9 +164,9 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -191,466 +192,156 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; } - template - __device__ static constexpr auto HotLoopScheduler(Stage stage) + __device__ static constexpr auto HotLoopScheduler() { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); - constexpr auto num_pk_fma_per_kscaleblock = MPerXDL == 16 ? 2 : 8; - constexpr auto num_mfma_per_kscaleblock = MPerXDL == 16 ? KPerBlock / 32 : KPerBlock / 16; + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; - if constexpr(stage.value == 0) - { - // B VMEM access. - constexpr auto staged_num_buffer_load_b_per_ds_read_a = - num_buffer_load_inst_b / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_buffer_load_b = - staged_num_mfma / num_buffer_load_inst_b; - // B global - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) { - static_for<0, staged_num_mfma_per_buffer_load_b, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + constexpr auto num_total_stages = MRepeat; - /* Judging issue v_pk_fma */ - if constexpr((i_inst * staged_num_mfma_per_buffer_load_b * - staged_num_buffer_load_b_per_ds_read_a + - ibuf_inst * staged_num_mfma_per_buffer_load_b + imfma + 1) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); - /* Judging issue v_pk_fma */ - if constexpr((i_inst * staged_num_mfma_per_buffer_load_b * - staged_num_buffer_load_b_per_ds_read_a + - (staged_num_buffer_load_b_per_ds_read_a - 1) * - staged_num_mfma_per_buffer_load_b + - 1) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + constexpr auto num_ds_read_a_prefetch_stages = 2; - static_for<0, staged_num_mfma_per_buffer_load_b - 1, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + constexpr auto buffer_load_perstage_more = math::integer_divide_ceil( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = math::integer_divide_floor( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); - /* Judging issue v_pk_fma */ - if constexpr((i_inst * staged_num_mfma_per_buffer_load_b * - staged_num_buffer_load_b_per_ds_read_a + - (staged_num_buffer_load_b_per_ds_read_a - 1) * - staged_num_mfma_per_buffer_load_b + - imfma + 2) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - }); + constexpr auto buffer_load_stages_more = + (num_buffer_load_inst_a + num_buffer_load_inst_b) - + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_b_stages = + buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + + (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) / + buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages; + + constexpr auto buffer_load_issue_point_b = 0; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto ds_write_issue_point = 0; + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + // Scale load, 1B + if constexpr (i.value==0){ __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 1) - { - // A LDS write access. - constexpr auto staged_num_mfma_per_ds_write_a = - math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; - - // A local write - static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - static_for<0, staged_num_mfma_per_ds_write_a - 1, 1>{}([&](auto i_mfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - /* Judging issue v_pk_fma */ - if constexpr((i_inst * staged_num_mfma_per_ds_write_a + i_mfma + 1) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - }); - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - /* Judging issue v_pk_fma */ - if constexpr(((i_inst + 1) * staged_num_mfma_per_ds_write_a) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - static_for<0, staged_num_mfma_per_ds_write_a, 1>{}([&](auto i_mfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - /* Judging issue v_pk_fma */ - if constexpr((i_inst * staged_num_mfma_per_ds_write_a + i_mfma + 1) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - }); - - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - static_for<0, staged_num_mfma_per_ds_write_a - 2, 1>{}([&](auto i_mfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - /* Judging issue v_pk_fma */ - if constexpr((stage_more_mfma * staged_num_mfma_per_ds_write_a + - (i_inst - stage_more_mfma) * - (staged_num_mfma_per_ds_write_a - 1) + - i_mfma + 1) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - }); - - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - /* Judging issue v_pk_fma */ - if constexpr((stage_more_mfma * staged_num_mfma_per_ds_write_a + - (i_inst - stage_more_mfma + 1) * - (staged_num_mfma_per_ds_write_a - 1)) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - static_for<0, staged_num_mfma_per_ds_write_a - 1, 1>{}([&](auto i_mfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - /* Judging issue v_pk_fma */ - if constexpr((stage_more_mfma * staged_num_mfma_per_ds_write_a + - (i_inst - stage_more_mfma) * - (staged_num_mfma_per_ds_write_a - 1) + - i_mfma + 1) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - }); - - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 2) - { - // A VMEM access. - constexpr auto staged_num_mfma_per_buffer_load_a = - math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a; - - // A global - static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - static_for<0, staged_num_mfma_per_buffer_load_a - 1, 1>{}([&](auto i_mfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - /* Judging issue v_pk_fma */ - if constexpr((i_inst * staged_num_mfma_per_buffer_load_a + i_mfma + 1) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - /* Judging issue v_pk_fma */ - if constexpr(((i_inst + 1) * staged_num_mfma_per_buffer_load_a) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - static_for<0, staged_num_mfma_per_buffer_load_a, 1>{}([&](auto i_mfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - /* Judging issue v_pk_fma */ - if constexpr((i_inst * staged_num_mfma_per_buffer_load_a + i_mfma + 1) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - static_for<0, staged_num_mfma_per_buffer_load_a - 2, 1>{}([&](auto i_mfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - /* Judging issue v_pk_fma */ - if constexpr((stage_more_mfma * staged_num_mfma_per_buffer_load_a + - (i_inst - stage_more_mfma) * - (staged_num_mfma_per_buffer_load_a - 1) + - i_mfma + 1) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - /* Judging issue v_pk_fma */ - if constexpr((stage_more_mfma * staged_num_mfma_per_buffer_load_a + - (i_inst - stage_more_mfma + 1) * - (staged_num_mfma_per_buffer_load_a - 1)) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - static_for<0, staged_num_mfma_per_buffer_load_a - 1, 1>{}([&](auto i_mfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - /* Judging issue v_pk_fma */ - if constexpr((stage_more_mfma * staged_num_mfma_per_buffer_load_a + - (i_inst - stage_more_mfma) * - (staged_num_mfma_per_buffer_load_a - 1) + - i_mfma + 1) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - }); - - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - } - }); - - __builtin_amdgcn_sched_barrier(0); - } - else - { - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - static_for<0, staged_num_mfma_per_ds_read_a, 1>{}([&](auto i_mfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - /* Judging issue v_pk_fma */ - if constexpr((i_inst * staged_num_mfma_per_ds_read_a + i_mfma + 1) % - num_mfma_per_kscaleblock == - 0) - { - __builtin_amdgcn_sched_group_barrier( - 0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA - } - }); - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } - } - - template - __device__ static constexpr auto EpilogueScheduler_1(Stage stage) - { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - if constexpr(stage.value == 0) - { - constexpr auto staged_num_buffer_load_b_per_ds_read_a = - num_buffer_load_inst_b / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_buffer_load_b = - staged_num_mfma / num_buffer_load_inst_b; - // B global - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - - static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) { - ignore = ibuf_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); + } + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 1) - { - constexpr auto staged_num_mfma_per_ds_write_a = - math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; - - // A local write - static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } - else + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read } + __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma }); - __builtin_amdgcn_sched_barrier(0); - } - else - { - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } - } - - __device__ static constexpr auto EpilogueScheduler_2() - { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + // __builtin_amdgcn_sched_barrier(0); }); - __builtin_amdgcn_sched_barrier(0); + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // Scale load, 1A + if constexpr(imfma == 0){ + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); } template ( b_thread_desc_.GetElementSpaceSize()); + StaticallyIndexedArray{}> b_thread_bufs; constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + auto a_scale_thread_buf = make_static_buffer( a_scale_thread_desc.GetElementSpaceSize()); auto b_scale_thread_buf = make_static_buffer( @@ -734,6 +427,10 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< auto c_scale_thread_buf = make_static_buffer( c_scale_thread_desc.GetElementSpaceSize()); + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + // StaticallyIndexedArray{}> c_scale_thread_bufs; + // Global prefetch A1 B1, AScale1 BScale1 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, @@ -750,7 +447,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< a_scale_grid_buf, a_scale_thread_desc, make_tuple(I0, I0), - a_scale_thread_buf); + a_scale_thread_bufs(I0)); if constexpr(NumKBlockPerScale == 1) { @@ -767,12 +464,12 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< b_scale_grid_buf, b_scale_thread_desc, make_tuple(I0, I0), - b_scale_thread_buf); + b_scale_thread_bufs(I0)); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); static_for<0, MRepeat, 1>{}([&](auto m0) { - c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0]; + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; }); // Local prefill A1 @@ -786,7 +483,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< a_scale_grid_buf, a_scale_thread_desc, make_tuple(I0, I0), - a_scale_thread_buf); + a_scale_thread_bufs(I0)); if constexpr(NumKBlockPerScale == 1) { @@ -803,13 +500,16 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< b_scale_grid_buf, b_scale_thread_desc, make_tuple(I0, I0), - b_scale_thread_buf); + b_scale_thread_bufs(I0)); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); // Initialize C c_thread_buf.Clear(); + // Double register buffer for non-scaled gemm computation + // 1. Reduce register pressure + // 2. Decouple the dependency between mfma instruction and scale-fma instruction following. StaticBufferTupleOfVector{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(I0, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(I0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, 2, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); }); +#if 0 + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + + // Fill first mfma buffer + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); +#endif __builtin_amdgcn_sched_barrier(0); // main body @@ -837,26 +567,43 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< do { auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(local_read_buf)); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - } - else if constexpr(m0.value == 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); - } - else if constexpr(m0.value == 2) - { - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_bufs(local_read_buf)); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(local_read_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { vector_type c_scale_thread_vec; c_scale_thread_vec.template AsType()(Number<0>{}) = c_scale_thread_buf[m0]; @@ -914,82 +661,93 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< }); }); - if constexpr(m0.value == MRepeat - 1) + if constexpr(m0.value == (MRepeat - 2)) { block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } - - HotLoopScheduler(m0); }); static_for<0, MRepeat, 1>{}([&](auto m0) { - c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0]; + c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs[mfma_reg_buf][I0]; }); - - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(I0, I0), - a_scale_thread_buf); - - if constexpr(NumKBlockPerScale == 1) - { - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); - } - else - { - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); - } - - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(I0, I0), - b_scale_thread_buf); - - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - b_scale_thread_copy_step); - - __builtin_amdgcn_sched_group_barrier(0x020, MRepeat + 1, 0); // VMEM read + HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); }; @@ -1003,20 +761,14 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< // tail if constexpr(TailNum == TailNumber::Even) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(I1)); - } - else if constexpr(m0.value == MRepeat - 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + static_for<0, MRepeat, 1>{}([&](auto m0) { vector_type c_scale_thread_vec; c_scale_thread_vec.template AsType()(Number<0>{}) = c_scale_thread_buf[m0]; @@ -1066,38 +818,74 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< }); }); - if constexpr(m0.value == MRepeat - 1) + if constexpr(m0.value == (MRepeat - 2)) { block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } - - HotLoopScheduler(m0); }); + HotLoopScheduler(); + static_for<0, MRepeat, 1>{}([&](auto m0) { - c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0]; + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; }); static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -1149,20 +937,25 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< }); }); - if constexpr(m0.value != (MRepeat - 1)) + if constexpr(m0.value < (MRepeat - 2)) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle @@ -1220,18 +1013,21 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< }); }); - if constexpr(m0.value != (MRepeat - 1)) + if constexpr(m0.value < (MRepeat - 2)) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); } @@ -1248,7 +1044,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), - Sequence<1, 1, 1, 1, 1, KPack>, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, @@ -1260,6 +1056,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< make_tuple(Number{}, I1, Number{}, Number{})); static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + using Base::c_thread_desc_; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index 318ac96ea1..af5f9c49ad 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -401,16 +401,16 @@ struct DeviceMoeGemmBlockScale }; #endif - constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize / - 4 * (1 + GridwiseGemm::NWave); - constexpr auto estimated_reg_b = - NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2); - constexpr auto estimated_reg_c = - MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4; - constexpr auto estimated_reg_total = - estimated_reg_a + estimated_reg_b + estimated_reg_c; + // constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize / + // 4 * (1 + GridwiseGemm::NWave); + // constexpr auto estimated_reg_b = + // NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2); + // constexpr auto estimated_reg_c = + // MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4; + // constexpr auto estimated_reg_total = + // estimated_reg_a + estimated_reg_b + estimated_reg_c; - constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2; + constexpr index_t minimum_occupancy = 2; constexpr auto MemoryDataOp = IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; @@ -704,7 +704,7 @@ struct DeviceMoeGemmBlockScale index_t StrideC, const void* p_a_scale, const void* p_b_scale, - index_t KBatch, + // index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override @@ -727,7 +727,7 @@ struct DeviceMoeGemmBlockScale StrideC, static_cast(p_a_scale), static_cast(p_b_scale), - KBatch, + 1, //KBatch, a_element_op, b_element_op, c_element_op); @@ -749,7 +749,9 @@ struct DeviceMoeGemmBlockScale {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; std::map BlkGemmPipelineVersionToString{ - {BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}}; + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}}; // clang-format off str << "DeviceMoeGEmm" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index 56d49f1430..3b2c59ba72 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -187,9 +187,10 @@ struct GridwiseMoeGemmBlockScale using mfma_selector = MfmaSelector; static constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); + static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; static constexpr index_t KLane = mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); - static constexpr index_t KRepeat = KPerBlock / KLane / KPack; + static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); static constexpr index_t NLane = NPerXdl; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; // static constexpr index_t NumTokens = 1; @@ -249,7 +250,7 @@ struct GridwiseMoeGemmBlockScale } __host__ __device__ static auto CalculateBK0Shuffled(index_t K) { - return math::integer_divide_ceil(K, KLane * KPack); + return math::integer_divide_ceil(K, KLane * KPack / KGroup); } __host__ __device__ static auto CalculateKPadded(index_t K) @@ -391,7 +392,7 @@ struct GridwiseMoeGemmBlockScale __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1334,7 +1335,7 @@ struct GridwiseMoeGemmBlockScale make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1946,7 +1947,7 @@ struct GridwiseMoeGemmBlockScale make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2076,6 +2077,7 @@ struct GridwiseMoeGemmBlockScale // shuffle C and write out { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); @@ -2371,6 +2373,45 @@ struct GridwiseMoeGemmBlockScale I0, cde_lds_and_global_step); } + + // // print C + // printf("tid: %d, blkid: %d, " + // "c_thread_buf = <%1.f, %1.f, %1.f>\n " + // // "%1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f," + // // "%1.f, %1.f, %1.f, %1.f, %1.f, %1.f\n" + // , get_thread_local_1d_id(), block_m_id, + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<0>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<1>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<2>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<3>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<4>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<5>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<6>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<7>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<8>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<9>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<10>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<11>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<12>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<13>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<14>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<3>{}]); }); } }