diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 3c1947c058..c2301c405d 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -22,3 +22,10 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() + +set(GEMM_OPTIONS) +list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32") +list(APPEND GEMM_OPTIONS -v --save-temps -Wno-gnu-line-marker) +target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS}) +target_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) +target_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index e4e6a4f1a7..c20dcaf31b 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -140,14 +140,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, - 128, 128, 128, + 256, 256, 128, 16, 16, - 32, 32, - 2, 2, + 16, 16, + 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; + 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index dc56a044b3..c8ee8fc79b 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -157,11 +157,11 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t MXDLPerWave = 8; -static constexpr ck::index_t NXDLPerWave = 1; +static constexpr ck::index_t MPerBlock = 256; +static constexpr ck::index_t MXDLPerWave = 16; +static constexpr ck::index_t NXDLPerWave = 4; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t NPerBlock = 64; +static constexpr ck::index_t NPerBlock = 256; static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t Nswizzle = false; @@ -183,14 +183,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceM // mn_perxdl MNPerXDL, MNPerXDL, // mn_xdlperwave - MXDLPerWave, NXDLPerWave, + MXDLPerWave, NXDLPerWave, // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, 32, 1, 8>, S, + 2, 2, S<1, 32, 1, 8>, S, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>; // clang-format on @@ -268,6 +268,7 @@ int main(int argc, char* argv[]) { expert_ids.mData[i] = i / (valid_tile_num / experts); } + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 42d892fe26..3039b7fb19 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -123,11 +123,11 @@ using BElementOp = PassThrough; using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t MPerBlock = 256; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 4; +static constexpr ck::index_t MXDLPerWave = 16; static constexpr ck::index_t NXDLPerWave = 4; -static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t NPerBlock = 256; static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); @@ -136,11 +136,12 @@ static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 2; -static constexpr ck::index_t D0Vec = 1; -static constexpr ck::index_t D1Vec = 1; -static constexpr ck::index_t D2Vec = 1; -static constexpr bool MulRoutedWeight = true; -using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm +// TODO: Epilogue performance issue. AtomicAdd lose 15~20% performance compare with Set. +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t D2Vec = 1; +static constexpr bool MulRoutedWeight = true; +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off ///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| @@ -164,7 +165,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| @@ -186,8 +187,8 @@ int main(int argc, char* argv[]) ck::index_t N = 4096; ck::index_t K = 4096; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 16; - ck::index_t valid_tile_num = 13; + ck::index_t sorted_tile_num = 133; + ck::index_t valid_tile_num = 128; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t tokens = 128; @@ -247,11 +248,11 @@ int main(int argc, char* argv[]) Tensor max_token_id(HostTensorDescriptor({1})); max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { - expert_ids.mData[i] = eids[i]; + expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts); } if(tokens * topk > valid_size) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp index 60be7e2d07..b63c813955 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -59,25 +59,25 @@ template struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3 + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -185,163 +186,145 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3 - __device__ static constexpr auto HotLoopScheduler(Stage stage) + __device__ static constexpr auto HotLoopScheduler() { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2; - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2; + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); - if constexpr(stage.value == 0) - { - constexpr auto staged_num_buffer_load_b_per_ds_read_a = - num_buffer_load_inst_b / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_buffer_load_b = - staged_num_mfma / num_buffer_load_inst_b; - // B global - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; - static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) { - ignore = ibuf_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); + constexpr auto num_total_stages = MRepeat; + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = math::integer_divide_ceil( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = math::integer_divide_floor( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + + constexpr auto buffer_load_stages_more = + (num_buffer_load_inst_a + num_buffer_load_inst_b) - + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_b_stages = + buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + + (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) / + buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages; + + constexpr auto buffer_load_issue_point_b = 0; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto ds_write_issue_point = 0; + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 1) - { - constexpr auto staged_num_mfma_per_ds_write_a = - math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; - - // A local write - static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } - else + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read } }); + }); - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 2) - { - constexpr auto staged_num_mfma_per_buffer_load_a = - math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a; - - // A global - static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + ds_write_issue_point))) { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write } - else + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read } }); + }); - __builtin_amdgcn_sched_barrier(0); - } - else - { - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } }); - - __builtin_amdgcn_sched_barrier(0); - } + }); } template __device__ static constexpr auto EpilogueScheduler_1(Stage stage) { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst_b = + MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2; constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2; @@ -527,12 +510,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(local_read_buf)); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - b_blockwise_copy_up.Run(b_grid_desc, - b_grid_buf_up, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs_up(local_read_buf)); - b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - } - else if constexpr(m0.value == 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); - } - else if constexpr(m0.value == 2) - { - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { vector_type a_thread_vec; @@ -617,11 +591,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}]; - + b_thread_vec_up.template AsType()(ik) = b_thread_bufs_up[mfma_reg_buf] - [Number{}]; + [Number{}]; }); using mfma_input_type = @@ -635,7 +609,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); - + xdlops_gemm.Run( a_thread_vec.template AsType(), b_thread_vec_up.template AsType(), @@ -643,48 +617,87 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } - - HotLoopScheduler(m0); + HotLoopScheduler(); }); }; @@ -697,26 +710,19 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(I1)); - - b_blockwise_copy_up.Run(b_grid_desc, - b_grid_buf_up, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs_up(I1)); - } - else if constexpr(m0.value == MRepeat - 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { vector_type a_thread_vec; @@ -730,7 +736,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3()(ik) = b_thread_bufs[I0][Number{}]; - + b_thread_vec_up.template AsType()(ik) = b_thread_bufs_up[I0][Number{}]; @@ -745,43 +751,78 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); - + xdlops_gemm.Run(a_thread_vec.template AsType(), b_thread_vec_up.template AsType(), c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); - - if constexpr(m0.value == MRepeat - 1) + if constexpr(m0.value == (MRepeat - 2)) { block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == MRepeat - 1) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } - - EpilogueScheduler_1(m0); }); + HotLoopScheduler(); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -810,7 +851,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); - + xdlops_gemm.Run(a_thread_vec.template AsType(), b_thread_vec_up.template AsType(), c_thread_buf_up.GetVectorTypeReference(Number{})); @@ -820,22 +861,28 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); + + HotLoopScheduler(); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency - // __builtin_amdgcn_sched_barrier(0); } else if constexpr(TailNum == TailNumber::Odd) { @@ -873,18 +920,21 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); } @@ -913,7 +963,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}, I1, Number{}, Number{})); static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; - }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp index d751543175..a003befc3c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -280,12 +281,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -348,12 +351,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -411,12 +417,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -495,7 +503,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp index 4c019a41a4..2d4ae048ac 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp @@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -281,12 +282,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(I0)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(I0)); + }); }); }); @@ -320,12 +323,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(local_read_buf)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(local_read_buf)); + }); }); }); @@ -391,12 +397,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_reg), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(local_read_reg)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_reg), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(local_read_reg)); + }); }); }); @@ -445,12 +453,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_reg), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(local_read_reg)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_reg), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(local_read_reg)); + }); }); }); @@ -539,7 +549,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp index 6d115e7620..6af20af484 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp @@ -123,6 +123,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -184,298 +185,230 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3 - __device__ static constexpr auto HotLoopScheduler(Stage stage) + __device__ static constexpr auto HotLoopScheduler() { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - if constexpr(stage.value == 0) - { - constexpr auto staged_num_buffer_load_b_per_ds_read_a = - num_buffer_load_inst_b / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_buffer_load_b = - staged_num_mfma / num_buffer_load_inst_b; - // B global - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - - static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) { - ignore = ibuf_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 1) - { - constexpr auto staged_num_mfma_per_ds_write_a = - math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; - - // A local write - static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 2) - { - constexpr auto staged_num_mfma_per_buffer_load_a = - math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a; - - // A global - static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - } - }); - - __builtin_amdgcn_sched_barrier(0); - } - else - { - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } - } - - template - __device__ static constexpr auto EpilogueScheduler_1(Stage stage) - { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - if constexpr(stage.value == 0) - { - constexpr auto staged_num_buffer_load_b_per_ds_read_a = - num_buffer_load_inst_b / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_buffer_load_b = - staged_num_mfma / num_buffer_load_inst_b; - // B global - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - - static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) { - ignore = ibuf_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 1) - { #if 0 - constexpr auto staged_num_ds_write_a_per_ds_read_a = - num_ds_write_inst_a / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a; - // A local write - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; - static_for<0, staged_num_ds_write_a_per_ds_read_a, 1>{}([&](auto idswrite_inst) { - ignore = idswrite_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - }); + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_mfma_per_issue_more = math::integer_divide_ceil( + num_mfma_inst, num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_mfma_per_issue_less = math::integer_divide_floor( + num_mfma_inst, num_buffer_load_inst_a + num_buffer_load_inst_b); + // Insert more mfmas between bufferloads + constexpr auto num_stage1_bufferloads = + num_mfma_inst - + (num_buffer_load_inst_a + num_buffer_load_inst_b) * num_mfma_per_issue_less; + constexpr auto num_stage1_mfma = num_mfma_per_issue_more * num_stage1_bufferloads; + // Insert less mfmas between bufferloads + // constexpr auto num_stage2_mfma = num_mfma_inst - num_stage1_mfma; + + constexpr auto buffer_load_issue_point = 0; + constexpr auto ds_write_issue_point_stage1 = num_mfma_per_issue_more >= 3 ? 1 : 0; + constexpr auto ds_write_issue_point_stage2 = num_mfma_per_issue_less >= 3 ? 1 : 0; + + static_for<0, num_mfma_inst, 1>{}([&](auto i) { + constexpr auto current_buffer_load_issue = + i < num_stage1_mfma + ? (i / num_mfma_per_issue_more) + : (num_stage1_bufferloads + (i - num_stage1_mfma) / num_mfma_per_issue_less); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + // Group num_mfma_perstage num_ds_read_a_perstage + // Hide A lds rd issue latency at begining of each stage + if constexpr((i % num_mfma_perstage) >= + (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + + // Schedule VMEM access instruction distributed evenly in the loop + // Hide B/A global rd issue latency + if constexpr(((i < num_stage1_mfma) && + (i % num_mfma_per_issue_more == buffer_load_issue_point)) || + ((i >= num_stage1_mfma) && + ((i - num_stage1_mfma) % num_mfma_per_issue_less == + buffer_load_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + // Hide A lds wr issue latency + if constexpr((current_buffer_load_issue >= num_buffer_load_inst_b) && + ((((i < num_stage1_mfma) && + (i % num_mfma_per_issue_more == ds_write_issue_point_stage1)) || + ((i >= num_stage1_mfma) && + ((i - num_stage1_mfma) % num_mfma_per_issue_less == + ds_write_issue_point_stage2))) && + (((i < num_stage1_mfma) && + ((i / num_mfma_per_issue_more - num_buffer_load_inst_b) < num_ds_write_inst_a)) || + ((i >= num_stage1_mfma) && + ((i - num_stage1_mfma) / num_mfma_per_issue_less + + num_stage1_bufferloads - num_buffer_load_inst_b) < num_ds_write_inst_a)))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); #elif 1 - constexpr auto staged_num_mfma_per_ds_write_a = - math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - // A local write - static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = math::integer_divide_ceil( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = math::integer_divide_floor( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + + constexpr auto buffer_load_stages_more = + (num_buffer_load_inst_a + num_buffer_load_inst_b) - + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_b_stages = + buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + + (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) / + buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages; + + constexpr auto buffer_load_issue_point_b = 0; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto ds_write_issue_point = 0; + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } - else + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read } }); -#endif - __builtin_amdgcn_sched_barrier(0); - } - else - { - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } - } - - __device__ static constexpr auto EpilogueScheduler_2() - { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); - __builtin_amdgcn_sched_barrier(0); + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); +#endif } template {}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(I0, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(I0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, 2, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); }); // Initialize C @@ -558,26 +495,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(local_read_buf)); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - } - else if constexpr(m0.value == 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); - } - else if constexpr(m0.value == 2) - { - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { vector_type a_thread_vec; @@ -613,49 +542,88 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } - - HotLoopScheduler(m0); }); + HotLoopScheduler(); }; LoopFunc(I0, I1); @@ -667,20 +635,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(I1)); - } - else if constexpr(m0.value == MRepeat - 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { vector_type a_thread_vec; @@ -707,36 +669,72 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } - - EpilogueScheduler_1(m0); }); + HotLoopScheduler(); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -764,25 +762,31 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); + + HotLoopScheduler(); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency - // __builtin_amdgcn_sched_barrier(0); } else if constexpr(TailNum == TailNumber::Odd) { @@ -813,18 +817,21 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); } @@ -841,7 +848,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index ce507ca8d3..6c1c5b1c4d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -58,6 +58,11 @@ struct BlockwiseGemmXdlops_pipeline_base static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KRepeat = KPerThread / KPack; static constexpr index_t KPerInnerLoop = KPack; + static constexpr index_t KGroup = + ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) || + (MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64)) + ? 2 + : 1; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp index 8375e81fa0..ea4f5e4a28 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp @@ -205,7 +205,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale; static constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); + static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; static constexpr index_t KLane = mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); - static constexpr index_t KRepeat = KPerBlock / KLane / KPack; + static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); static constexpr index_t NLane = NPerXdl; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; @@ -209,7 +210,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle } __host__ __device__ static auto CalculateBK0Shuffled(index_t K) { - return math::integer_divide_ceil(K, KLane * KPack); + return math::integer_divide_ceil(K, KLane * KPack / KGroup); } __host__ __device__ static auto CalculateKPadded(index_t K) @@ -351,7 +352,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1228,7 +1229,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1668,7 +1669,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index ead4399d81..d963980f26 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -188,7 +188,9 @@ struct GridwiseMoeGemm math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); static constexpr index_t KLane = mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); - static constexpr index_t KRepeat = KPerBlock / KLane / KPack; + + static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; + static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); static constexpr index_t NLane = NPerXdl; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; // static constexpr index_t NumTokens = 1; @@ -249,7 +251,7 @@ struct GridwiseMoeGemm } __host__ __device__ static auto CalculateBK0Shuffled(index_t K) { - return math::integer_divide_ceil(K, KLane * KPack); + return math::integer_divide_ceil(K, KLane * KPack / KGroup); } __host__ __device__ static auto CalculateKPadded(index_t K) @@ -391,7 +393,7 @@ struct GridwiseMoeGemm __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1301,7 +1303,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2012,7 +2014,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2081,23 +2083,24 @@ struct GridwiseMoeGemm c_thread_buf_up, num_k_block_main_loop); } - else - { + else + { - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_bufs, - a_block_slice_copy_step, - b_grid_desc_bpreshuffled, - b_blockwise_copy, - b_grid_buf, - b_block_bufs, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); - } + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + } // shuffle C and write out { diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp index 39407cb8f6..34c874353d 100644 --- a/include/ck/utility/blkgemmpipe_scheduler.hpp +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -90,6 +90,17 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst static constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + static constexpr index_t C_MFMA_Inst_Cycle = []() { + if constexpr(NPerXDL == 16) + { + return KPerXDL == 128 ? 32 : 16; + } + else if constexpr(NPerXDL == 32) + { + return KPerXDL == 64 ? 64 : 32; + } + }(); + static constexpr auto Print() { printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n", @@ -103,7 +114,7 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst KPerXDL); printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " - "%d, %d\n C MFMA inst: %d\n" + "%d, %d\n C MFMA inst: %d C MFMA cycle: %d\n" "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: " "%d/ %d\n", A_Buffer_Load_Inst_Num, @@ -113,6 +124,7 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst A_LDS_Read_Inst_Num, B_LDS_Read_Inst_Num, C_MFMA_Inst_Num, + C_MFMA_Inst_Cycle, A_LDS_Read_Width, B_LDS_Read_Width, ALDSWriteWidth,