From cef79d5f82f030c31540c9f6f2e2cc8f2a4072d3 Mon Sep 17 00:00:00 2001 From: John Afaganis Date: Tue, 26 Aug 2025 10:48:49 -0600 Subject: [PATCH] Revert "[CK-TILE] Default epilogue, adding support for D (#2629)" (#2746) This reverts commit 92037686aef6ca9ee4176605717bd92146563c60. [ROCm/composable_kernel commit: 508e7912f9bb758c22c7f7c1fc5dbb4cd3030c06] --- .../ops/epilogue/default_2d_epilogue.hpp | 108 ++---- .../ops/gemm/kernel/gemm_multi_d_kernel.hpp | 6 - test/ck_tile/gemm_multi_d/CMakeLists.txt | 6 +- ...i_d_cshuffle.cpp => test_gemm_multi_d.cpp} | 27 +- .../test_gemm_multi_d_default2d.cpp | 43 --- .../test_gemm_multi_d_ut_cases.inc | 334 ++++++++++++++++++ .../test_gemm_multi_d_ut_cases_cshuffle.inc | 211 ----------- .../test_gemm_multi_d_ut_cases_default2d.inc | 211 ----------- .../gemm_multi_d/test_gemm_multi_d_util.hpp | 89 +++-- tile_engine/ops/gemm/codegen_utils.py | 5 - 10 files changed, 422 insertions(+), 618 deletions(-) rename test/ck_tile/gemm_multi_d/{test_gemm_multi_d_cshuffle.cpp => test_gemm_multi_d.cpp} (75%) delete mode 100644 test/ck_tile/gemm_multi_d/test_gemm_multi_d_default2d.cpp create mode 100644 test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc delete mode 100644 test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_cshuffle.inc delete mode 100644 test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_default2d.inc diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 401f90f78f..8a0970f494 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -29,14 +29,9 @@ struct Default2DEpilogueProblem template ; using BDataType = remove_cvref_t; using CLayout = remove_cvref_t; - using DsDataType = remove_cvref_t; - using DsLayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; - static constexpr index_t kMPerBlock = kM_; - static constexpr index_t kNPerBlock = kN_; static constexpr index_t kMPerXdl = kMPerXdl_; static constexpr index_t kNPerXdl = kNPerXdl_; static constexpr index_t kKPerXdl = kKPerXdl_; static constexpr index_t isCTransposed = isCTransposed_; - - static constexpr index_t NumDTensor = DsDataType::size(); - - static_assert(NumDTensor == DsLayout::size(), - "The size of DsDataType and DsLayout should be the same"); }; template @@ -77,7 +62,6 @@ struct Default2DEpilogue using Problem = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; - using CDElementwise = remove_cvref_t; static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool UseRawStore = Problem::UseRawStore; @@ -87,71 +71,44 @@ struct Default2DEpilogue // TODO: this function assume store out vector size is the same as OAccTile last dimension size // how do we fix this ? - template - CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, - const OAccTile& o_acc_tile, - const DsDramWindows& ds_dram_windows, - void* = nullptr) + template + CK_TILE_DEVICE auto + operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) const { - const auto storeOrUpdateTile = [&](const auto& o_tile) { - // TODO: this is ugly - if constexpr(UseRawStore && (kPadM || kPadN)) + // TODO: this is ugly + if constexpr(UseRawStore && (kPadM || kPadN)) + { + if constexpr(MemoryOperation == memory_operation_enum::set) { - if constexpr(MemoryOperation == memory_operation_enum::set) - { - store_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); - } - else - { - update_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); - } - buffer_store_fence(); + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); } else { - if constexpr(MemoryOperation == memory_operation_enum::set) - { - store_tile(o_dram_window_tmp, cast_tile(o_tile)); - } - else - { - update_tile(o_dram_window_tmp, cast_tile(o_tile)); - } + update_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); } - }; - - if constexpr(Problem::NumDTensor >= 1) - { - using elementwise_result_t = decltype(load_tile( - make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(), - make_tuple(Problem::kMPerBlock, Problem::kNPerBlock), - ds_dram_windows[number<0>{}].get_window_origin(), - o_acc_tile.get_tile_distribution()))); - - elementwise_result_t elementwise_result; - - const auto d_tensor_tuple = generate_tuple( - [&](auto idx) { - const auto d_tile_window = - make_tile_window(ds_dram_windows[idx], o_acc_tile.get_tile_distribution()); - return load_tile(d_tile_window); - }, - number{}); - - const auto c_d_tuple = concat_tuple_of_reference( - tie(elementwise_result, o_acc_tile), - generate_tie([&](auto idx) -> const auto& { return d_tensor_tuple[idx]; }, - number{})); - - tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_d_tuple); - - storeOrUpdateTile(elementwise_result); + buffer_store_fence(); } else { - storeOrUpdateTile(o_acc_tile); + if constexpr(MemoryOperation == memory_operation_enum::set) + { + store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + } + else + { + update_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + } } } + + template + CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, + const OAccTile& o_acc_tile, + const DsDramWindows& /* unused */, + void* = nullptr) const + { + return operator()(o_dram_window_tmp, o_acc_tile); + } }; template @@ -165,9 +122,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; - using DsDataType = remove_cvref_t; - using DsLayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; + using DsDataType = ck_tile::tuple<>; + using DsLayout = ck_tile::tuple<>; using CLayout = remove_cvref_t; static constexpr index_t kMPerXdl = Problem::kMPerXdl; static constexpr index_t kNPerXdl = Problem::kNPerXdl; @@ -236,11 +192,7 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue } } - template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD([[maybe_unused]] number index) - { - return GetVectorSizeC(); - } + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp index 9d3ac8b901..34c4e72b22 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp @@ -175,12 +175,6 @@ struct GemmKernelMultiD CK_TILE_HOST static auto IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool { - // Currently MultiD kernel doesn't support k_batch > 1 - if(kargs.k_batch > 1) - { - return false; - } - return UniversalGemmKernel::IsSupportedArgument(kargs); } diff --git a/test/ck_tile/gemm_multi_d/CMakeLists.txt b/test/ck_tile/gemm_multi_d/CMakeLists.txt index c9d53e53e2..a50de7178b 100644 --- a/test/ck_tile/gemm_multi_d/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_d/CMakeLists.txt @@ -5,8 +5,6 @@ if(CK_USE_OCP_FP8) endif() if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") - add_gtest_executable(test_gemm_multi_d_cshuffle test_gemm_multi_d_cshuffle.cpp) - add_gtest_executable(test_gemm_multi_d_default2d test_gemm_multi_d_default2d.cpp) - target_compile_definitions(test_gemm_multi_d_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - target_compile_definitions(test_gemm_multi_d_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_gemm_multi_d test_gemm_multi_d.cpp) + target_compile_definitions(test_ck_tile_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_cshuffle.cpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp similarity index 75% rename from test/ck_tile/gemm_multi_d/test_gemm_multi_d_cshuffle.cpp rename to test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp index 8ac847e888..a634d825b7 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_cshuffle.cpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp @@ -18,23 +18,22 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; // clang-format off using KernelTypes = ::testing::Types< - // Has cshuffle epilogue enabled - // ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, CDElementWiseFn, UseCshuffleEpilog - std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F16, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F16, ElementWiseAddAdd, std::true_type>, + // ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, CDataType, CDElementWiseFn + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F16, ElementWiseAddAdd>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F16, ElementWiseAddAdd>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply, std::true_type>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F32, MultiplyMultiply, std::true_type>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply, std::true_type>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply, std::true_type>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply, std::true_type>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F32, MultiplyMultiply, std::true_type> + std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F32, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F32, MultiplyMultiply> >; // clang-format on TYPED_TEST_SUITE(TestCkTileGemmMultiD, KernelTypes); -#include "test_gemm_multi_d_ut_cases_cshuffle.inc" +#include "test_gemm_multi_d_ut_cases.inc" diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_default2d.cpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_default2d.cpp deleted file mode 100644 index 4f14cc49f9..0000000000 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_default2d.cpp +++ /dev/null @@ -1,43 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "gtest/gtest.h" - -#include "ck_tile/host.hpp" -#include "test_gemm_multi_d_util.hpp" - -using F16 = ck_tile::half_t; -using BF16 = ck_tile::bf16_t; -using F32 = float; -using F8 = ck_tile::fp8_t; - -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - -// clang-format off -using KernelTypes = ::testing::Types< - // Has cshuffle epilogue disabled - // ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, CDElementWiseFn, UseCshuffleEpilog - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, BF16, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, BF16, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, F16, F16, F32, F16, ElementWiseAddAdd, std::false_type>, - - std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply, std::false_type>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply, std::false_type>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply, std::false_type>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply, std::false_type>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, BF16, MultiplyMultiply, std::false_type>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, BF16, MultiplyMultiply, std::false_type>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, F16, F16, F32, F16, MultiplyMultiply, std::false_type> - >; -// clang-format on - -TYPED_TEST_SUITE(TestCkTileGemmMultiD, KernelTypes); - -#include "test_gemm_multi_d_ut_cases_default2d.inc" diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc new file mode 100644 index 0000000000..22d887fa83 --- /dev/null +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc @@ -0,0 +1,334 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x256x512) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x768x512) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x1280x512) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x1280x512) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_768x512x512) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x512x512) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x256x512) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x512x512) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x256x512) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x512x512) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x256x512) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x768x512) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x1280x512) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x1280x512) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_768x512x512) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x512x512) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x256x512) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_cshuffle.inc b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_cshuffle.inc deleted file mode 100644 index 8d21c65692..0000000000 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_cshuffle.inc +++ /dev/null @@ -1,211 +0,0 @@ -#pragma once - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x512x256) -{ - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x256x256) -{ - constexpr int M = 512; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x512x256) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x256x256) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x768x256) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x1280x256) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x1280x256) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_768x512x256) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_1280x512x256) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_1280x256x256) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x512x512) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x512x256) -{ - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x256x256) -{ - constexpr int M = 512; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x512x256) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x256x256) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x768x256) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x1280x256) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x1280x256) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_768x512x256) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_1280x512x256) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_1280x256x256) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_default2d.inc b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_default2d.inc deleted file mode 100644 index 35b40a896a..0000000000 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_default2d.inc +++ /dev/null @@ -1,211 +0,0 @@ -#pragma once - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x512x256) -{ - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x256x256) -{ - constexpr int M = 512; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x512x256) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x256x256) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x768x256) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x1280x256) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x1280x256) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_768x512x256) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_1280x512x256) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_1280x256x256) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x512x512) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x512x256) -{ - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x256x256) -{ - constexpr int M = 512; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x512x256) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x256x256) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x768x256) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x1280x256) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x1280x256) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_768x512x256) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_1280x512x256) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_1280x256x256) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - - EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); -} diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index 8399bc7ee3..d21777c92b 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -70,21 +70,20 @@ template class TestCkTileGemmMultiD : public ::testing::Test { protected: - using ALayout = std::tuple_element_t<0, Tuple>; - using BLayout = std::tuple_element_t<1, Tuple>; - using D0Layout = std::tuple_element_t<2, Tuple>; - using D1Layout = std::tuple_element_t<3, Tuple>; - using ELayout = std::tuple_element_t<4, Tuple>; - using ADataType = std::tuple_element_t<5, Tuple>; - using BDataType = std::tuple_element_t<6, Tuple>; - using D0DataType = std::tuple_element_t<7, Tuple>; - using D1DataType = std::tuple_element_t<8, Tuple>; - using AccDataType = std::tuple_element_t<9, Tuple>; - using EDataType = std::tuple_element_t<10, Tuple>; - using CDElementWiseFn = std::tuple_element_t<11, Tuple>; - using UseCshuffleEpilog = std::tuple_element_t<12, Tuple>; - using DsLayout = ck_tile::tuple; - using DsDataType = ck_tile::tuple; + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using D0Layout = std::tuple_element_t<2, Tuple>; + using D1Layout = std::tuple_element_t<3, Tuple>; + using ELayout = std::tuple_element_t<4, Tuple>; + using ADataType = std::tuple_element_t<5, Tuple>; + using BDataType = std::tuple_element_t<6, Tuple>; + using D0DataType = std::tuple_element_t<7, Tuple>; + using D1DataType = std::tuple_element_t<8, Tuple>; + using AccDataType = std::tuple_element_t<9, Tuple>; + using EDataType = std::tuple_element_t<10, Tuple>; + using CDElementWiseFn = std::tuple_element_t<11, Tuple>; + using DsLayout = ck_tile::tuple; + using DsDataType = ck_tile::tuple; template ; using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - - using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem>; - - using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< + using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = std:: - conditional_t; - using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); @@ -243,7 +218,6 @@ class TestCkTileGemmMultiD : public ::testing::Test const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - std::cout << "Run without SplitK" << std::endl; Run(has_hot_loop_, tail_number_, ck_tile::integral_constant{}); } }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(has_hot_loop) + { + if(tail_num == ck_tile::TailNumber::Full) + { + RunSplitk( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" + << tail_num << "\" which is not supported! PrefetchStages: " + << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + else + { + std::ostringstream err; + err << "Num K loop must be larger than number of prefetech stages." + << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } } public: - bool Run(const int M, + void Run(const int M, const int N, const int K, const int k_batch, @@ -404,6 +401,6 @@ class TestCkTileGemmMultiD : public ::testing::Test << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; - return pass; + EXPECT_TRUE(pass); } }; diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 392125aa0b..dd9de36865 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -31,14 +31,9 @@ DEFAULT_EPILOGUE = """ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< ck_tile::DefaultGemm2DEpilogueProblem, AccDataType, CDataType, - ck_tile::tuple<>, CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, kPadM, kPadN, WarpTileM,