diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index 0fcf9680bc..4a83a2c4ab 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -63,25 +63,30 @@ struct UniversalInvoker const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index ad1862306a..53bfa6041d 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -35,7 +35,8 @@ template // The number of continuous xdl_output per warp + index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp + bool DoubleSmemBuffer_ = false> struct CShuffleEpilogueProblem { using AsDataType = remove_cvref_t; @@ -59,6 +60,7 @@ struct CShuffleEpilogueProblem static constexpr bool FixedVectorSize = FixedVectorSize_; static constexpr index_t VectorSizeC = VectorSizeC_; static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_; + static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; static constexpr index_t kNumWaveGroups = kNumWaveGroups_; static constexpr index_t NumDTensor = DsDataType::size(); @@ -118,6 +120,7 @@ struct CShuffleEpilogue static constexpr bool FixedVectorSize = Problem::FixedVectorSize; static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp; + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t MPerIteration = MPerXdl * MWave; static constexpr index_t NPerIteration = NPerXdl * NWave; @@ -204,6 +207,26 @@ struct CShuffleEpilogue } return max_vector_size / sizeof(DiDataType); } + + /** + * @brief Shuffle tile configuration parameters check and aligment + * + * @details Return tuple(1, 1) if shuffle_tile values are too large for SMEM. + */ + template + CK_TILE_HOST_DEVICE static constexpr auto AlignShuffleTileWithSmem() + { + constexpr index_t m_val = MPerXdl * MWave * m_shuffle_tile; + constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile; + + constexpr auto shuffle_tile = + m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer + ? std::make_tuple(1, 1) + : std::make_tuple(m_shuffle_tile, n_shuffle_tile); + + return shuffle_tile; + } + /** * @brief Shuffle tile configuration parameters * @@ -214,20 +237,23 @@ struct CShuffleEpilogue */ static constexpr auto shuffle_tile_tuple = [] { constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size(); - if constexpr(elem_per_thread >= GetVectorSizeC()) + if constexpr(elem_per_thread <= GetVectorSizeC()) { return std::make_tuple(1, 1); } else { - constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread; + constexpr index_t num_xdl_shuffles = elem_per_thread / GetVectorSizeC(); + static_assert(elem_per_thread % GetVectorSizeC() == 0); if constexpr(std::is_same_v) { static_assert((kMPerBlock % (MPerXdl * MWave) == 0) && (kMPerBlock % num_xdl_shuffles == 0), "kMPerBlock must be divisible by MPerXdl*MWave and " "num_xdl_shuffles for CShuffleEpilogue"); - return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1); + return AlignShuffleTileWithSmem(); } else { @@ -235,7 +261,9 @@ struct CShuffleEpilogue (kNPerBlock % num_xdl_shuffles == 0), "kNPerBlock must be divisible by NPerXdl*NWave and " "num_xdl_shuffles for CShuffleEpilogue"); - return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave))); + return AlignShuffleTileWithSmem<1, + min(num_xdl_shuffles, + kNPerBlock / (NPerXdl * NWave))>(); } } }(); diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 8adbfb9723..2761b16571 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -232,7 +232,7 @@ struct BatchedGemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr1[GetSmemSize()]; + __shared__ char smem_ptr1[GemmPipeline::GetSmemSize()]; UniversalGemmKernel::RunGemm2LDS({a_ptr}, {b_ptr}, {/*ds_ptr*/}, diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 838fc236d2..95114e8496 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -310,7 +310,7 @@ struct GroupedGemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; RunGemmWithPipelineSelection2LDS(a_ptr, b_ptr, c_ptr, diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 866a4cc693..5f7e78fac2 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -1084,7 +1084,7 @@ struct UniversalGemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) @@ -1169,7 +1169,7 @@ struct UniversalGemmKernel // Run the GEMM if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index d9af5cce1f..3e97380374 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -1324,7 +1324,7 @@ struct QuantGemmKernel assert(kargs.k_batch == 1); if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; RunGemm2LDS(a_ptr, b_ptr, diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 726f678d37..7e246961cb 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -325,7 +325,7 @@ struct QuantGroupedGemmKernel kQuantType == QuantType::BQuantGrouped) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; RunGemmWithPipelineSelection2LDS(a_ptr, b_ptr, aq_ptr, diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index 46c60cb6d7..6e1ac39509 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -1048,7 +1048,7 @@ struct GroupedConvolutionBackwardDataKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index c9e81d4744..2e80ff64c1 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -1005,7 +1005,7 @@ struct GroupedConvolutionBackwardWeightKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && GroupedConvTraitsType_::VectorSizeC % 2 != 0 && diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index a9f3274805..0f143d7ff7 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -1184,7 +1184,7 @@ struct GroupedConvolutionForwardKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && GroupedConvTraitsType_::VectorSizeC % 2 != 0 && diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index a0c078a1e9..e949ed45e6 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -202,7 +202,13 @@ class TestCkTileGemmPipeline : public ::testing::Test N_Warp_Tile, K_Warp_Tile, UniversalGemmProblem::TransposeC, - memory_operation>>; + memory_operation, + 1, /*kNumWaveGroups_*/ + false, /*FixedVectorSize_*/ + 1, /*VectorSizeC_*/ + false, /*TiledMMAPermuteN_*/ + 1, /*BlockedXDLN_PerWarp_*/ + DoubleSmemBuffer /*DoubleSmemBuffer*/>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args);