diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index bc3799f015..30cfee22f6 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,2 +1,5 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) +target_compile_options(tile_example_gemm_universal PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 +) diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index ed02f89fac..636b34981f 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -11,21 +11,26 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#define CK_TILE_PIPELINE_COMPUTE 1 +#define CK_TILE_PIPELINE_COMPUTE_V3 1 #define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 #ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 #endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave #else #error "unsupported CK_TILE_PIPELINE_DEFAULT value" #endif @@ -126,7 +131,8 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("split_k", "1", "splitK value"); + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 13a1c30e43..042ad372dc 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -110,6 +110,7 @@ int run_gemm_example_with_layouts(int argc, ck_tile::index_t kbatch = arg_parser.get_int("split_k"); int n_warmup = arg_parser.get_int("warmup"); int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); @@ -122,9 +123,19 @@ int run_gemm_example_with_layouts(int argc, ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - // TODO: add different init types - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + if (init_method == 0) { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + } else if (init_method == 1) { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } else if (init_method == 2) { + ck_tile::FillConstant{static_cast(1)}(a_m_k); + ck_tile::FillConstant{static_cast(1)}(b_k_n); + } else { + a_m_k.SetZero(); + b_k_n.SetZero(); + } ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); diff --git a/example/ck_tile/03_gemm/script/benchmark_basic.sh b/example/ck_tile/03_gemm/script/benchmark_basic.sh index a1646da5bd..64d2ddbb5c 100755 --- a/example/ck_tile/03_gemm/script/benchmark_basic.sh +++ b/example/ck_tile/03_gemm/script/benchmark_basic.sh @@ -2,7 +2,6 @@ EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" VALID=1 - for b_matrix_layout in "C"; do for m in "64" "512" "1024" "2048"; do for n in "512" "1024" "2048"; do diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 08a9cdb24b..668d6e4201 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -34,8 +34,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; + + constexpr bool DoubleSmemBuffer = false; #endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) // Compute friendly for Intrawave scheduler constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t N_Tile = 256; @@ -48,6 +50,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = true; #endif constexpr bool kPadM = false; @@ -70,8 +90,14 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile:: - TileGemmUniversalTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -99,8 +125,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& has_hot_loop_v, tail_number_v>; - using GemmPipeline = - GEMM_PIPELINE; + using GemmPipeline = GEMM_PIPELINE; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem{}, @@ -215,6 +240,17 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ck_tile::integral_constant{}); } } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } #endif } else diff --git a/include/ck_tile/core/utility/transpose_vectors.hpp b/include/ck_tile/core/utility/transpose_vectors.hpp index a164c3f946..497fd3b948 100644 --- a/include/ck_tile/core/utility/transpose_vectors.hpp +++ b/include/ck_tile/core/utility/transpose_vectors.hpp @@ -68,52 +68,82 @@ struct transpose_vectors } else if constexpr(sizeof(S) == 1) { - static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!"); + static_assert(((NX % 4 == 0 && NY % 4 == 0) || (NX % 2 == 0 && NY % 2 == 0)), "wrong!"); using S4 = array; // typename array::type; + using S2 = array; // typename array::type; - // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple - static_for<0, NY, 4>{}([&](auto iy) { - static_for<0, NX, 4>{}([&](auto ix) { - // 4 int8x4 data from vx_tuple - const int32_t x_s4_0 = - bit_cast(vx_tuple[ix].template get_as()[iy / I4]); - const int32_t x_s4_1 = - bit_cast(vx_tuple[ix + I1].template get_as()[iy / I4]); - const int32_t x_s4_2 = - bit_cast(vx_tuple[ix + I2].template get_as()[iy / I4]); - const int32_t x_s4_3 = - bit_cast(vx_tuple[ix + I3].template get_as()[iy / I4]); + if constexpr(NX % 4 == 0 && NY % 4 == 0) + { + // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 4>{}([&](auto iy) { + static_for<0, NX, 4>{}([&](auto ix) { + // 4 int8x4 data from vx_tuple + const int32_t x_s4_0 = + bit_cast(vx_tuple[ix].template get_as()[iy / I4]); + const int32_t x_s4_1 = + bit_cast(vx_tuple[ix + I1].template get_as()[iy / I4]); + const int32_t x_s4_2 = + bit_cast(vx_tuple[ix + I2].template get_as()[iy / I4]); + const int32_t x_s4_3 = + bit_cast(vx_tuple[ix + I3].template get_as()[iy / I4]); - // transpose - int32_t t_s4_0, t_s4_1; - int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3; + // transpose + int32_t t_s4_0, t_s4_1; + int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3; - constexpr int32_t m0 = 0x05010400; - constexpr int32_t m1 = 0x05040100; - constexpr int32_t m2 = 0x07060302; - constexpr int32_t m3 = 0x07030602; + constexpr int32_t m0 = 0x05010400; + constexpr int32_t m1 = 0x05040100; + constexpr int32_t m2 = 0x07060302; + constexpr int32_t m3 = 0x07030602; - // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 - // -- -- -- -- -- -- -- -- - - - - - // index 7 6 5 4 3 2 1 0 33 77 44 88 - // index is reversed because of little endianness (least significant bits first) - t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0); - t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0); - y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); - y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); - t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3); - t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3); - y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); - y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> + // 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits + // first) + t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0); + t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0); + y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); + y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); + t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3); + t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3); + y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); + y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); - // 4 int8x4 data from vy_tuple - vy_tuple(iy).template get_as()(ix / I4) = bit_cast(y_s4_0); - vy_tuple(iy + I1).template get_as()(ix / I4) = bit_cast(y_s4_1); - vy_tuple(iy + I2).template get_as()(ix / I4) = bit_cast(y_s4_2); - vy_tuple(iy + I3).template get_as()(ix / I4) = bit_cast(y_s4_3); + // 4 int8x4 data from vy_tuple + vy_tuple(iy).template get_as()(ix / I4) = bit_cast(y_s4_0); + vy_tuple(iy + I1).template get_as()(ix / I4) = bit_cast(y_s4_1); + vy_tuple(iy + I2).template get_as()(ix / I4) = bit_cast(y_s4_2); + vy_tuple(iy + I3).template get_as()(ix / I4) = bit_cast(y_s4_3); + }); }); - }); + } + else if constexpr(NX % 2 == 0 && NY % 2 == 0) + { + static_for<0, NY, 2>{}([&](auto ix) { + static_for<0, NX, 2>{}([&](auto iy) { + const int16_t x_s2_0 = + bit_cast(vx_tuple[ix].template get_as()[iy / I2]); + const int16_t x_s2_1 = + bit_cast(vx_tuple[ix + I1].template get_as()[iy / I2]); + constexpr int32_t m0 = 0x05040100; + constexpr int32_t m1 = 0x07060302; + + const int32_t x0_32 = static_cast(x_s2_0 & 0xFFFF); + const int32_t x1_32 = static_cast(x_s2_1 & 0xFFFF); + + const int32_t y_s2_0 = __builtin_amdgcn_perm(x1_32, x0_32, m0); + const int32_t y_s2_1 = __builtin_amdgcn_perm(x1_32, x0_32, m1); + + vy_tuple(iy).template get_as()[ix / I2] = + bit_cast(static_cast(y_s2_0 & 0xFFFF)); + vy_tuple(iy + I1).template get_as()[ix / I2] = + bit_cast(static_cast(y_s2_1 & 0xFFFF)); + }); + }); + } } else { diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index a94628a59a..794f7f21f2 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -29,6 +29,8 @@ #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index 521f236ab7..b4362d9069 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -14,24 +14,54 @@ namespace ck_tile { template struct BlockGemmARegBRegCRegV1 { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; + private: + template + struct GemmTraits_ + { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - static constexpr index_t MWarp = config.template at<1>(); - static constexpr index_t NWarp = config.template at<2>(); - static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); - static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - static constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static constexpr index_t KPack = WarpGemm::kKPerThread; + }; + + public: + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using Traits = GemmTraits_; + + using WarpGemm = typename Traits::WarpGemm; + using BlockGemmShape = typename Traits::BlockGemmShape; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; + static constexpr index_t MIterPerWarp = Traits::MIterPerWarp; + static constexpr index_t NIterPerWarp = Traits::NIterPerWarp; + + static constexpr index_t MWarp = Traits::MWarp; + static constexpr index_t NWarp = Traits::NWarp; CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { @@ -43,7 +73,7 @@ struct BlockGemmARegBRegCRegV1 sequence<1, 2>, sequence<0, 0>>{}; constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); return a_block_dstr_encode; } @@ -58,7 +88,7 @@ struct BlockGemmARegBRegCRegV1 sequence<1, 2>, sequence<0, 0>>{}; constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); return b_block_dstr_encode; } @@ -73,7 +103,7 @@ struct BlockGemmARegBRegCRegV1 sequence<1, 2>, sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); return c_block_dstr_encode; } @@ -112,13 +142,13 @@ struct BlockGemmARegBRegCRegV1 .get_static_tile_distribution_encoding())>>, "C distribution is wrong!"); - using AWarpDstr = typename WG::AWarpDstr; - using BWarpDstr = typename WG::BWarpDstr; - using CWarpDstr = typename WG::CWarpDstr; + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; - using AWarpTensor = typename WG::AWarpTensor; - using BWarpTensor = typename WG::BWarpTensor; - using CWarpTensor = typename WG::CWarpTensor; + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); @@ -157,7 +187,7 @@ struct BlockGemmARegBRegCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( @@ -180,7 +210,7 @@ struct BlockGemmARegBRegCRegV1 sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); return c_block_tensor; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 4ed3006c89..3107d07bc9 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -463,7 +463,9 @@ struct GemmKernel * @param a_ptr input A pointer * @param b_ptr input B pointer * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @@ -473,7 +475,7 @@ struct GemmKernel CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, CDataType* c_ptr, - void* smem_ptr, + void* smem_ptr_0, const GemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, @@ -491,15 +493,67 @@ struct GemmKernel // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& c_block_tile = - GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0); // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); EpiloguePipeline{} .template operator()( - c_block_window, c_block_tile, smem_ptr); + c_block_window, c_block_tile, smem_ptr_0); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + * @tparam DstInMemOp Destination memory operation (default: set). + */ + template + CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const GemmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I2); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, smem_ptr_0); } CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const @@ -517,11 +571,27 @@ struct GemmKernel CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; + __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr_1[GetSmemSize()]; if(kargs.k_batch == 1) { - RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + RunGemm2LDS(a_ptr, + b_ptr, + c_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else + { + RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + } } else { @@ -530,8 +600,23 @@ struct GemmKernel if constexpr(!(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && is_any_of::value)) { - RunGemm( - a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + RunGemm2LDS(a_ptr, + b_ptr, + c_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else + { + RunGemm( + a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + } } } } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index c08fe45465..4855df0e0e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -41,20 +41,26 @@ struct GemmPipelineAgBgCrImplBase store_tile(lds_tile_window, block_tile_tmp); } + template + CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile, + const SrcTileWindow& lds_tile_window) const + { + load_tile(dst_block_tile, lds_tile_window); + } + CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const { // A tile in LDS - ADataType* p_a_lds = static_cast(p_smem); + ADataType* __restrict__ p_a_lds = static_cast(p_smem); constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); // TODO: LDS alignment should come from Policy! - constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * - 16; + constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple( + sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16); // B tile in LDS - BDataType* p_b_lds = static_cast( + BDataType* __restrict__ p_b_lds = static_cast( static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index eec3886e2f..69c50c7cd0 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -76,6 +76,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp new file mode 100644 index 0000000000..ea8d063fd5 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -0,0 +1,559 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BaseGemmPipelineAgBgCrCompV4 +{ + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::Three; + } + else + { + return TailNumber::Two; + } + } +}; + +/** + * @brief Compute optimized pipeline version 4 + * + * This version introduces a dual LDS window mechanism using a ping-pong buffer approach + * for more efficient data handling from global memory. Unlike compute version 3, this method + * allows one LDS to fetch data from global memory while the other LDS executes warps for MFMA + * matrix multiplication. This dual operation helps in keeping the Warp unit continuously busy, + * thereby significantly reducing memory load times and enhancing overall performance. + * + * @note This version shows improved performance over Compute Version 3 with the same block tile. + * It is particularly more efficient for large matrices where M, N, and K are greater than 8K, + * even when Compute Version 3's block size is twice that of Compute Version 4. + */ +template +struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 +{ + using Base = BaseGemmPipelineAgBgCrCompV4; + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Policy::template IsTransposeC(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{}); + constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{}); + constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{}); + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t A_LDS_Read_Width = KPerXDL; + constexpr index_t B_LDS_Read_Width = KPerXDL; + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); + constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16 + ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; + constexpr auto num_ds_write_inst = A_LDS_Write_Inst_Num + B_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num; + constexpr auto num_issue = num_buffer_load_inst; + + static_for<0, num_buffer_load_inst, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1 + __builtin_amdgcn_sched_group_barrier( + 0x100, num_ds_read_inst / num_issue, 0); // DS read : 2 + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1 + __builtin_amdgcn_sched_group_barrier( + 0x200, num_ds_write_inst / num_issue, 0); // DS write : 1 + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1 + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1 + __builtin_amdgcn_sched_group_barrier( + 0x008, C_MFMA_Inst_Num / num_issue - 3, 0); // MFMA : 5 + }); + __builtin_amdgcn_sched_barrier(0); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "Data Type conflict on A and B matrix input data type."); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + ////////////// global window & register ///////////////// + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window_linear(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window_linear(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // A register tile for global load + constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution(); + constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution(); + using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr)); + using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr)); + ABlockTile a_global_load_tile; + BBlockTile b_global_load_tile; + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + + // global prefetch 0 + // global read 0 + Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + ////////////// LDS desc, window & register ///////////////// + auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); + auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); + + auto a_copy_lds_window0 = make_tile_window( + a_lds_block0, make_tuple(number{}, number{}), {0, 0}); + + auto a_copy_lds_window1 = make_tile_window( + a_lds_block1, make_tuple(number{}, number{}), {0, 0}); + + auto b_copy_lds_window0 = make_tile_window( + b_lds_block0, make_tuple(number{}, number{}), {0, 0}); + + auto b_copy_lds_window1 = make_tile_window( + b_lds_block1, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_global_load_tile); + Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_global_load_tile); + Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); + } + + // global read 1 + Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + + constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_block_tile0; + ALdsTile a_block_tile1; + + BLdsTile b_block_tile0; + BLdsTile b_block_tile1; + + auto a_lds_ld_window0 = + make_tile_window_linear(a_lds_block0, + make_tuple(number{}, number{}), + {0, 0}, + ALdsTileDistr); + auto a_lds_ld_window1 = + make_tile_window_linear(a_lds_block1, + make_tuple(number{}, number{}), + {0, 0}, + ALdsTileDistr); + auto b_lds_ld_window0 = + make_tile_window_linear(b_lds_block0, + make_tuple(number{}, number{}), + {0, 0}, + BLdsTileDistr); + auto b_lds_ld_window1 = + make_tile_window_linear(b_lds_block1, + make_tuple(number{}, number{}), + {0, 0}, + BLdsTileDistr); + + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_global_load_tile); + Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_global_load_tile); + Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func); + } + + Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + + if(HasHotLoop) + { + // minus 2 because we have ping-pong double buffer. + index_t iCounter = __builtin_amdgcn_readfirstlane(num_loop - 2); + do + { + // ping + { + block_sync_lds(); + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_global_load_tile); + Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill( + a_copy_lds_window0, a_global_load_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_global_load_tile); + Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill( + b_copy_lds_window0, b_global_load_tile, b_element_func); + } + + Base::GlobalPrefetch( + a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + // gemm + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + } + // pong + { + block_sync_lds(); + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_global_load_tile); + Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill( + a_copy_lds_window1, a_global_load_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_global_load_tile); + Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill( + b_copy_lds_window1, b_global_load_tile, b_element_func); + } + + Base::GlobalPrefetch( + a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + // gemm + block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + } + iCounter -= 2; + } while(iCounter > 1); + } + + // tail 3 + if(TailNum == TailNumber::Three) + { + // 3 + { + block_sync_lds(); + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_global_load_tile); + Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_global_load_tile); + Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); + } + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + } + // 2 + { + block_sync_lds(); + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); + block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + } + // 1 + { + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + __builtin_amdgcn_sched_barrier(0); + } + } + else + { + // 2 + { + block_sync_lds(); + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + static_for<0, 8, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA + }); + __builtin_amdgcn_sched_barrier(0); + } + // 1 + { + block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + __builtin_amdgcn_sched_barrier(0); + } + } + return c_block_tile; + } + }; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_0, + void* p_smem_1) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem_0, + p_smem_1); + } + + public: + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem_0, + p_smem_1); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp new file mode 100644 index 0000000000..e528847438 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { +// Default policy for GemmPipelineAGmemBGmemCregComputeV4, except the block gemm method, it shares +// the same vector size implementation, SmemSize, Global memory tile distiribution as the +// UniversalGemm Pipeline Policy. +// Default policy class should not be templated, put template on +// member functions instead. +struct GemmPipelineAgBgCrCompV4DefaultPolicy + : public UniversalGemmBasePolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using namespace ck_tile; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{} / KPack, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackB(); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kNPerBlock)*KPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return b_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index f8dd2348cb..cde31f087b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -124,6 +124,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + // Where is the right place for HasHotLoop and TailNum ??? static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index a2a14d1017..33945651ae 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -52,6 +52,9 @@ struct GemmPipelineAGmemBGmemCRegV1 // clang-format on } + // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. + static constexpr bool DoubleSmemBuffer = false; + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index d7fa1c0c61..2d9f95627c 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -338,7 +338,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy { using ALayout = remove_cvref_t; using ADataType = remove_cvref_t; - static_assert(std::is_same_v); + static_assert(std::is_same_v); constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index dd631876b4..771662f566 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -36,6 +36,8 @@ struct GemmPipelineProblemBase static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadK = Traits::kPadK; + static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; + static constexpr auto Scheduler = GemmPipelineScheduler::Default; static constexpr index_t VectorLoadSize = Traits::_VectorSize; @@ -173,6 +175,8 @@ struct UniversalGemmPipelineProblem static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadK = Traits::kPadK; + static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; + static constexpr auto Scheduler = Scheduler_; static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 2a9683b36e..c20d09cea4 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -9,8 +9,8 @@ namespace ck_tile { -// UniversalGemm Policy -struct UniversalGemmPipelineAgBgCrPolicy +template +struct UniversalGemmBasePolicy { static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; @@ -113,7 +113,7 @@ struct UniversalGemmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { - using BlockGemm = remove_cvref_t())>; + using BlockGemm = remove_cvref_t())>; using WG = typename BlockGemm::WarpGemm; constexpr bool TransposeC = Problem::TransposeC; @@ -166,10 +166,116 @@ struct UniversalGemmPipelineAgBgCrPolicy } } + template + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Problem::TransposeC; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ALayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeA(); + + // Tile: MPerBlock X KPerBlock + if constexpr(std::is_same_v) + { + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); + } + // Tile: KPerBlock X MPerBlock + else + { + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeB(); + + // Tile: KPerBlock X NPerBlock + if constexpr(std::is_same_v) + { + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); + } + // Tile: NPerBlock X KPerBlock + else + { + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution() + { + using ALayout = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeA(); + + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution() + { + using BLayout = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeB(); + + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() { - using BlockGemm = decltype(GetBlockGemm()); + using BlockGemm = remove_cvref_t())>; constexpr index_t KPack = BlockGemm::Traits::KPack; return KPack; } @@ -177,11 +283,43 @@ struct UniversalGemmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() { - using BlockGemm = decltype(GetBlockGemm()); + using BlockGemm = remove_cvref_t())>; constexpr index_t KPack = BlockGemm::Traits::KPack; return KPack; } + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + constexpr auto a_lds_desc = Derived::template MakeALdsBlockDescriptor(); + constexpr index_t smem_size_a = integer_least_multiple( + sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16); + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() + { + constexpr auto b_lds_desc = Derived::template MakeBLdsBlockDescriptor(); + constexpr index_t smem_size_b = integer_least_multiple( + sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16); + return smem_size_b; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + constexpr index_t smem_size_b = GetSmemSizeB(); + + return smem_size_a + smem_size_b; + } +}; + +// UniversalGemm Policy +struct UniversalGemmPipelineAgBgCrPolicy + : public UniversalGemmBasePolicy +{ template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { @@ -421,133 +559,6 @@ struct UniversalGemmPipelineAgBgCrPolicy #endif } - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() - { - constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * - MakeALdsBlockDescriptor().get_element_space_size(); - return smem_size_a; - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() - { - constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * - MakeBLdsBlockDescriptor().get_element_space_size(); - return smem_size_b; - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - constexpr index_t smem_size_a = GetSmemSizeA(); - constexpr index_t smem_size_b = GetSmemSizeB(); - index_t smem_size = 0; - smem_size += smem_size_a + smem_size_b; - - return smem_size; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() - { - using ALayout = remove_cvref_t; - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeA(); - - // Tile: MPerBlock X KPerBlock - if constexpr(std::is_same_v) - { - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); - } - // Tile: KPerBlock X MPerBlock - else - { - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() - { - using BLayout = remove_cvref_t; - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeB(); - - // Tile: KPerBlock X NPerBlock - if constexpr(std::is_same_v) - { - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); - } - // Tile: NPerBlock X KPerBlock - else - { - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution() - { - using ALayout = remove_cvref_t; - static_assert(std::is_same_v); - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeA(); - - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution() - { - using BLayout = remove_cvref_t; - static_assert(std::is_same_v); - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeB(); - - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); - } - template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 3d7441c942..d0e1f60d38 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -32,6 +32,7 @@ struct TileGemmTraits template ; using Mem = ck_tile::integral_constant; -using Comp = ck_tile::integral_constant; +using CompV3 = ck_tile::integral_constant; +using CompV4 = ck_tile::integral_constant; // clang-format off using KernelTypes = ::testing::Types< // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, - std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, + std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>, std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem> >; // clang-format on diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index dc685567eb..155234cddc 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -14,7 +14,32 @@ enum struct GemmPipelineType { Mem, - Comp + CompV3, + CompV4 +}; + +template +struct GemmPipelineTypeSelector; + +template +struct GemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrMem; + using pipeline = ck_tile::GemmPipelineAgBgCrMem; +}; + +template +struct GemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + using pipeline = ck_tile::GemmPipelineAgBgCrCompV3; +}; + +template +struct GemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; + using pipeline = ck_tile::GemmPipelineAgBgCrCompV4; }; template @@ -36,8 +61,8 @@ class TestCkTileGemmPipeline : public ::testing::Test void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t M_Warp = 2; @@ -52,6 +77,8 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr bool kPadN = PadN; constexpr bool kPadK = PadK; + constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false; + // TODO: For now - but this should also be a test parameter constexpr bool TransposeC = false; @@ -69,16 +96,20 @@ class TestCkTileGemmPipeline : public ::testing::Test GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile:: - TileGemmUniversalTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; using BaseGemmPipeline = - std::conditional_t, - ck_tile::BaseGemmPipelineAgBgCrCompV3>; + typename GemmPipelineTypeSelector::base_pipeline; const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; @@ -99,12 +130,8 @@ class TestCkTileGemmPipeline : public ::testing::Test has_hot_loop_v, tail_number_v>; - using GemmPipeline = std::conditional_t< - PipelineType == GemmPipelineType::Mem, - ck_tile::GemmPipelineAgBgCrMem, - ck_tile::GemmPipelineAgBgCrCompV3>; + using GemmPipeline = + typename GemmPipelineTypeSelector::pipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem{}, + ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } } else { @@ -258,7 +301,19 @@ class TestCkTileGemmPipeline : public ::testing::Test public: std::vector k_batches_; - void SetUp() override { k_batches_ = {1, 2}; } + void SetUp() override + { + if constexpr(PipelineType == GemmPipelineType::CompV4) + { + // Only do k_batch = 1 when pipeline is CompV4 + k_batches_ = {1}; + } + else + { + // Otherwise, use k_batch = 1 and 2 + k_batches_ = {1, 2}; + } + } template void Run(const int M,