diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index aacbdf6863..e9ffe72a91 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank sh ../script/cmake-ck-dev.sh ../ +# The basic pipeline method on the gemm calculation make tile_example_gemm_basic -j +# The memory bound pipeline on the gemm calculation +make tile_example_gemm_mem_pipeline -j ``` This will result in an executable `build/bin/tile_example_gemm_basic` diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 09427217c5..b7d8693442 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -17,10 +17,11 @@ template float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) { - // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadA = true; - constexpr bool kPadB = true; - constexpr bool kPadC = true; + // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + constexpr bool kTilePermute = false; // The rank and permutation will also be generate out by the CodeGen part. constexpr ck_tile::index_t kOutputRank = 2; @@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) CShuffleEpilogue, ck_tile::CShuffleEpilogue>, ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>>; + ck_tile::Default2DEpilogueProblem>>; using CodegenGemmTraits = - ck_tile::TileGemmTraits; + ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; - using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; + using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. diff --git a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp b/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp index 2ee0395e47..ff9d8bad32 100644 --- a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp +++ b/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp @@ -31,9 +31,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) constexpr ck_tile::index_t K_Warp_Tile = 8; // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadA = true; - constexpr bool kPadB = true; - constexpr bool kPadC = true; + constexpr bool kPadM = true; + constexpr bool kPadN = true; + constexpr bool kPadK = true; constexpr int kBlockPerCu = 1; @@ -46,9 +46,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) using TilePartitioner = ck_tile::GemmTilePartitioner; using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; + ck_tile::Default2DEpilogueProblem>; - using Traits = ck_tile::TileGemmTraits; + using Traits = ck_tile::TileGemmTraits; using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< ck_tile::GemmPipelineProblem>; diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp index da3c7117e5..55e3274cde 100644 --- a/include/ck_tile/core/tensor/shuffle_tile.hpp +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -170,7 +170,7 @@ CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in) } else { - // NOT implemented + static_assert(false, "The shuffle should always happen!"); } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index fbb05e1641..a3a29bb540 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -863,6 +863,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy{}, number{}), - // somehow clang-format is splitting below line into multiple. - // clang-format off - sequence{}); + auto a_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + }(); // clang-format on auto a_block_window = make_tile_window( @@ -128,12 +138,22 @@ struct GemmKernel make_tuple(number{}, number{}), {i_m, 0}); - auto b_pad_view = pad_tensor_view( - b_tensor_view, - make_tuple(number{}, number{}), - // clang-format off - sequence{}); - // clang-format on + auto b_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view( + b_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + b_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + }(); auto b_block_window = make_tile_window( b_pad_view, @@ -171,18 +191,28 @@ struct GemmKernel } }(); - auto c_pad_view = pad_tensor_view( - c_tensor_view, - make_tuple(number{}, number{}), - // clang-format off - sequence{}); - // clang-format on - auto c_block_window = make_tile_window( + auto c_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view( + c_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + c_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + auto CBlockWindow_pad = make_tile_window( c_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); - EpiloguePipeline{}(c_block_window, c_block_tile); + EpiloguePipeline{}(CBlockWindow_pad, c_block_tile); } }; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index b9b45d3f42..85c5c58056 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeC = Problem::VectorSizeC; - static constexpr bool kPadA = Problem::kPadA; - static constexpr bool kPadB = Problem::kPadB; - static constexpr bool kPadC = Problem::kPadC; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; // Where is the right place for HasHotLoop and TailNum ??? static constexpr bool HasHotLoop = Problem::HasHotLoop; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index a2424290e6..c0817e736b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeC = Problem::VectorSizeC; - static constexpr bool kPadA = Problem::kPadA; - static constexpr bool kPadB = Problem::kPadB; - static constexpr bool kPadC = Problem::kPadC; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() { @@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1 Policy::template MakeADramTileDistribution()); // A LDS tile window for store - auto a_copy_lds_window = - make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - a_copy_dram_window.get_tile_distribution()); + auto a_copy_lds_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); // B DRAM tile window for load auto b_copy_dram_window = @@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1 Policy::template MakeBDramTileDistribution()); // B LDS tile window for store - auto b_copy_lds_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - b_copy_dram_window.get_tile_distribution()); + auto b_copy_lds_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); // A LDS tile for block GEMM auto a_lds_gemm_window = make_tile_window( @@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + if constexpr(std::is_same_v) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegBlockDescriptor()); + shuffle_tile(a_shuffle_tmp, a_block_tile); + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); + store_tile(a_copy_lds_window, a_block_tile_tmp); + } + else + { + store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + } // LDS write 0 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + if constexpr(std::is_same_v) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegBlockDescriptor()); + shuffle_tile(b_shuffle_tmp, b_block_tile); + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp); + store_tile(b_copy_lds_window, b_block_tile_tmp); + } + else + { + store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile)); + } } index_t iCounter = num_loop - 1; @@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1 store_tile(a_copy_lds_window, a_block_tile_tmp); // LDS write i + 1 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + if constexpr(std::is_same_v) + { + auto b_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledBRegBlockDescriptor()); + shuffle_tile(b_shuffle_tmp_loop, b_block_tile); + store_tile(b_copy_lds_window, + tile_elementwise_in(b_element_func, b_shuffle_tmp_loop)); + } + else + { + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + } iCounter--; } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 199ba56aac..c765b3ce9d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -11,6 +11,7 @@ namespace ck_tile { // Default policy class should not be templated, put template on member functions instead struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy { + #if 0 // 2d template @@ -116,6 +117,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy return smem_size; } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + { + using ADataType = remove_cvref_t; + return Problem::VectorLoadSize / sizeof(ADataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() + { + using BDataType = remove_cvref_t; + return Problem::VectorLoadSize / sizeof(BDataType); + } #elif 1 // fake XOR template @@ -192,80 +207,269 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; -#if 1 // coalesce reading for each blocks - constexpr index_t M1 = kBlockSize / get_warp_size(); - static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); - constexpr index_t M0 = kMPerBlock / (M2 * M1); + if constexpr(std::is_same_v) + { + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = MPerBlock / M1; + constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t KPack = GetSmemPackA(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * M0)) + { + constexpr index_t K1 = get_warp_size() / (K2 * M0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + if constexpr(get_warp_size() % (M2 * K0) == 0) + { + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -#else // coalesce reading for each warps - constexpr index_t M0 = kBlockSize / get_warp_size(); - constexpr index_t M1 = kMPerBlock / (M2 * M0); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); -#endif + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + constexpr index_t M0 = BlockSize / get_warp_size(); + constexpr index_t M1 = MPerBlock / (M2 * M0); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } } template CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() { using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N0 = NPerBlock / N1; + constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % N1 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t KPack = GetSmemPackB(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + if constexpr(get_warp_size() % (N2 * K0) == 0) + { + constexpr index_t N1 = BlockSize / get_warp_size(); + static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = NPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + // coalesce reading for each warps + else + { + constexpr index_t N0 = BlockSize / get_warp_size(); + constexpr index_t N1 = NPerBlock / (N2 * N0); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor() + { + using BLayout = remove_cvref_t; + using BDataType = remove_cvref_t; + static_assert(std::is_same_v); constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K1 = 16 / sizeof(BDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; -#if 1 // coalesce reading for each blocks - constexpr index_t N1 = kBlockSize / get_warp_size(); - static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); - constexpr index_t N0 = kNPerBlock / (N2 * N1); + constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemPackB(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * N0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * N0); + constexpr index_t K0 = kBlockSize / warp_size; - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -#else // coalesce reading for each warps - constexpr index_t N0 = kBlockSize / get_warp_size(); - constexpr index_t N1 = kNPerBlock / (N2 * N0); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); -#endif + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() + { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = kMPerBlock / M1; + constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t kKPack = GetSmemPackA(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * M0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * M0); + constexpr index_t K0 = kBlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } } template diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 1156f549b6..3c43790bd6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -3,40 +3,133 @@ #pragma once -#include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { -static constexpr int _VectorSize = 16; - template -struct GemmPipelineProblem +struct GemmPipelineProblemBase { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using GemmTraits = remove_cvref_t; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; - using GemmTraits = remove_cvref_t; using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; - static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - static constexpr bool kPadA = GemmTraits::kPadA; - static constexpr bool kPadB = GemmTraits::kPadB; - static constexpr bool kPadC = GemmTraits::kPadC; + static constexpr index_t VectorLoadSize = GemmTraits::_VectorSize; + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - static constexpr index_t VectorSizeA = kPadA ? 1 : _VectorSize / sizeof(ADataType); - static constexpr index_t VectorSizeB = kPadB ? 1 : _VectorSize / sizeof(BDataType); - static constexpr index_t VectorSizeC = kPadC ? 1 : _VectorSize / sizeof(CDataType); + static constexpr bool kPadM = GemmTraits::kPadM; + static constexpr bool kPadN = GemmTraits::kPadN; + static constexpr bool kPadK = GemmTraits::kPadK; + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() + { + if constexpr(std::is_same_v) + { + constexpr index_t pixels_per_thread = + BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; + return pixels_per_thread < VectorLoadSize / sizeof(ADataType) + ? pixels_per_thread + : VectorLoadSize / sizeof(ADataType); + } + else + { + return VectorLoadSize / sizeof(ADataType); + } + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB() + { + if constexpr(std::is_same_v) + { + constexpr index_t pixels_per_thread = + BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; + return pixels_per_thread < VectorLoadSize / sizeof(BDataType) + ? pixels_per_thread + : VectorLoadSize / sizeof(BDataType); + } + else + { + return VectorLoadSize / sizeof(BDataType); + } + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC() + { + if constexpr(std::is_same_v) + { + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size()); + constexpr index_t M0 = get_warp_size() / N2; + constexpr index_t M1 = BlockGemmShape::kM / M0; + + return std::min(M1, static_cast(VectorLoadSize / sizeof(CDataType))); + } + else + { + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size()); + constexpr index_t N0 = get_warp_size() / M2; + constexpr index_t N1 = BlockGemmShape::kN / N0; + + return std::min(N1, static_cast(VectorLoadSize / sizeof(CDataType))); + } + } + + static constexpr index_t VectorSizeA = []() { + if constexpr(std::is_same_v) + { + return kPadK ? 1 : GetAlignmentA(); + } + else + { + return kPadM ? 1 : GetAlignmentA(); + } + }(); + + static constexpr index_t VectorSizeB = []() { + if constexpr(std::is_same_v) + { + return kPadN ? 1 : GetAlignmentB(); + } + else + { + return kPadK ? 1 : GetAlignmentB(); + } + }(); + + static constexpr index_t VectorSizeC = []() { + if constexpr(std::is_same_v) + { + return kPadN ? 1 : GetAlignmentC(); + } + else + { + return kPadM ? 1 : GetAlignmentC(); + } + }(); }; +// Alias for GemmPipelineProblem +template +using GemmPipelineProblem = + GemmPipelineProblemBase; + template -struct UniversalGemmPipelineProblem +struct UniversalGemmPipelineProblem : public GemmPipelineProblemBase { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - using GemmTraits = remove_cvref_t; - - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - static constexpr auto Scheduler = Scheduler_; - static constexpr auto HasHotLoop = HasHotLoop_; - static constexpr auto TailNum = TailNum_; - static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - - static constexpr bool kPadA = GemmTraits::kPadA; - static constexpr bool kPadB = GemmTraits::kPadB; - static constexpr bool kPadC = GemmTraits::kPadC; - - static constexpr index_t VectorSizeA = kPadA ? _VectorSize / sizeof(ADataType) : 1; - static constexpr index_t VectorSizeB = kPadB ? _VectorSize / sizeof(BDataType) : 1; - static constexpr index_t VectorSizeC = kPadC ? _VectorSize / sizeof(CDataType) : 1; + static constexpr auto Scheduler = Scheduler_; + static constexpr auto HasHotLoop = HasHotLoop_; + static constexpr auto TailNum = TailNum_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 7044a53140..207f1f9e4b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -9,12 +9,8 @@ namespace ck_tile { // UniversalGemm Policy -template struct UniversalGemmPipelineAgBgCrPolicy { - using LayoutA = remove_cvref_t; - using LayoutB = remove_cvref_t; - using LayoutC = remove_cvref_t; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; @@ -34,13 +30,14 @@ struct UniversalGemmPipelineAgBgCrPolicy TransposeC>; using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t K1 = WarpGemm::kK; constexpr index_t K0 = KPerBlock / K1; - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 ? 1 @@ -176,13 +173,15 @@ struct UniversalGemmPipelineAgBgCrPolicy using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t K1 = WarpGemm::kK; constexpr index_t K0 = KPerBlock / K1; - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { // NLdsLayer * K0 as logical Bank constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 @@ -331,72 +330,285 @@ struct UniversalGemmPipelineAgBgCrPolicy return smem_size; } + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + { + using ADataType = remove_cvref_t; + return Problem::VectorLoadSize / sizeof(ADataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() + { + using BDataType = remove_cvref_t; + return Problem::VectorLoadSize / sizeof(BDataType); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { - using WarpGemm = WarpGemmMfmaDispatcher; + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K1 = WarpGemm::kK; - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - - constexpr index_t M1 = BlockSize / get_warp_size(); - static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); - constexpr index_t M0 = MPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + if constexpr(std::is_same_v) + { + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = MPerBlock / M1; + constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t KPack = GetSmemPackA(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * M0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * M0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + if constexpr(get_warp_size() % (M2 * K0) == 0) + { + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + constexpr index_t M0 = BlockSize / get_warp_size(); + constexpr index_t M1 = MPerBlock / (M2 * M0); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } } template CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() { - using WarpGemm = WarpGemmMfmaDispatcher; + using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K1 = WarpGemm::kK; - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; + if constexpr(std::is_same_v) + { + constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N0 = NPerBlock / N1; + constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % N1 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t KPack = GetSmemPackB(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { - constexpr index_t N1 = BlockSize / get_warp_size(); - static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); - constexpr index_t N0 = NPerBlock / (N2 * N1); + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + if constexpr(get_warp_size() % (N2 * K0) == 0) + { + constexpr index_t N1 = BlockSize / get_warp_size(); + static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = NPerBlock / (N2 * N1); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + // coalesce reading for each warps + else + { + constexpr index_t N0 = BlockSize / get_warp_size(); + constexpr index_t N1 = NPerBlock / (N2 * N0); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() + { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = MPerBlock / M1; + constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t kKPack = GetSmemPackB(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * M0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * M0); + constexpr index_t K0 = BlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor() + { + using BLayout = remove_cvref_t; + using BDataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N0 = NPerBlock / N1; + constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % N1 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemPackB(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * N0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * N0); + constexpr index_t K0 = BlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } } template diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 9d050be2fb..34756c3ff6 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -3,19 +3,23 @@ #pragma once +#include "ck_tile/core.hpp" + namespace ck_tile { -template struct TileGemmTraits { - static constexpr bool kPadA = kPadA_; - static constexpr bool kPadB = kPadB_; - static constexpr bool kPadC = kPadC_; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kPadK = kPadK_; + + static constexpr int _VectorSize = 16; using ALayout = ALayout_; using BLayout = BLayout_; diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp index 1b243ab437..6b47898339 100644 --- a/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp @@ -53,9 +53,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; - constexpr bool kPadA = true; - constexpr bool kPadB = true; - constexpr bool kPadC = true; + constexpr bool kPadM = true; + constexpr bool kPadN = true; + constexpr bool kPadK = true; constexpr int kBlockPerCu = 1; @@ -68,9 +68,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test using TilePartitioner = ck_tile::GemmTilePartitioner; using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; + ck_tile::Default2DEpilogueProblem>; - using Traits = ck_tile::TileGemmTraits; + using Traits = ck_tile::TileGemmTraits; using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< ck_tile::GemmPipelineProblem>; @@ -108,7 +108,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test if(s.log_level_ > 0) { - std::cout << "Lunching kernel with args:" + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;