diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index a9ec0e1819..dbce3bcc04 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -13,7 +13,7 @@ #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" #include "device_gemm_xdl_cshuffle.hpp" -#include "device_gemm_xdl_cshuffle_v2.hpp" +#include "device_gemm_xdl_producer_consumer_cshuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" @@ -57,24 +57,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // // 2-stage prefetch // < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; -#elif 0 -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v2 +#elif 1 +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_ProducerConsumer_CShuffle //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // all thread - < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 0, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; -// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 0, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; -#elif 0 -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v2 -//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // producer & consumer - < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>; -// < Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>; +// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>; + < Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>; #elif 1 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_producer_consumer_cshuffle.hpp similarity index 98% rename from include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp rename to include/ck/tensor_operation/gpu/device/device_gemm_xdl_producer_consumer_cshuffle.hpp index 7875be9dd7..9a208b99c1 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_producer_consumer_cshuffle.hpp @@ -7,8 +7,8 @@ #include "tensor_layout.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdl_cshuffle_v2.hpp" -#include "tensor_operation/gpu/device/gemm_specialization.hpp" +#include "gridwise_gemm_xdl_producer_consumer_cshuffle.hpp" +#include "gemm_specialization.hpp" namespace ck { namespace tensor_operation { @@ -56,10 +56,10 @@ template -struct DeviceGemm_Xdl_CShuffle_v2 +struct DeviceGemm_Xdl_ProducerConsumer_CShuffle : public DeviceGemm { - using DeviceOp = DeviceGemm_Xdl_CShuffle_v2; + using DeviceOp = DeviceGemm_Xdl_ProducerConsumer_CShuffle; static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -334,7 +334,7 @@ struct DeviceGemm_Xdl_CShuffle_v2 using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2< + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_producer_consumer_cshuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -471,7 +471,7 @@ struct DeviceGemm_Xdl_CShuffle_v2 if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdl_cshuffle_v2< + const auto kernel = kernel_gemm_xdl_producer_consumer_cshuffle< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, @@ -523,7 +523,7 @@ struct DeviceGemm_Xdl_CShuffle_v2 } else { - const auto kernel = kernel_gemm_xdl_cshuffle_v2< + const auto kernel = kernel_gemm_xdl_producer_consumer_cshuffle< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, @@ -672,7 +672,7 @@ struct DeviceGemm_Xdl_CShuffle_v2 auto str = std::stringstream(); // clang-format off - str << "DeviceGemm_Xdl_CShuffle_v2" + str << "DeviceGemm_Xdl_ProducerConsumer_CShuffle" << "<" << ABBlockTransferThreadGroupSize << ", " << BlockGemmThreadGroupSize << ", " diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp new file mode 100644 index 0000000000..6b04881715 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp @@ -0,0 +1,234 @@ +#pragma once + +#include "common_header.hpp" + +namespace ck { + +template +struct GridwiseGemmPipelineProducerConsumer; + +// 1-stage prefetch +template +struct GridwiseGemmPipelineProducerConsumer +{ + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + // TODO: improve applicability + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop / 2 > 1; + } + + template + static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_block_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_block_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + index_t num_loop) + { + // global read 0 + a_block_copy.RunRead(a_grid_desc, a_grid_buf); + b_block_copy.RunRead(b_grid_desc, b_grid_buf); + + // move to 1 + a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // LDS write 0 + a_block_copy.RunWrite(a_block_desc, a_block_buf); + // global Read 1 + a_block_copy.RunRead(a_grid_desc, a_grid_buf); + + // LDS write 0 + b_block_copy.RunWrite(b_block_desc, b_block_buf); + // global Read 1 + b_block_copy.RunRead(b_grid_desc, b_grid_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds(); + + // GEMM i + + block_sync_lds(); + + // move to i + 2 + a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // LDS write i + 1 + a_block_copy.RunWrite(a_block_desc, a_block_buf); + // global read i + 2 + a_block_copy.RunRead(a_grid_desc, a_grid_buf); + + // LDS write i + 1 + b_block_copy.RunWrite(b_block_desc, b_block_buf); + // global read i + 2 + b_block_copy.RunRead(b_grid_desc, b_grid_buf); + + ++i; + } while(i < (num_loop - 2)); + } + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 2 + + block_sync_lds(); + + // LDS write num_loop - 1 + a_block_copy.RunWrite(a_block_desc, a_block_buf); + b_block_copy.RunWrite(b_block_desc, b_block_buf); + + block_sync_lds(); + + // GEMM num_loop - 1 + } + } + + template + static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf, + BBlockBuffer& b_block_buf, + const BlockwiseGemm& block_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds(); + + // GEMM i + block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + // move to i + 2 + + // LDS write i + 1 + // global read i + 2 + + // LDS write i + 1 + // global read i + 2 + + ++i; + } while(i < (num_loop - 2)); + } + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 2 + block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + // LDS write num_loop - 1 + + block_sync_lds(); + + // GEMM num_loop - 1 + block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } + + template + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_block_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_block_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& block_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + if(ABBlockTransferThreadGroup::IsBelong()) + { + RunABBlockTransferPipeline(a_grid_desc, + a_block_desc, + a_block_copy, + a_grid_buf, + a_block_buf, + a_block_copy_step, + b_grid_desc, + b_block_desc, + b_block_copy, + b_grid_buf, + b_block_buf, + b_block_copy_step, + num_loop); + } + else if(BlockGemmThreadGroup::IsBelong()) + { + RunBlockGemmPipeline( + a_block_buf, b_block_buf, block_gemm, c_thread_buf, num_loop); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index a11a5343c3..6a1b6eef31 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -51,7 +51,6 @@ struct GridwiseGemmPipeline_v1<1> CThreadBuffer& c_thread_buf, index_t num_loop) { -#if 1 // preload data into LDS a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); @@ -98,80 +97,6 @@ struct GridwiseGemmPipeline_v1<1> blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); } -#elif 1 - // global read 0 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - // move to 1 - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Initialize C - c_thread_buf.Clear(); - - // LDS write 0 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - // global Read 1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - - // LDS write 0 - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - // global Read 1 - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - // main body - // FIXME: HasMainLoop = (num_loop) > 2 - if constexpr(HasMainLoop) - { - index_t i = 0; - - do - { - block_sync_lds(); - - // GEMM i - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - // move to i + 2 - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // LDS write i + 1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - // global read i + 2 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - - // LDS write i + 1 - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - // global read i + 2 - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - ++i; - } while(i < (num_loop - 2)); - } - - // tail - { - block_sync_lds(); - - // GEMM num_loop - 2 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - // LDS write num_loop - 1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - - block_sync_lds(); - - // GEMM num_loop - 1 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } -#endif } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp index 7c1947e4de..a5e53f1489 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp @@ -1,20 +1,10 @@ #pragma once - #include "common_header.hpp" namespace ck { -template struct GridwiseGemmPipeline_v2 { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - __device__ constexpr GridwiseGemmPipeline_v2() - { - // TODO static assert - } - __host__ __device__ static constexpr bool IsSupported(index_t num_loop) { // TODO: improve applicability @@ -23,161 +13,7 @@ struct GridwiseGemmPipeline_v2 __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) { - return num_loop / 2 > 1; - } - - template - static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc, - const ABlockDesc& a_block_desc, - ABlockTransfer& a_block_copy, - const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, - const ABlockTransferStep& a_block_copy_step, - const BGridDesc& b_grid_desc, - const BBlockDesc& b_block_desc, - BBlockTransfer& b_block_copy, - const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, - const BBlockTransferStep& b_block_copy_step, - index_t num_loop) - { - // global read 0 - a_block_copy.RunRead(a_grid_desc, a_grid_buf); - b_block_copy.RunRead(b_grid_desc, b_grid_buf); - - // move to 1 - a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // LDS write 0 - a_block_copy.RunWrite(a_block_desc, a_block_buf); - // global Read 1 - a_block_copy.RunRead(a_grid_desc, a_grid_buf); - - // LDS write 0 - b_block_copy.RunWrite(b_block_desc, b_block_buf); - // global Read 1 - b_block_copy.RunRead(b_grid_desc, b_grid_buf); - - // main body - // FIXME: HasMainLoop = (num_loop) > 2 - if constexpr(HasMainLoop) - { - index_t i = 0; - - do - { - block_sync_lds(); - - // GEMM i - - block_sync_lds(); - - // move to i + 2 - a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // LDS write i + 1 - a_block_copy.RunWrite(a_block_desc, a_block_buf); - // global read i + 2 - a_block_copy.RunRead(a_grid_desc, a_grid_buf); - - // LDS write i + 1 - b_block_copy.RunWrite(b_block_desc, b_block_buf); - // global read i + 2 - b_block_copy.RunRead(b_grid_desc, b_grid_buf); - - ++i; - } while(i < (num_loop - 2)); - } - - // tail - { - block_sync_lds(); - - // GEMM num_loop - 2 - - block_sync_lds(); - - // LDS write num_loop - 1 - a_block_copy.RunWrite(a_block_desc, a_block_buf); - b_block_copy.RunWrite(b_block_desc, b_block_buf); - - block_sync_lds(); - - // GEMM num_loop - 1 - } - } - - template - static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf, - BBlockBuffer& b_block_buf, - const BlockwiseGemm& block_gemm, - CThreadBuffer& c_thread_buf, - index_t num_loop) - { - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainLoop) - { - index_t i = 0; - - do - { - block_sync_lds(); - - // GEMM i - block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - // move to i + 2 - - // LDS write i + 1 - // global read i + 2 - - // LDS write i + 1 - // global read i + 2 - - ++i; - } while(i < (num_loop - 2)); - } - - // tail - { - block_sync_lds(); - - // GEMM num_loop - 2 - block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - // LDS write num_loop - 1 - - block_sync_lds(); - - // GEMM num_loop - 1 - block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } + return (num_loop / 2) > 1; } template - static __device__ void Run(const AGridDesc& a_grid_desc, + __device__ static void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, - ABlockTransfer& a_block_copy, + ABlockTransfer& a_blockwise_copy, const AGridBuffer& a_grid_buf, ABlockBuffer& a_block_buf, const ABlockTransferStep& a_block_copy_step, const BGridDesc& b_grid_desc, const BBlockDesc& b_block_desc, - BBlockTransfer& b_block_copy, + BBlockTransfer& b_blockwise_copy, const BGridBuffer& b_grid_buf, BBlockBuffer& b_block_buf, const BBlockTransferStep& b_block_copy_step, - const BlockwiseGemm& block_gemm, + const BlockwiseGemm& blockwise_gemm, CThreadBuffer& c_thread_buf, index_t num_loop) { - if(ABBlockTransferThreadGroup::IsBelong()) + // global read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // move to 1 + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // LDS write 0 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // global Read 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + // LDS write 0 + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + // global Read 1 + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // main body + if constexpr(HasMainLoop) { - RunABBlockTransferPipeline(a_grid_desc, - a_block_desc, - a_block_copy, - a_grid_buf, - a_block_buf, - a_block_copy_step, - b_grid_desc, - b_block_desc, - b_block_copy, - b_grid_buf, - b_block_buf, - b_block_copy_step, - num_loop); + index_t i = 0; + + do + { + block_sync_lds(); + + // GEMM i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + // move to i + 2 + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // LDS write i + 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // global read i + 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + // LDS write i + 1 + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + // global read i + 2 + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + ++i; + } while(i < (num_loop - 2)); } - else if(BlockGemmThreadGroup::IsBelong()) + + // tail { - RunBlockGemmPipeline( - a_block_buf, b_block_buf, block_gemm, c_thread_buf, num_loop); + block_sync_lds(); + + // GEMM num_loop - 2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + // LDS write num_loop - 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + block_sync_lds(); + + // GEMM num_loop - 1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); } } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index b0a3774a26..02e70fd3f8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -8,6 +8,7 @@ #include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "gridwise_gemm_pipeline_v1.hpp" +#include "gridwise_gemm_pipeline_v2.hpp" namespace ck { @@ -127,7 +128,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using ThisThreadBlock = ThisThreadBlock; +#if 1 using GridwiseGemmPipe = GridwiseGemmPipeline_v1; +#else + using GridwiseGemmPipe = GridwiseGemmPipeline_v2; +#endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_producer_consumer_cshuffle.hpp similarity index 95% rename from include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_producer_consumer_cshuffle.hpp index f4a0741390..f5cf8de41a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_producer_consumer_cshuffle.hpp @@ -7,8 +7,7 @@ #include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" -#include "gridwise_gemm_pipeline_v1.hpp" -#include "gridwise_gemm_pipeline_v2.hpp" +#include "gridwise_gemm_pipeline_producer_consumer.hpp" namespace ck { @@ -27,18 +26,20 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdl_cshuffle_v2(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdl_producer_consumer_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map) { +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -52,6 +53,18 @@ __global__ void b_grid_desc_bk0_n_bk1, c_grid_desc_mblock_mperblock_nblock_nperblock, block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_ctile_map; +#endif // end of #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) } template -struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 +struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_producer_consumer_cshuffle { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -114,10 +127,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 static constexpr auto AK1 = Number{}; static constexpr auto BK1 = Number{}; - using ThisThreadBlock = - ThisThreadBlock; - -#if 0 struct ABBlockTransferThreadGroup { __device__ static constexpr index_t GetNumOfThread() @@ -151,22 +160,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 } }; - using CShuffleBlockTransferThreadGroup = ThisThreadBlock; -#else - using ABBlockTransferThreadGroup = ThisThreadBlock; - using BlockGemmThreadGroup = ThisThreadBlock; - using CShuffleBlockTransferThreadGroup = ThisThreadBlock; -#endif + using CShuffleBlockTransferThreadGroup = + ThisThreadBlock; -#if 1 - // gridwise GEMM pipeline - using GridwiseGemmPipe = GridwiseGemmPipeline_v1; -#else - // gridwise GEMM pipeline - using GridwiseGemmPipe = GridwiseGemmPipeline_v2; -#endif + using GridwiseGemmPipe = GridwiseGemmPipelineProducerConsumer; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { diff --git a/library/src/utility/conv_fwd_util.cpp b/library/src/utility/conv_fwd_util.cpp index fde2caa56b..1658450388 100644 --- a/library/src/utility/conv_fwd_util.cpp +++ b/library/src/utility/conv_fwd_util.cpp @@ -37,16 +37,16 @@ std::size_t get_flops(ck::index_t N, } ConvParams::ConvParams() - : num_dim_spatial(2), - N(128), - K(256), - C(192), - filter_spatial_lengths(2, 3), - input_spatial_lengths(2, 71), - conv_filter_strides(2, 2), - conv_filter_dilations(2, 1), - input_left_pads(2, 1), - input_right_pads(2, 1) + : num_dim_spatial(2), + N(128), + K(256), + C(192), + filter_spatial_lengths(2, 3), + input_spatial_lengths(2, 71), + conv_filter_strides(2, 2), + conv_filter_dilations(2, 1), + input_left_pads(2, 1), + input_right_pads(2, 1) { } @@ -77,9 +77,9 @@ ConvParams::ConvParams(ck::index_t n_dim, conv_filter_dilations.size() != num_dim_spatial || input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial) { - throw(std::runtime_error( - "ConvParams::GetOutputSpatialLengths: " - "parameter size is different from number of declared dimensions!")); + throw( + std::runtime_error("ConvParams::GetOutputSpatialLengths: " + "parameter size is different from number of declared dimensions!")); } } @@ -91,9 +91,9 @@ std::vector ConvParams::GetOutputSpatialLengths() const conv_filter_dilations.size() != num_dim_spatial || input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial) { - throw(std::runtime_error( - "ConvParams::GetOutputSpatialLengths: " - "parameter size is different from number of declared dimensions!")); + throw( + std::runtime_error("ConvParams::GetOutputSpatialLengths: " + "parameter size is different from number of declared dimensions!")); } std::vector out_spatial_len(num_dim_spatial, 0); @@ -101,8 +101,7 @@ std::vector ConvParams::GetOutputSpatialLengths() const { // XEff = (X - 1) * conv_dilation_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - const ck::index_t idx_eff = - (filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1; + const ck::index_t idx_eff = (filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1; out_spatial_len[i] = (input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) / conv_filter_strides[i] +