From e2c8f98fef3bc08dc38ce5951cff15f1a7979bc0 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 16 May 2025 10:41:59 +0000 Subject: [PATCH] generalize the pipeline scheduling. --- example/67_gemm_microscaling/gemm_mx_fp4.cpp | 4 +- .../blockwise_gemm_pipeline_xdlops_v3_mx.hpp | 128 ++++++++++++------ .../impl/device_gemm_xdl_cshuffle_v3_mx.hpp | 8 -- 3 files changed, 90 insertions(+), 50 deletions(-) diff --git a/example/67_gemm_microscaling/gemm_mx_fp4.cpp b/example/67_gemm_microscaling/gemm_mx_fp4.cpp index 84c12bd0ac..b3b6345871 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp4.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp4.cpp @@ -50,14 +50,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle GemmSpec, // GemmSpec ScaleBlockSize, // ScaleBlockSize: Scaling block size 256, // BlockSize: Thread block size - 128, // MPerBlock + 192, // MPerBlock 256, // NPerBlock KPerBlock, // KPerBlock 16, // AK1 16, // BK1 16, // MPerXDL 16, // NPerXDL - 4, // MXdlPerWave + 6, // MXdlPerWave 8, // NXdlPerWave S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp index 9f9a104fe5..aa91227bce 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp @@ -208,6 +208,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto i) { - if constexpr(i < num_buffer_load_a_scale) - { + if constexpr(i< mfma_stages_more){ + static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a){ + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else{ + static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a){ + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } - static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { - ignore = idswrite; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier( - 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { - if constexpr(i < num_buffer_load_b_scale) - { + if constexpr((i+num_buffer_load_inst_a)< mfma_stages_more){ + static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a){ + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else{ + static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_b){ + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i+num_buffer_load_inst_a+num_buffer_load_inst_b)< mfma_stages_more){ + static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else{ + static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i+num_buffer_load_inst_a+num_buffer_load_inst_b+num_buffer_load_a_scale)< mfma_stages_more){ + static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else{ + static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } - static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { - ignore = idswrite; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier( - 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA }); // stage 2 static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= ds_read_a_mfma_rate) { @@ -279,10 +332,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= ds_read_b_mfma_rate) { @@ -295,9 +348,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto m0) { static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { @@ -503,9 +561,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto n0) { static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { @@ -528,9 +583,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx()(ik) = b_thread_buf[Number{}]; - CK_PRINT, - Number, - Number - >(); }); using mfma_input_type_a = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index 699c6c40cb..ebed17cfb0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -220,14 +220,6 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX, pk_i4_t> || - is_same_v, f4x2_pk_t>) - return 2; - else - return 1; - }(); - if(stream_config.log_level_ > 0) { arg.Print();