diff --git a/CHANGELOG.md b/CHANGELOG.md index 60fe2df99d..4be173dd85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,8 +19,11 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Optimized + +* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166) * Added Vectorize Transpose optimization for CK Tile (#2131) + ### Fixes None diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index e4e6a4f1a7..9f758d5fc5 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -9,7 +9,6 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" @@ -142,12 +141,12 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 128, 16, 16, - 32, 32, - 2, 2, + 16, 16, + 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; + 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // clang-format on int main(int argc, char* argv[]) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp index d751543175..1d27a74bd7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -280,12 +281,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -346,14 +349,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -409,14 +416,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -495,7 +506,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp index 4c019a41a4..7bbaaca5b6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp @@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -281,12 +282,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(I0)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -318,14 +321,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(local_read_buf)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -389,14 +396,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_reg), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(local_read_reg)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -445,12 +456,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_reg), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(local_read_reg)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -539,7 +553,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp index 6d115e7620..6f3a7e6357 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp @@ -5,6 +5,16 @@ #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" +#define DS_READ_A_PREFETCH_STAGES 2 + +template +constexpr auto compute_stage_loads(T total_loads, T stages) +{ + return std::make_pair((total_loads + stages - 1) / stages, // ceil + total_loads / stages // floor + ); +} + namespace ck { // Compute optimized pipeline @@ -123,6 +133,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -184,298 +191,132 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3 - __device__ static constexpr auto HotLoopScheduler(Stage stage) + __device__ static constexpr auto HotLoopScheduler() { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); - if constexpr(stage.value == 0) - { - constexpr auto staged_num_buffer_load_b_per_ds_read_a = - num_buffer_load_inst_b / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_buffer_load_b = - staged_num_mfma / num_buffer_load_inst_b; - // B global - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; + constexpr auto num_total_stages = MRepeat; - static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) { - ignore = ibuf_inst; + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / MRepeat; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / MRepeat; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto total_buffer_loads = num_buffer_load_inst_a + num_buffer_load_inst_b; + constexpr auto stages_available = MRepeat - DS_READ_A_PREFETCH_STAGES; + + constexpr auto stage_loads = compute_stage_loads(total_buffer_loads, stages_available); + + constexpr auto buffer_load_perstage_more = stage_loads.first; + constexpr auto buffer_load_perstage_less = stage_loads.second; + + constexpr auto buffer_load_stages_more = total_buffer_loads % stages_available; + + constexpr auto buffer_b_heavy_loads = buffer_load_perstage_more * buffer_load_stages_more; + constexpr auto buffer_b_remaining = + num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more; + + constexpr auto buffer_load_b_stages = + buffer_b_heavy_loads > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + buffer_b_remaining / buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - DS_READ_A_PREFETCH_STAGES - buffer_load_b_stages; + + static_assert(buffer_load_a_stages > 0, + "The buffer load a stages should always have a value over 0."); + + constexpr auto buffer_load_issue_point_interval_more = + math::integer_divide_ceil(num_mfma_perstage, buffer_load_perstage_more); + constexpr auto buffer_load_issue_point_interval_less = + buffer_load_perstage_less == 0 + ? INT32_MAX + : math::integer_divide_ceil(num_mfma_perstage, buffer_load_perstage_less); + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == 0)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == 0))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 1) - { - constexpr auto staged_num_mfma_per_ds_write_a = - math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; - - // A local write - static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } + SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); } }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 2) - { - constexpr auto staged_num_mfma_per_buffer_load_a = - math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a; - - // A global - static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - } - }); - - __builtin_amdgcn_sched_barrier(0); - } - else - { - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } - } - - template - __device__ static constexpr auto EpilogueScheduler_1(Stage stage) - { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - if constexpr(stage.value == 0) - { - constexpr auto staged_num_buffer_load_b_per_ds_read_a = - num_buffer_load_inst_b / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_buffer_load_b = - staged_num_mfma / num_buffer_load_inst_b; - // B global - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - - static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) { - ignore = ibuf_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 1) - { -#if 0 - constexpr auto staged_num_ds_write_a_per_ds_read_a = - num_ds_write_inst_a / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a; - // A local write - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - - static_for<0, staged_num_ds_write_a_per_ds_read_a, 1>{}([&](auto idswrite_inst) { - ignore = idswrite_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - }); - - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); -#elif 1 - constexpr auto staged_num_mfma_per_ds_write_a = - math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; - - // A local write - static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - }); -#endif - __builtin_amdgcn_sched_barrier(0); - } - else - { - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } - } - - __device__ static constexpr auto EpilogueScheduler_2() - { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); - __builtin_amdgcn_sched_barrier(0); + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == 0)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == 0))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_WRITE, 1, 0); + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); + } + }); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, DS_READ_A_PREFETCH_STAGES, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); + } + }); + }); } template {}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(I0, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(I0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, DS_READ_A_PREFETCH_STAGES, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + // K = k0 × KGroup × k1 = k0 × kg0 × A_K1 + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); }); // Initialize C @@ -558,26 +404,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(local_read_buf)); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - } - else if constexpr(m0.value == 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); - } - else if constexpr(m0.value == 2) - { - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { vector_type a_thread_vec; @@ -613,49 +451,88 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<0>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } - - HotLoopScheduler(m0); }); + HotLoopScheduler(); }; LoopFunc(I0, I1); @@ -667,20 +544,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(I1)); - } - else if constexpr(m0.value == MRepeat - 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { vector_type a_thread_vec; @@ -707,36 +578,68 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number<0>{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } - - EpilogueScheduler_1(m0); }); + HotLoopScheduler(); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -764,25 +667,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); - // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle - // latency - // __builtin_amdgcn_sched_barrier(0); + + HotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) { @@ -813,18 +720,21 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); } @@ -841,7 +751,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index ce507ca8d3..6c1c5b1c4d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -58,6 +58,11 @@ struct BlockwiseGemmXdlops_pipeline_base static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KRepeat = KPerThread / KPack; static constexpr index_t KPerInnerLoop = KPack; + static constexpr index_t KGroup = + ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) || + (MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64)) + ? 2 + : 1; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 238ab14606..c0d9464136 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -167,11 +167,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle using mfma_selector = MfmaSelector; static constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); + static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; static constexpr index_t KLane = mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); - static constexpr index_t KRepeat = KPerBlock / KLane / KPack; - static constexpr index_t NLane = NPerXdl; - static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + static constexpr index_t KPackPerGroup = KPack / KGroup; + static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup; + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; static constexpr auto MakeDsGridPointer() { @@ -209,7 +211,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle } __host__ __device__ static auto CalculateBK0Shuffled(index_t K) { - return math::integer_divide_ceil(K, KLane * KPack); + return math::integer_divide_ceil(K, KLane * KPackPerGroup); } __host__ __device__ static auto CalculateKPadded(index_t K) @@ -351,7 +353,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1228,7 +1230,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPackPerGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1668,7 +1670,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPackPerGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp index 39407cb8f6..6c788fb41e 100644 --- a/include/ck/utility/blkgemmpipe_scheduler.hpp +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -48,6 +48,15 @@ enum struct TailNumber // prefetchstages Full, }; + +enum SchedulerGroup : uint32_t +{ + SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions + SCHED_GROUP_VMEM = 0x020, // Global memory operations + SCHED_GROUP_LDS_READ = 0x100, // LDS read operations + SCHED_GROUP_LDS_WRITE = 0x200 // LDS write operations +}; + template >; // fp16 2:4 structured sparsity -#if defined(__gfx94__) || defined(__gfx95__) using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl>>; using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl>>; -#else // gfx 90a does not support smfmac -using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmImpl, - 2>>; -using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmImpl, - 2>>; -#endif // bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp index 97fd2a8742..cd6cd3a399 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp @@ -49,7 +49,7 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16 const int32_t& idx, bool_constant = {}) const { -#if defined(__gfx9__) +#if defined(__gfx94_) or defined(__gfx95_) c_vec = __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, c_vec, idx, 0, 0); #else ck_tile::ignore = c_vec; @@ -100,7 +100,7 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32 const int32_t& idx, bool_constant = {}) const { -#if defined(__gfx9__) +#if defined(__gfx94_) or defined(__gfx95_) c_vec = __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, c_vec, idx, 0, 0); #else ck_tile::ignore = c_vec; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_wp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_wp.hpp index 07891ea932..90a9fa381d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_wp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_wp.hpp @@ -18,173 +18,6 @@ namespace device { namespace instance { #if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8)) -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p1( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p1( std::vector && is_same_v && is_same_v) { - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances( - op_ptrs); - - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2( - op_ptrs); - - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1( op_ptrs); add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p2( @@ -612,33 +250,6 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances( - op_ptrs); - - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances_v2( - op_ptrs); - - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p1( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p1( op_ptrs); add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p2( diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt index 37233ac5b4..743a0272f7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt @@ -2,18 +2,6 @@ set(GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES) list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p1.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p2.cpp f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p1.cpp f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p2.cpp f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p3.cpp @@ -21,18 +9,6 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p5.cpp f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p6.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp @@ -41,18 +17,6 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp ) -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") @@ -60,18 +24,6 @@ set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p5.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p6.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp index e5ada03a46..4613a0f24d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp @@ -171,13 +171,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Compute friendly // 256x[64, 256, 32]x128 - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 16, 16, 8, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 16, 16, 16, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 160, 128, 16, 16, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 16, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 96, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 16, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; @@ -190,13 +190,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // 224x[64, 256, 32]x128 - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 14, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 224, 128, 16, 16, 16, 16, 7, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 7, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 14, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 160, 128, 16, 16, 16, 16, 7, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 14, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 96, 128, 16, 16, 16, 16, 7, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 14, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; template @@ -208,13 +208,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // 192x[64, 256, 32]x128, 192x[64]x256 - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 12, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 224, 128, 16, 16, 16, 16, 6, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 192, 128, 16, 16, 16, 16, 6, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 192, 128, 16, 16, 16, 16, 12, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 160, 128, 16, 16, 16, 16, 6, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 12, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 12, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; template @@ -226,13 +226,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // 160x[64, 256, 32]x128, 160x[64, 96, 32]x256 - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 5, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 10, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 224, 128, 16, 16, 16, 16, 5, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 192, 128, 16, 16, 16, 16, 5, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 192, 128, 16, 16, 16, 16, 10, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 160, 128, 16, 16, 16, 16, 5, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 5, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 10, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 96, 128, 16, 16, 16, 16, 5, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 10, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; template @@ -244,10 +244,10 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 128, 16, 16, 16, 16, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 8, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 16, 16, 4, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 8, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; @@ -259,11 +259,11 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 4, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 128, 16, 16, 16, 16, 4, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index c00554df8f..3839523e3d 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -535,11 +535,7 @@ struct GemmDispatcher { ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) content += f""" -#if defined(__gfx908__) - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); -#else - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); -#endif""" + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" content += f""" }} else {{""" for tile in tile_params: