From 89093ac4317b976d0bb41fe709180865422a98a9 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Mon, 27 Jan 2025 16:37:19 +0100 Subject: [PATCH] [CK-Tile] Enable vectorized reads on all layouts & improve perf. (#1835) * Refactor universal gemm policy. * Adapt example to refactor changes. * Introduce static encoding pattern * Adding shuffled encoding patterns. * Fix err in reverse tuple. * Add transpose_tile2d * Small refactoring + doc * Enable reading on contiguous dimension in all layouts. * Transpose A/B register tile if needed for comp v3 pipeline. * Take contiguous dim size when calculating dram vector load size. * A/B smem pack size taken from WarpGemm attributes * Update B LDS layout and setup tile distribution pattern at class level. * Fix static assert. * Fix errors in examples. * Formatting & fix IsTranspose * Fix VectorSize & refactor. * Add error loging messages. * Fix VecLoadSize and TranspseC for mem pipeline. * Update unit-tests & disable mem pipeline. * Clang format * Update include/ck_tile/core/tensor/tile_window.hpp Co-authored-by: jakpiase * Fix compilation and reviewers comments. * Refactor unit-test. Fallback to non-universal gemm. Need to use GemmPipelineAGmemBGmemCRegV1 for now, since GemmKernel is now supporting also non-K major vector reads. --------- Co-authored-by: jakpiase [ROCm/composable_kernel commit: 39dc25a9b8d9d835ec5716f6078bc9dd5501fcb6] --- example/ck_tile/03_gemm/gemm_basic.cpp | 26 +- example/ck_tile/03_gemm/run_gemm_example.inc | 36 - example/ck_tile/03_gemm/universal_gemm.cpp | 68 +- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 4 +- .../ck_tile/16_batched_gemm/batched_gemm.hpp | 2 +- .../run_batched_gemm_example.inc | 83 ++- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 5 +- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 2 +- .../run_grouped_gemm_example.inc | 17 +- include/ck_tile/core.hpp | 2 + .../algorithm/static_encoding_pattern.hpp | 210 ++++++ include/ck_tile/core/container/tuple.hpp | 2 +- include/ck_tile/core/tensor/tile_window.hpp | 23 +- .../ck_tile/core/tensor/transpose_tile.hpp | 202 ++++++ .../block/block_universal_gemm_as_bs_cr.hpp | 2 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 112 ++- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 3 +- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 42 +- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 107 ++- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 11 +- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 19 +- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 116 +-- .../gemm/pipeline/gemm_pipeline_problem.hpp | 53 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 673 ++++++++++-------- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 23 + .../batched_gemm/test_batched_gemm.cpp | 2 +- test/ck_tile/gemm/test_gemm_pipeline.cpp | 28 +- .../gemm/test_gemm_pipeline_ut_cases.inc | 31 +- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 170 +++-- .../grouped_gemm/test_grouped_gemm.cpp | 2 +- .../grouped_gemm/test_grouped_gemm_util.hpp | 5 +- 31 files changed, 1393 insertions(+), 688 deletions(-) create mode 100644 include/ck_tile/core/algorithm/static_encoding_pattern.hpp create mode 100644 include/ck_tile/core/tensor/transpose_tile.hpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 16f1466dd3..c3a66ba3ea 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -70,9 +70,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; - using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; - using CodegenGemmPipeline = - ck_tile::GemmPipelineAGmemBGmemCRegV1; + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::GemmKernel; @@ -103,4 +101,26 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& #include "run_gemm_example.inc" +int run_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index e29ba272f5..d32ec57be5 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -217,39 +217,3 @@ int run_gemm_example_with_layouts(int argc, return pass; } - -int run_gemm_example(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); - - if(a_layout == "R" && b_layout == "R") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } - else if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } - // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not - // work. - // else if(a_layout == "C" && b_layout == "C") - // { - // return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - // } - // else if(a_layout == "C" && b_layout == "R") - // { - // return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - // } - else - { - throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); - } -} diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index bff243d559..5d2bd2df31 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -28,8 +28,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; - -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) +#endif +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) // Compute friendly for Intrawave scheduler constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t N_Tile = 256; @@ -48,6 +48,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr bool kPadN = false; constexpr bool kPadK = false; + constexpr bool TransposeC = false; + constexpr int kBlockPerCu = 1; // =============================================== @@ -62,7 +64,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ck_tile::Default2DEpilogueProblem>; using Traits = ck_tile::TileGemmTraits; - + using GemmUniversalTraits = ck_tile:: + TileGemmUniversalTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -85,14 +88,15 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& BDataType, AccDataType, GemmShape, - Traits, + GemmUniversalTraits, scheduler, has_hot_loop_v, tail_number_v>; - using GemmPipeline = GEMM_PIPELINE; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmPipeline = + GEMM_PIPELINE; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); constexpr dim3 blocks = Kernel::BlockSize(); @@ -117,6 +121,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& if(has_hot_loop) { +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" << tail_num + << "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { @@ -177,6 +196,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ck_tile::integral_constant{}); } } +#endif } else { @@ -201,4 +221,38 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& #include "run_gemm_example.inc" +int run_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "R") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 5cb2aa5045..720802236c 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -72,9 +72,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; - using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; - using CodegenGemmPipeline = - ck_tile::GemmPipelineAGmemBGmemCRegV1; + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::BatchedGemmKernel; diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index 62f0058fd1..7b7e22160a 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -39,7 +39,7 @@ auto create_args(int argc, char* argv[]) .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default") .insert("batch_stride_a", "32768", "Batch A stride") .insert("batch_stride_b", "16384", "Batch B stride") diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index c3ed76f5ef..d0df8845cc 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -3,13 +3,6 @@ #pragma once -template -static constexpr inline auto is_row_major(Layout layout_) -{ - return ck_tile::bool_constant, - ck_tile::tensor_layout::gemm::RowMajor>>{}; -} - auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -113,16 +106,56 @@ int run_batched_gemm_example_with_layouts(int argc, int n_warmup = arg_parser.get_int("warmup"); int n_repeat = arg_parser.get_int("repeat"); - stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); - stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout)); + using namespace ck_tile::literals; - ck_tile::HostTensor a_m_k(ck_tile::host_tensor_descriptor( - batch_count, M, K, stride_A, batch_stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( - batch_count, K, N, stride_B, batch_stride_B, is_row_major(b_layout))); - ck_tile::HostTensor c_m_n_dev_result(ck_tile::host_tensor_descriptor( - batch_count, M, N, stride_C, batch_stride_C, is_row_major(c_layout))); + auto f_host_tensor_descriptor = [](std::size_t batch_count_, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({batch_count_, row, col}, + {batch_stride, stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({batch_count_, row, col}, + {batch_stride, 1_uz, stride}); + } + }; + + auto f_get_default_stride = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + stride_A = f_get_default_stride(M, K, stride_A, a_layout); + stride_B = f_get_default_stride(K, N, stride_B, b_layout); + stride_C = f_get_default_stride(M, N, stride_C, c_layout); + + ck_tile::HostTensor a_m_k( + f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, a_layout)); + ck_tile::HostTensor b_k_n( + f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, b_layout)); + ck_tile::HostTensor c_m_n_dev_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, c_layout)); ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); @@ -158,8 +191,8 @@ int run_batched_gemm_example_with_layouts(int argc, if(arg_parser.get_int("v") == 1) { - ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( - batch_count, M, N, stride_C, batch_stride_C, is_row_major(CLayout){})); + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); c_m_n_host_ref.SetZero(); const auto b_n_k = b_k_n.transpose({0, 2, 1}); @@ -183,8 +216,8 @@ int run_batched_gemm_example_with_layouts(int argc, } else if(arg_parser.get_int("v") == 2) { - ck_tile::HostTensor c_m_n_gpu_ref(ck_tile::host_tensor_descriptor( - batch_count, M, N, stride_C, batch_stride_C, is_row_major(CLayout){})); + ck_tile::HostTensor c_m_n_gpu_ref( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); @@ -268,11 +301,11 @@ int run_batched_gemm_example(int argc, char* argv[]) std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); - if(a_layout == "R" && b_layout == "R") - { - return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } - else if(a_layout == "R" && b_layout == "C") + // if(a_layout == "R" && b_layout == "R") + // { + // return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + // } + if(a_layout == "R" && b_layout == "C") { return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 6b51f696a3..bb4bdbf514 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -88,12 +88,9 @@ using CodegenPipelineProblem = CodegenGemmShape, CodegenGemmTraits>; -using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; - template using CodegenGemmPipeline = - ck_tile::GemmPipelineAGmemBGmemCRegV1, - CodegenGemmPolicy>; + ck_tile::GemmPipelineAGmemBGmemCRegV1>; template using Kernel = ck_tile::GroupedGemmKernel( ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); @@ -229,10 +226,10 @@ int run_grouped_gemm_example(int argc, char* argv[]) { return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } - else if(a_layout == "R" && b_layout == "R") - { - return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } + // else if(a_layout == "R" && b_layout == "R") + // { + // return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + // } else { throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 34f8ec5245..5610c093ca 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -7,6 +7,7 @@ #include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/algorithm/indexing_adaptor.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp" +#include "ck_tile/core/algorithm/static_encoding_pattern.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/generic_memory_space_atomic.hpp" @@ -53,6 +54,7 @@ #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_utils.hpp" +#include "ck_tile/core/tensor/transpose_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" diff --git a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp new file mode 100644 index 0000000000..78884f3f9f --- /dev/null +++ b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp @@ -0,0 +1,210 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/tensor/tile_distribution_encoding.hpp" + +namespace ck_tile { + +/** + * @brief Enumeration describing static tile distribution patterns. + * + */ +enum struct tile_distribution_pattern +{ + /** + * @brief Thread raked pattern. + * + */ + thread_raked, + /** + * @brief Warp raked pattern. + * + */ + warp_raked, + /** + * @brief Block raked pattern - aka linear. + * + */ + block_raked, +}; + +struct TileDistributionEncodingPattern +{ +}; + +/** + * @brief Class creating 2D static tile distribution with different load/store patterns. + * + * @note We always assume that Tile is YPerTile x XPerTile where X dim (rightmost) + * is contiguous and we can do vector load on this dimension. + * + * @tparam BlockSize Number of threads in a workgroup. + * @tparam YPerTile The tile size of outer/leftmost dimension. + * @tparam XPerTile The tile size of inner/rightmost dimension (contiguous). + * @tparam VecSize The vector access size. + * @tparam DistributionPattern The enumeration describing used access pattern. + */ +template +struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern +{ +}; + +// Thread raked +template +struct TileDistributionEncodingPattern2D + : public TileDistributionEncodingPattern +{ + + // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! + static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); + static constexpr index_t warp_size = get_warp_size(); + static constexpr index_t num_warps = BlockSize / get_warp_size(); + static constexpr index_t X1 = VecSize; + static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim + + // # of rows in Y dim accessed by single wavefront in one iteration + static constexpr index_t Y1 = warp_size / X0; + static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!"); + + static constexpr index_t Y0 = num_warps; + // YPerWarp = YPerTile / Y0; + // Y2 = YPerWarp / Y1; + static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront + + static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!"); + static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile"); + + CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<1, 2>>{}); + } +}; + +// Warp raked +template +struct TileDistributionEncodingPattern2D + : public TileDistributionEncodingPattern +{ + + static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); + static constexpr index_t warp_size = get_warp_size(); + static constexpr index_t num_warps = BlockSize / get_warp_size(); + static constexpr index_t X1 = VecSize; + static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim + + static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront + static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!"); + + static constexpr index_t Y0 = num_warps; + static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!"); + + static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront + static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile"); + + CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } +}; + +// Block raked +template +struct TileDistributionEncodingPattern2D + : public TileDistributionEncodingPattern +{ + + // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! + static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); + static constexpr index_t warp_size = get_warp_size(); + static constexpr index_t num_warps = BlockSize / get_warp_size(); + static constexpr index_t X1 = VecSize; + static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim + static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront + static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!"); + static constexpr index_t Y1 = num_warps; + static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!"); + static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters + static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile"); + + CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 0>>{}); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 19d853ad5c..74575f4c6e 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -546,7 +546,7 @@ CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple& t) using Idx = number::size() - i - 1>; return t.at(Idx{}); }, - number::size()()>{}); + number::size()>{}); } // Reduce tuple values in specific range using Function diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index caeb038521..27c2c24ad5 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,8 +18,17 @@ namespace ck_tile { -// Note: this tile window do not support single issue -// you need to use tile_window_linear structure for this purpose +/** + * @brief This class provides tile (windowed) view and access to the device memory. + * + * @note This tile window does not support single issue you need to use tile_window_linear + * structure for this purpose + * + * @tparam BottomTensorView_ Class describing & holding device tensor memory. + * @tparam WindowLengths_ Spatial sizes of windowed view on tensor. + * @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions + * @tparam NumCoord TBD + */ template struct tile_window_with_static_lengths { diff --git a/include/ck_tile/core/tensor/transpose_tile.hpp b/include/ck_tile/core/tensor/transpose_tile.hpp new file mode 100644 index 0000000000..f34efe5c2f --- /dev/null +++ b/include/ck_tile/core/tensor/transpose_tile.hpp @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/algorithm/space_filling_curve.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/thread_buffer.hpp" +#include "ck_tile/core/container/statically_indexed_array.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/tensor/tile_elementwise.hpp" +#include "ck_tile/core/utility/transpose_vectors.hpp" + +namespace ck_tile { +namespace detail { + +template +CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor, + const InTensor& in_tensor) +{ + constexpr auto I0 = number<0>{}; + + static_assert(std::is_same_v, + "Data type for InTensor and OutTensor must be the same!"); + + using DataType = typename InTensor::DataType; + + constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor(); + constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor(); + + // y_dim_out_to_in + // For swapped Hs tile case I need only get_rh_minor_to_y + // since rh_major are already swapped due to swapped Hs. + constexpr auto get_rh_minor_to_y = [](auto dstr_tensor) { + using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode; + + map rh_minor_to_y_; + + static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) { + constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i]; + + rh_minor_to_y_(rh_minor) = i; + }); + + return rh_minor_to_y_; + }; + + // In swapped Hs case -> tile + // we have same rh_major, but reversed rh_minor! + constexpr auto rh_minor_to_y_in = get_rh_minor_to_y(InTensor{}); + constexpr auto rh_minor_to_y_out = get_rh_minor_to_y(OutTensor{}); + + // Is this really needed?? Should we have simple reverse here?? + constexpr auto y_dim_out_to_in = [&] { + map y_dim_out_to_in_; + + for(const auto& [rh_minor, y_out] : rh_minor_to_y_out) + { + y_dim_out_to_in_(y_out) = rh_minor_to_y_in[rh_minor]; + } + + return y_dim_out_to_in_; + }(); + + constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y(); + constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths()); + + // input and output vector dim in the order of input Y dims + constexpr index_t y_dim_vec_in = NDimY - 1; + constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1]; + + // vector lengths + constexpr index_t vec_length_in = y_lengths[y_dim_vec_in]; + constexpr index_t vec_length_out = y_lengths[y_dim_vec_out]; + + // # of vectors + constexpr index_t num_vec_in = vec_length_out; + constexpr index_t num_vec_out = vec_length_in; + + using InVec = array; + using OutVec = array; + + // SFC + constexpr auto scalars_per_access_arr = generate_array( + [&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; }, + number{}); + + constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY); + + using SFC_Y = space_filling_curve::type, + decltype(scalars_per_access)>; + + constexpr index_t num_access = SFC_Y::get_num_of_access(); + + static_assert(num_access > 0, "wrong! num_access should be larger than 0"); + + // in/out vectors to be transposed + thread_buffer in_vectors; + thread_buffer out_vectors; + + // loop over SFC and do transpose + static_for<0, num_access, 1>{}([&](auto iAccess) { + // data index [y0, y1, ...] in the order of input tensor + constexpr auto idx_y_start = SFC_Y::get_index(iAccess); + + // get input vectors + static_for<0, num_vec_in, 1>{}([&](auto i) { + constexpr auto idx_y_in = generate_tuple( + [&](auto ii) { + return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii]; + }, + number{}); + + constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in); + static_assert(in_offset % vec_length_in == 0); + + in_vectors(i).template get_as()(I0) = + in_tensor.get_thread_buffer() + .template get_as()[number{}]; + }); + + // transpose + transpose_vectors{}(in_vectors, out_vectors); + + // set output vectors + static_for<0, num_vec_out, 1>{}([&](auto i) { + constexpr auto idx_y_out_tmp = generate_array( + [&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; }, + number{}); + + constexpr auto idx_y_out = + container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in); + + constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out); + static_assert(out_offset % vec_length_out == 0); + + out_tensor.get_thread_buffer().template set_as( + number{}, + out_vectors[i].template get_as()[I0]); + }); + }); +} + +} // namespace detail + +template +CK_TILE_DEVICE void transpose_tile2d(OutTensor& out, const InTensor& in) +{ + using InDataType = typename InTensor::DataType; + using OutDataType = typename OutTensor::DataType; + + using InTileDistr = typename InTensor::StaticTileDistribution; + using OutTileDistr = typename OutTensor::StaticTileDistribution; + + using InDstrEncode = typename InTileDistr::DstrEncode; + using OutDstrEncode = typename OutTileDistr::DstrEncode; + + using InThreadTensorDesc = typename InTensor::ThreadTensorDesc; + using OutThreadTensorDesc = typename OutTensor::ThreadTensorDesc; + + // Ys: + constexpr auto in_thread_desc_lengths = InThreadTensorDesc{}.get_lengths(); + constexpr auto out_thread_desc_lengths = OutThreadTensorDesc{}.get_lengths(); + + // type convert + const auto in_tmp = [&]() { + if constexpr(std::is_same_v) + { + return in; + } + else + { + return tile_elementwise_in(type_convert, in); + } + }(); + + // Scenario where we switch from tile -> - only 2D tiles! + // we preserve Ps but swap Ys: -> + if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ && + InDstrEncode::hs_lengthss_ == tuple_reverse(OutDstrEncode::hs_lengthss_) && + InDstrEncode::NDimY == OutDstrEncode::NDimY && InDstrEncode::NDimY == 2 && + in_thread_desc_lengths == tuple_reverse(out_thread_desc_lengths)) + // Any condition on Ps ?? + // InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ && + // InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ && + { + detail::transpose_tile2d_impl_in_thread(out, in_tmp); + } + else + { + static_assert(false, "Provided tensors could not be transposed!"); + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 0fe0a9f40d..646d380a18 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -80,7 +80,7 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t InterWaveSchedulingMacClusters = 1; static constexpr index_t KPack = WarpGemm::kKPerThread; - static constexpr index_t KPerThread = KPerBlock / WarpGemm::kK * KPack; + static constexpr index_t KPerThread = KIterPerWarp * KPack; static constexpr index_t KRepeat = KPerThread / KPack; }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 76cfaa2cf0..8d640831df 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -8,7 +8,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { @@ -69,6 +68,7 @@ struct GemmKernel using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; + // Below type is actually accumulation data type - the output of block GEMM. using CDataType = remove_cvref_t; static constexpr auto I0 = number<0>(); @@ -168,6 +168,7 @@ struct GemmKernel { if(kargs.KBatch != 1) { + std::cerr << "Conditions not met for Kbatch >1 !" << std::endl; return false; } } @@ -176,10 +177,14 @@ struct GemmKernel { if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false) { + std::cerr << "Can't support K that is not a multiple of KPerBlock" + " without padding!" + << std::endl; return false; } if(kargs.K % GemmPipeline::VectorSizeA != 0) { + std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl; return false; } } @@ -187,10 +192,14 @@ struct GemmKernel { if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) { + std::cerr << "Can't support M that is not a multiple of MPerBlock" + " without padding!" + << std::endl; return false; } if(kargs.M % GemmPipeline::VectorSizeA != 0) { + std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl; return false; } } @@ -199,10 +208,14 @@ struct GemmKernel { if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { + std::cerr << "Can't support N that is not a multiple of NPerBlock" + " without padding!" + << std::endl; return false; } if(kargs.N % GemmPipeline::VectorSizeB != 0) { + std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl; return false; } } @@ -210,10 +223,14 @@ struct GemmKernel { if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false) { + std::cerr << "Can't support K that is not a multiple of KPerBlock" + " without padding!" + << std::endl; return false; } if(kargs.K % GemmPipeline::VectorSizeB != 0) { + std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl; return false; } } @@ -222,10 +239,14 @@ struct GemmKernel { if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { + std::cerr << "Can't support N that is not a multiple of NPerBlock" + " without padding!" + << std::endl; return false; } if(kargs.N % GemmPipeline::VectorSizeC != 0) { + std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; return false; } } @@ -233,10 +254,14 @@ struct GemmKernel { if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) { + std::cerr << "Can't support M that is not a multiple of MPerBlock" + " without padding!" + << std::endl; return false; } if(kargs.M % GemmPipeline::VectorSizeC != 0) { + std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; return false; } } @@ -250,6 +275,14 @@ struct GemmKernel const GemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset) { + // const auto idxs = TilePartitioner{}(); + // const auto i_m = idxs.at(number<0>{}); + // const auto i_n = idxs.at(number<1>{}); + // // options + // const ADataType* a_start = static_cast(kargs.a_ptr); + // const BDataType* b_start = static_cast(kargs.b_ptr); + // // Convert pointers to tensor views + // auto a_tensor_view = [&]() { const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -264,9 +297,9 @@ struct GemmKernel { return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), - make_tuple(1, kargs.stride_A), - number<1>{}, + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_A, 1), + number{}, number<1>{}); } }(); @@ -276,9 +309,9 @@ struct GemmKernel { return make_naive_tensor_view( b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(1, kargs.stride_B), - number<1>{}, + make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(kargs.stride_B, 1), + number{}, number<1>{}); } else @@ -292,6 +325,7 @@ struct GemmKernel } }(); + // TODO: enable vector write for C in ColMajor const auto& c_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -331,9 +365,9 @@ struct GemmKernel else { return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); + make_tuple(number{}, + number{}), + sequence{}); } }(); @@ -349,12 +383,13 @@ struct GemmKernel else { return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); + make_tuple(number{}, + number{}), + sequence{}); } }(); + // TODO vector write in for C in ColMajor const auto& c_pad_view = [&]() { const auto& c_tensor_view = views.at(I2); if constexpr(std::is_same_v) @@ -380,20 +415,45 @@ struct GemmKernel CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) { - const auto& a_pad_view = views.at(I0); - const auto& a_block_window = make_tile_window( - a_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - - const auto& b_pad_view = views.at(I1); - const auto& b_block_window = make_tile_window( - b_pad_view, - make_tuple(number{}, number{}), - {i_n, 0}); - + const auto& a_pad_view = views.at(I0); + const auto& b_pad_view = views.at(I1); const auto& c_pad_view = views.at(I2); - auto c_block_window = make_tile_window( + + const auto& a_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, i_m}); + } + }(); + + const auto& b_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {0, i_n}); + } + }(); + + auto c_block_window = make_tile_window( c_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 6dbb1d6b82..656939770c 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -50,7 +50,6 @@ struct GroupedGemmKernel : public GemmKernel; using BDataType = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; using BlockGemmShape = remove_cvref_t; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - template + template CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, - SrcTileWindow& dram_tile_window) const + SrcTileWindow& dram_tile_window, + const DramTileWindowStep& dram_tile_window_step) const { load_tile(dst_block_tile, dram_tile_window); - move_tile_window(dram_tile_window, {0, KPerBlock}); + move_tile_window(dram_tile_window, dram_tile_window_step); } template @@ -60,19 +64,21 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, const ALdsTensorView& a_lds_block_view) const { + constexpr bool is_col_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + // A DRAM tile window for load auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(YPerTile{}, XPerTile{}), a_dram_block_window_tmp.get_window_origin(), Policy::template MakeADramTileDistribution()); // A LDS tile window for store - auto a_copy_lds_window = - make_tile_window(a_lds_block_view, - make_tuple(number{}, number{}), - {0, 0}, - a_copy_dram_window.get_tile_distribution()); + auto a_copy_lds_window = make_tile_window( + a_lds_block_view, make_tuple(number{}, number{}), {0, 0}); auto a_lds_gemm_window = make_tile_window( a_lds_block_view, make_tuple(number{}, number{}), {0, 0}); @@ -86,18 +92,22 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, const BLdsTensorView& b_lds_block_view) const { + constexpr bool is_row_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + auto b_copy_dram_window = make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(YPerTile{}, XPerTile{}), b_dram_block_window_tmp.get_window_origin(), Policy::template MakeBDramTileDistribution()); + // TODO: Do we really need those two tile windows??? + // They're exactly same... // B LDS tile window for store - auto b_copy_lds_window = - make_tile_window(b_lds_block_view, - make_tuple(number{}, number{}), - {0, 0}, - b_copy_dram_window.get_tile_distribution()); + auto b_copy_lds_window = make_tile_window( + b_lds_block_view, make_tuple(number{}, number{}), {0, 0}); auto b_lds_gemm_window = make_tile_window( b_lds_block_view, make_tuple(number{}, number{}), {0, 0}); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 40628b1868..70de4014c1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" @@ -37,7 +37,7 @@ struct BaseGemmPipelineAgBgCrCompV3 // LocalPreFillStages: 1 // LocalPreFetchStages: 1 // LocalSharedMemoryBuffer: 1 -template +template struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { using Base = BaseGemmPipelineAgBgCrCompV3; @@ -62,15 +62,14 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t VectorSizeA = Problem::VectorSizeA; - static constexpr index_t VectorSizeB = Problem::VectorSizeB; - static constexpr index_t VectorSizeC = Problem::VectorSizeC; + static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA(); + static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB(); + static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC(); static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; - // Where is the right place for HasHotLoop and TailNum ??? static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; @@ -82,7 +81,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 return Policy::template GetSmemSize(); } - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Policy::template IsTransposeC(); + } template struct PipelineImpl : public PipelineImplBase @@ -248,11 +250,22 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 "A/B Dram block window should have the same data type as appropriate " "([A|B]DataType) defined in Problem definition!"); - static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}], - "A/B block window appropriate sizes must be equal to MPerBlock/NPerblock" - " or KPerBlock!"); + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); // ------------------------------------------------------------------------------------ // Definitions of all needed tiles @@ -287,23 +300,51 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 ABlockTile a_block_tile; BBlockTile b_block_tile; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + // ----------------------------------------------------------------------------------------- // Gemm pipeline start // prefetch // global read 0 - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window); + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window); + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); @@ -318,11 +359,31 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { block_sync_lds(); - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window); + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index f169a17bc0..1d6a9a0b87 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t VectorSizeA = Problem::VectorSizeA; - static constexpr index_t VectorSizeB = Problem::VectorSizeB; - static constexpr index_t VectorSizeC = Problem::VectorSizeC; + static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA(); + static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB(); + static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC(); static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; @@ -133,7 +133,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return Policy::template GetSmemSize(); } - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Policy::template IsTransposeC(); + } template struct PipelineImpl : public PipelineImplBase diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 22e2b214b0..ccb2f81d4b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -39,17 +39,6 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; - CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() - { - return integer_divide_ceil( - sizeof(ADataType) * - Policy::template MakeALdsBlockDescriptor().get_element_space_size(), - 16) * - 16 + - sizeof(BDataType) * - Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); - } - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -150,7 +139,7 @@ struct GemmPipelineAGmemBGmemCRegV1 if constexpr(std::is_same_v) { auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegBlockDescriptor()); + Policy::template MakeShuffledARegBlockDistribution()); shuffle_tile(a_shuffle_tmp, a_block_tile); const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); store_tile(a_copy_lds_window, a_block_tile_tmp); @@ -164,7 +153,7 @@ struct GemmPipelineAGmemBGmemCRegV1 if constexpr(std::is_same_v) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegBlockDescriptor()); + Policy::template MakeShuffledBRegBlockDistribution()); shuffle_tile(b_shuffle_tmp, b_block_tile); const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp); store_tile(b_copy_lds_window, b_block_tile_tmp); @@ -201,7 +190,7 @@ struct GemmPipelineAGmemBGmemCRegV1 if constexpr(std::is_same_v) { auto b_shuffle_tmp_loop = make_static_distributed_tensor( - Policy::template MakeShuffledBRegBlockDescriptor()); + Policy::template MakeShuffledBRegBlockDistribution()); shuffle_tile(b_shuffle_tmp_loop, b_block_tile); store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_shuffle_tmp_loop)); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 0250ae051d..ce22ab7ab1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,37 +18,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy static constexpr bool TransposeC = true; -#if 0 - // 2d - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using namespace ck_tile; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{}); - - return a_lds_block_desc; - } - - // 2d - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - using namespace ck_tile; - - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto b_lds_block_desc = - make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{}); - - return b_lds_block_desc; - } -#elif 1 // 3d + padding template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() @@ -58,7 +27,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - // TODO: this 8 is AK1! should be a policy parameter! constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number<8>{}), make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), @@ -127,87 +95,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() { - using ADataType = remove_cvref_t; - return Problem::VectorLoadSize / sizeof(ADataType); + return Problem::VectorLoadSize; } template CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() { - using BDataType = remove_cvref_t; - return Problem::VectorLoadSize / sizeof(BDataType); + return Problem::VectorLoadSize; } -#elif 1 - // fake XOR - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using namespace ck_tile; - - using ADataType = remove_cvref_t; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( - make_tuple(number{}, number<2>{}, number{}), - number{}); - - constexpr index_t kK1 = 16 / sizeof(ADataType); - - constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( - a_lds_block_desc_d1_d2_d3, - make_tuple( - make_xor_transform(make_tuple(number{}, number{}), kK1), - make_pass_through_transform(2)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{})); - - constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( - a_lds_block_desc_d4_d5_d6, - make_tuple(make_merge_transform(make_tuple(number{}, number<2>{})), - make_pass_through_transform(kKPerBlock)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc_m_k; - } - - // fake XOR - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - using namespace ck_tile; - - using BDataType = remove_cvref_t; - - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( - make_tuple(number{}, number<2>{}, number{}), - number{}); - - constexpr index_t kK1 = 16 / sizeof(BDataType); - - constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( - b_lds_block_desc_d1_d2_d3, - make_tuple( - make_xor_transform(make_tuple(number{}, number{}), kK1), - make_pass_through_transform(2)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{})); - - constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( - b_lds_block_desc_d4_d5_d6, - make_tuple(make_merge_transform(make_tuple(number{}, number<2>{})), - make_pass_through_transform(kKPerBlock)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return b_lds_block_desc_n_k; - } -#endif template CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() @@ -273,7 +168,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy static_assert(M0 * M1 * M2 == MPerBlock, "Incorrect M0, M2, M1 configuration! " "M0, M1, M2 must cover whole MPerBlock!"); - return make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, @@ -394,7 +288,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor() + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDistribution() { using BLayout = remove_cvref_t; using BDataType = remove_cvref_t; @@ -442,7 +336,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution() { using ALayout = remove_cvref_t; using ADataType = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index bf51577aeb..dc2ea81d6f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { @@ -11,10 +12,10 @@ template + typename Traits_> struct GemmPipelineProblemBase { - using GemmTraits = remove_cvref_t; + using Traits = remove_cvref_t; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -22,19 +23,19 @@ struct GemmPipelineProblemBase using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; - static constexpr index_t VectorLoadSize = GemmTraits::_VectorSize; - static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - static constexpr bool kPadM = GemmTraits::kPadM; - static constexpr bool kPadN = GemmTraits::kPadN; - static constexpr bool kPadK = GemmTraits::kPadK; + static constexpr bool kPadM = Traits::kPadM; + static constexpr bool kPadN = Traits::kPadN; + static constexpr bool kPadK = Traits::kPadK; static constexpr auto Scheduler = GemmPipelineScheduler::Default; + static constexpr index_t VectorLoadSize = Traits::_VectorSize; CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() { if constexpr(std::is_same_v) @@ -128,27 +129,43 @@ template + typename Traits_> using GemmPipelineProblem = - GemmPipelineProblemBase; + GemmPipelineProblemBase; template -struct UniversalGemmPipelineProblem : public GemmPipelineProblemBase +struct UniversalGemmPipelineProblem { + using Traits = remove_cvref_t; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using BlockGemmShape = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); + + static constexpr bool kPadM = Traits::kPadM; + static constexpr bool kPadN = Traits::kPadN; + static constexpr bool kPadK = Traits::kPadK; + static constexpr auto Scheduler = Scheduler_; static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_; + + static constexpr bool TransposeC = Traits::TransposeC; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index b26ee071df..31a837aa45 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" namespace ck_tile { @@ -15,30 +16,43 @@ struct UniversalGemmPipelineAgBgCrPolicy static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; - static constexpr bool TransposeC = true; + static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked; + static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked; - template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorLoadSize() + /** + * @brief Get the maximum global memory vector load size. + * + * @tparam Problem The UniversalGemmPipelineProblem object. + * @tparam DataType The tensor data type we're considering. + * @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B). + * @tparam XPerTile The contiguous Tile dimension size. + * @return Maximum DRAM vector load size. + */ + template + CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize() { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; - if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0) + // Assume DataType is even! + if constexpr(XPerTile % (16 / sizeof(DataType)) == 0 && + elements_per_thread % (16 / sizeof(DataType)) == 0) { return (16 / sizeof(DataType)); } - else if constexpr(elements_per_thread % (8 / sizeof(DataType)) == 0) + else if constexpr(XPerTile % (8 / sizeof(DataType)) == 0 && + elements_per_thread % (8 / sizeof(DataType)) == 0) { return (8 / sizeof(DataType)); } - else if constexpr(elements_per_thread % (4 / sizeof(DataType)) == 0 && - sizeof(DataType) >= 4) + else if constexpr(sizeof(DataType) >= 4 && XPerTile % (4 / sizeof(DataType)) == 0 && + elements_per_thread % (4 / sizeof(DataType)) == 0) { return (4 / sizeof(DataType)); } - else if constexpr(elements_per_thread % (2 / sizeof(DataType)) == 0 && - sizeof(DataType) >= 2) + else if constexpr(sizeof(DataType) >= 2 && XPerTile % (2 / sizeof(DataType)) == 0 && + elements_per_thread % (2 / sizeof(DataType)) == 0) { return (2 / sizeof(DataType)); } @@ -48,6 +62,126 @@ struct UniversalGemmPipelineAgBgCrPolicy } } + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() + { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + return GetGlobalVectorLoadSize(); + } + else + { + return GetGlobalVectorLoadSize(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() + { + using BLayout = remove_cvref_t; + using BDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + return GetGlobalVectorLoadSize(); + } + else + { + return GetGlobalVectorLoadSize(); + } + } + + /** + * @brief Get the vector store size for C tensor. + * + * @tparam Problem - Gemm pipeline problem class. + * + * @note The vector store size for output C tensor would depend on multiple factors + * like its data layout and warp gemm C transposition. In general it would + * be the number of consecutive elements in contiguous C dimension hold by + * single thread. + * + * @return The vector store size for C tensor. + */ + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() + { + using BlockGemm = remove_cvref_t())>; + using WG = typename BlockGemm::WarpGemm; + + constexpr bool TransposeC = Problem::TransposeC; + using CLayout = typename Problem::CLayout; + using CWarpDstr = typename WG::CWarpDstr; + + // N is contiguous dimension + if constexpr(std::is_same_v) + { + if constexpr(TransposeC) + { + // In this case each thread has multiple consecutive elements in + // N dimension, however consecutive threads' elements have stride. + constexpr index_t NDimY = CWarpDstr::NDimY; + constexpr auto c_warp_y_lengths = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); + static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane == + c_warp_y_lengths.get(number{})); + return c_warp_y_lengths.get(number{}); + } + else + { + // In this case each thread has just a single item in Ndim + return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN; + } + } + // M is contiguous dimension + else if constexpr(std::is_same_v) + { + if constexpr(TransposeC) + { + // In this case each thread has just a single item in Mdim + return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN; + } + else + { + // In this case each thread has multiple consecutive elements in + // M dimension, however consecutive threads' elements have stride. + constexpr index_t NDimY = CWarpDstr::NDimY; + constexpr auto c_warp_y_lengths = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); + static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane == + c_warp_y_lengths.get(number{})); + return c_warp_y_lengths.get(number{}); + } + } + else + { + static_assert(false, "Unsupported CLayout!"); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + { + using BlockGemm = decltype(GetBlockGemm()); + constexpr index_t KPack = BlockGemm::Traits::KPack; + return KPack; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() + { + using BlockGemm = decltype(GetBlockGemm()); + constexpr index_t KPack = BlockGemm::Traits::KPack; + return KPack; + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { @@ -56,7 +190,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetVectorLoadSize(); + constexpr index_t KPack = GetSmemPackA(); constexpr auto DataTypeSize = sizeof(ADataType); constexpr auto MLdsLayer = @@ -99,54 +233,193 @@ struct UniversalGemmPipelineAgBgCrPolicy return a_lds_block_desc; } + /** + * @brief Create LDS block descriptor for B tensor. + * + * @tparam Problem Gemm pipeline problem. + * @return B tensor LDS block descriptor. + */ template CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - + // using BLayout = remove_cvref_t; using BDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetVectorLoadSize(); - constexpr auto DataTypeSize = sizeof(BDataType); - constexpr auto NLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); +#if 1 + // if constexpr(std::is_same_v) + { + constexpr index_t KPack = GetSmemPackB(); + constexpr auto BK0 = number{}; + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple( + BK0 * number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + BK0 * number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); - constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc; + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } +#else + else // B is Row Major + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t VecLoadSize = GetVectorSizeB(); + using TileEncodingPattern = TileDistributionEncodingPattern2D; + + constexpr auto BK0 = number{}; + constexpr auto BK1 = number{}; + // constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N0 = TileEncodingPattern::X0; + constexpr auto N1 = NPerBlock / N0; + + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr auto NPerXdl = number{}; + + // constexpr auto KThreadWrite = + // BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto KThreadWrite = TileEncodingPattern::Y2; + constexpr auto K0PerThreadWrite = BK0 / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0 / KThreadRead; + + constexpr auto kfold = + (BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1 * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + BK1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + // constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + // b_lds_block_desc_unmerged, + // make_tuple(make_merge_transform_v3_division_mod( + // make_tuple(number{}, + // number{}, + // number{}, + // number{})), + // make_merge_transform_v3_division_mod( + // make_tuple(number{}, number{}, number{})), + // make_pass_through_transform(BK1)), + // make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}), + // make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + BK1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // return b_lds_block_desc_bk0_n_bk1; + return b_lds_block_desc_kn; + + // constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor( + // make_tuple(BK0, number{}, number{}), + // make_tuple(number{}, number{}, number<1>{}), + // number{}, + // number<1>{}); + + // constexpr auto b_lds_block_desc = transform_tensor_descriptor( + // b_lds_block_desc_bk0_n_bk1, + // make_tuple(make_pass_through_transform(number{}), + // make_merge_transform_v3_division_mod(make_tuple(BK0, + // number{}))), + // make_tuple(sequence<1>{}, sequence<0, 2>{}), + // make_tuple(sequence<0>{}, sequence<1>{})); + + // return b_lds_block_desc; + } +#endif } template @@ -179,291 +452,127 @@ struct UniversalGemmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { - using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; + using ALayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeA(); - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - if constexpr(std::is_same_v) + // Tile: MPerBlock X KPerBlock + if constexpr(std::is_same_v) { - constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); - constexpr index_t M0 = MPerBlock / M1; - constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; - static_assert(total_pixels % M1 == 0); - constexpr index_t K3 = total_pixels / M1; - constexpr index_t KPack = GetVectorLoadSize(); - static_assert(KPack % K3 == 0); - constexpr index_t K2 = KPack / K3; - if constexpr(get_warp_size() % (K2 * M0) == 0) - { - constexpr index_t K1 = get_warp_size() / (K2 * M0); - constexpr index_t K0 = BlockSize / get_warp_size(); - static_assert(KPerBlock == K0 * K1 * K2 * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - else - { - constexpr index_t K1 = (K2 * M0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = BlockSize / get_warp_size() / K1; - static_assert(KPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); } + // Tile: KPerBlock X MPerBlock else { - constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - if constexpr(get_warp_size() % (M2 * K0) == 0) - { - constexpr index_t M1 = BlockSize / get_warp_size(); - static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); - constexpr index_t M0 = MPerBlock / (M2 * M1); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - else - { - constexpr index_t M0 = BlockSize / get_warp_size(); - constexpr index_t M1 = MPerBlock / (M2 * M0); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); - } + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); } } template CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() { - using BDataType = remove_cvref_t; - using BLayout = remove_cvref_t; + using BLayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; - - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeB(); + // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) { - constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); - constexpr index_t N0 = NPerBlock / N1; - constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; - static_assert(total_pixels % N1 == 0); - constexpr index_t K3 = total_pixels / N1; - constexpr index_t KPack = GetVectorLoadSize(); - static_assert(KPack % K3 == 0); - constexpr index_t K2 = KPack / K3; - if constexpr(get_warp_size() % (K2 * N0) == 0) - { - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = BlockSize / get_warp_size(); - static_assert(KPerBlock == K0 * K1 * K2 * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - else - { - constexpr index_t K1 = (K2 * N0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = BlockSize / get_warp_size() / K1; - static_assert(KPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); } + // Tile: NPerBlock X KPerBlock else { - - constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - // coalesce reading for each blocks - if constexpr(get_warp_size() % (N2 * K0) == 0) - { - constexpr index_t N1 = BlockSize / get_warp_size(); - static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); - static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); - constexpr index_t N0 = NPerBlock / (N2 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - // coalesce reading for each warps - else - { - constexpr index_t N0 = BlockSize / get_warp_size(); - constexpr index_t N1 = NPerBlock / (N2 * N0); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); - } + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); } } template - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution() { - using ALayout = remove_cvref_t; - using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; static_assert(std::is_same_v); - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeA(); - constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); - constexpr index_t M0 = MPerBlock / M1; - constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; - static_assert(total_pixels % M1 == 0); - constexpr index_t K3 = total_pixels / M1; - constexpr index_t kKPack = GetVectorLoadSize(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t warp_size = get_warp_size(); - if constexpr(warp_size % (K2 * M0) == 0) - { - constexpr index_t K1 = warp_size / (K2 * M0); - constexpr index_t K0 = BlockSize / warp_size; - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<1, 2>, - sequence<1, 3>>{}); - } - else - { - constexpr index_t K1 = (K2 * M0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = BlockSize / get_warp_size() / K1; - static_assert(KPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<0, 2>>, - sequence<1, 2>, - sequence<1, 3>>{}); - } + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } template - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor() + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution() { - using BLayout = remove_cvref_t; - using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; static_assert(std::is_same_v); - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeB(); - constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); - constexpr index_t N0 = NPerBlock / N1; - constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; - static_assert(total_pixels % N1 == 0); - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetVectorLoadSize(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t warp_size = get_warp_size(); - if constexpr(warp_size % (K2 * N0) == 0) - { - constexpr index_t K1 = warp_size / (K2 * N0); - constexpr index_t K0 = BlockSize / warp_size; - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<1, 2>, - sequence<1, 3>>{}); - } - else - { - constexpr index_t K1 = (K2 * N0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = BlockSize / get_warp_size() / K1; - static_assert(KPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<0, 2>>, - sequence<1, 2>, - sequence<1, 3>>{}); - } + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; } + template + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Problem::TransposeC; + } template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { - using AccDataType = float; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmMfmaDispatcher; + Problem::TransposeC>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; - return BlockGemmASmemBSmemCRegV1{}; + return BlockUniversalGemmAsBsCr{}; } }; diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 34756c3ff6..3d7441c942 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -19,11 +19,34 @@ struct TileGemmTraits static constexpr bool kPadN = kPadN_; static constexpr bool kPadK = kPadK_; + // TODO this can't be hardcoded here! Should be in policy! static constexpr int _VectorSize = 16; using ALayout = ALayout_; using BLayout = BLayout_; using CLayout = CLayout_; + + static constexpr bool TransposeC = false; +}; + +template +struct TileGemmUniversalTraits +{ + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kPadK = kPadK_; + + using ALayout = ALayout_; + using BLayout = BLayout_; + using CLayout = CLayout_; + + static constexpr bool TransposeC = TransposeC_; }; } // namespace ck_tile diff --git a/test/ck_tile/batched_gemm/test_batched_gemm.cpp b/test/ck_tile/batched_gemm/test_batched_gemm.cpp index 29bed8d2fd..3e3b821498 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm.cpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm.cpp @@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; // clang-format off using KernelTypes = ::testing::Types< // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType - std::tuple< Row, Row, Row, F16, F16, F32, F16>, + // std::tuple< Row, Row, Row, F16, F16, F32, F16>, //std::tuple< Col, Row, Row, F16, F16, F32, F16>, std::tuple< Row, Col, Row, F16, F16, F32, F16>//, //std::tuple< Col, Col, Row, F16, F16, F32, F16> diff --git a/test/ck_tile/gemm/test_gemm_pipeline.cpp b/test/ck_tile/gemm/test_gemm_pipeline.cpp index 48a2b86a63..faffe848d5 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline.cpp @@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Intrawave = ck_tile::integral_constant; -using Interwave = ck_tile::integral_constant; -using Mem = ck_tile::integral_constant; -using Comp = ck_tile::integral_constant; +// using Interwave = ck_tile::integral_constant; +// using Mem = ck_tile::integral_constant; +using Comp = ck_tile::integral_constant; + +// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors. // clang-format off using KernelTypes = ::testing::Types< // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType - std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, + // std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, - std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, + // std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, + // std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, + // std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, + // std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem> + // std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, + // std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp> + // std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem> >; // clang-format on diff --git a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc index c78d69601c..e53015a975 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc @@ -10,22 +10,43 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM) constexpr int K = 320; for(int M : Ms) - this->Run(M, N, K); + { + if constexpr(std::is_same_v) + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + else + this->Run(M, N, K); + } } TYPED_TEST(TestCkTileGemmPipeline, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 1024; - constexpr int K = 320; + constexpr int N = 1024; + constexpr int K = 320; + constexpr int VecLoadSize = 8; for(int M : Ms) - this->Run(M, N, K); + { + if constexpr(std::is_same_v) + { + // TODO: Can we anyhow deduce used vector load size? + if(M % VecLoadSize == 0) + this->Run(M, N, K); + else + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } } TYPED_TEST(TestCkTileGemmPipeline, PaddK) { - std::vector Ms{127}; + std::vector Ms{128}; constexpr int N = 1024; constexpr int K = 432; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 96199f33e8..1474498726 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -16,6 +16,7 @@ enum struct GemmPipelineType Mem, Comp }; + template class TestCkTileGemmPipeline : public ::testing::Test { @@ -51,6 +52,9 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr bool kPadN = PadN; constexpr bool kPadK = PadK; + // TODO: For now - but this should also be a test parameter + constexpr bool TransposeC = false; + constexpr int kBlockPerCu = 1; // =============================================== @@ -65,14 +69,16 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::Default2DEpilogueProblem>; using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile:: + TileGemmUniversalTraits; - using BaseGemmPipeline = std::conditional_t< - PipelineType == GemmPipelineType::Mem, - ck_tile::BaseGemmPipelineAgBgCrMem< - ck_tile::GemmPipelineProblem>, - ck_tile::BaseGemmPipelineAgBgCrCompV3< - ck_tile:: - GemmPipelineProblem>>; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + std::conditional_t, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; @@ -84,26 +90,22 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; - using GemmPipeline = - std::conditional_t>, - ck_tile::GemmPipelineAgBgCrCompV3< - ck_tile::UniversalGemmPipelineProblem>>; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = std::conditional_t< + PipelineType == GemmPipelineType::Mem, + ck_tile::GemmPipelineAgBgCrMem, + ck_tile::GemmPipelineAgBgCrCompV3>; + using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -129,70 +131,94 @@ class TestCkTileGemmPipeline : public ::testing::Test if(has_hot_loop) { - // Tail pipeline One to Seven - if(tail_num == ck_tile::TailNumber::One) + if constexpr(PipelineType == GemmPipelineType::Comp) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Full) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" + << tail_num << "\" which is not supported! PrefetchStages: " + << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } } - if constexpr(BaseGemmPipeline::PrefetchStages > 2) + if constexpr(PipelineType == GemmPipelineType::Mem) { - if(tail_num == ck_tile::TailNumber::Two) + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) { Run(ck_tile::bool_constant{}, ck_tile::integral_constant{}); + ck_tile::TailNumber::One>{}); } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 3) - { - if(tail_num == ck_tile::TailNumber::Three) + else if(tail_num == ck_tile::TailNumber::Full) { Run(ck_tile::bool_constant{}, ck_tile::integral_constant{}); + ck_tile::TailNumber::Full>{}); } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 4) - { - if(tail_num == ck_tile::TailNumber::Four) + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 5) - { - if(tail_num == ck_tile::TailNumber::Five) + if constexpr(BaseGemmPipeline::PrefetchStages > 3) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 6) - { - if(tail_num == ck_tile::TailNumber::Six) + if constexpr(BaseGemmPipeline::PrefetchStages > 4) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 7) - { - if(tail_num == ck_tile::TailNumber::Seven) + if constexpr(BaseGemmPipeline::PrefetchStages > 5) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } } } } diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp b/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp index 1bce0f8aa9..7ea4c2b6dc 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp @@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; // clang-format off using KernelTypes = ::testing::Types< // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType - std::tuple< Row, Row, Row, F16, F16, F32, F16>, + // std::tuple< Row, Row, Row, F16, F16, F32, F16>, //std::tuple< Col, Row, Row, F16, F16, F32, F16>, std::tuple< Row, Col, Row, F16, F16, F32, F16>//, //std::tuple< Col, Col, Row, F16, F16, F32, F16> diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index f532de21dc..a1b767d853 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -96,12 +96,9 @@ class TestCkTileGroupedGemm : public ::testing::Test CodegenGemmShape, CodegenGemmTraits>; - using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; - template using CodegenGemmPipeline = - ck_tile::GemmPipelineAGmemBGmemCRegV1, - CodegenGemmPolicy>; + ck_tile::GemmPipelineAGmemBGmemCRegV1>; template using Kernel = ck_tile::GroupedGemmKernel