diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 8a0970f494..2e907c2fa8 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -25,13 +25,19 @@ struct Default2DEpilogueProblem static constexpr bool kPadN = kPadN_; static constexpr bool UseRawStore = UseRawStore_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; + static constexpr index_t NumDTensor = 0; }; template ; using BDataType = remove_cvref_t; using CLayout = remove_cvref_t; + using DsDataType = remove_cvref_t; + using CDElementwise = remove_cvref_t; + using DsLayout = remove_cvref_t; + static constexpr index_t kMPerBlock = kM_; + static constexpr index_t kNPerBlock = kN_; static constexpr index_t kMPerXdl = kMPerXdl_; static constexpr index_t kNPerXdl = kNPerXdl_; static constexpr index_t kKPerXdl = kKPerXdl_; static constexpr index_t isCTransposed = isCTransposed_; + + static constexpr index_t NumDTensor = DsDataType::size(); + + static_assert(NumDTensor == DsLayout::size(), + "The size of DsDataType and DsLayout should be the same"); }; template @@ -71,43 +87,70 @@ struct Default2DEpilogue // TODO: this function assume store out vector size is the same as OAccTile last dimension size // how do we fix this ? - template - CK_TILE_DEVICE auto - operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) const - { - // TODO: this is ugly - if constexpr(UseRawStore && (kPadM || kPadN)) - { - if constexpr(MemoryOperation == memory_operation_enum::set) - { - store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); - } - else - { - update_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); - } - buffer_store_fence(); - } - else - { - if constexpr(MemoryOperation == memory_operation_enum::set) - { - store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); - } - else - { - update_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); - } - } - } - template CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, - const DsDramWindows& /* unused */, - void* = nullptr) const + const DsDramWindows& ds_dram_windows, + void* = nullptr) { - return operator()(o_dram_window_tmp, o_acc_tile); + const auto storeOrUpdateTile = [&](const auto& o_tile) { + // TODO: this is ugly + if constexpr(UseRawStore && (kPadM || kPadN)) + { + if constexpr(MemoryOperation == memory_operation_enum::set) + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); + } + else + { + update_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); + } + buffer_store_fence(); + } + else + { + if constexpr(MemoryOperation == memory_operation_enum::set) + { + store_tile(o_dram_window_tmp, cast_tile(o_tile)); + } + else + { + update_tile(o_dram_window_tmp, cast_tile(o_tile)); + } + } + }; + + if constexpr(!std::is_same_v && Problem::NumDTensor >= 1) + { + using elementwise_result_t = decltype(load_tile( + make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(), + make_tuple(Problem::kMPerBlock, Problem::kNPerBlock), + ds_dram_windows[number<0>{}].get_window_origin(), + o_acc_tile.get_tile_distribution()))); + + elementwise_result_t elementwise_result; + + const auto d_tensor_tuple = generate_tuple( + [&](auto idx) { + const auto d_tile_window = + make_tile_window(ds_dram_windows[idx], o_acc_tile.get_tile_distribution()); + return load_tile(d_tile_window); + }, + number{}); + + const auto c_d_tuple = concat_tuple_of_reference( + tie(elementwise_result, o_acc_tile), + generate_tie([&](auto idx) -> const auto& { return d_tensor_tuple[idx]; }, + number{})); + + tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_d_tuple); + + storeOrUpdateTile(elementwise_result); + } + else + { + storeOrUpdateTile(o_acc_tile); + } } }; @@ -122,8 +165,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; - using DsDataType = ck_tile::tuple<>; - using DsLayout = ck_tile::tuple<>; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; using CLayout = remove_cvref_t; static constexpr index_t kMPerXdl = Problem::kMPerXdl; static constexpr index_t kNPerXdl = Problem::kNPerXdl; @@ -192,7 +236,11 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue } } - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; } + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD([[maybe_unused]] number index) + { + return GetVectorSizeC(); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 3f5bef366e..c1f85cb5e6 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -1134,8 +1134,8 @@ struct FmhaBwdDQDKDVKernel scale_rp_undrop, dropout); - KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile); - VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile); + KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile, nullptr); + VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile, nullptr); } else { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 6d35afaa26..ddc5c5447f 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1509,7 +1509,7 @@ struct FmhaFwdKernel make_tuple(number{}, number{}), {i_m0, i_n1}); - EpiloguePipeline{}(o_dram_window, o_acc_tile); + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); } else { @@ -2180,7 +2180,7 @@ struct FmhaFwdKernel make_tuple(number{}, number{}), {i_m0, i_n1}); - EpiloguePipeline{}(o_dram_window, o_acc_tile); + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); } } }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index 9a3e8ac304..58ef6ba87e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -1358,7 +1358,6 @@ struct FmhaFwdPagedKVKernel make_tuple(kargs.stride_o, 1), number{}, number<1>{}); - return pad_tensor_view( o_dram_naive, make_tuple(number{}, number{}), @@ -1370,7 +1369,7 @@ struct FmhaFwdPagedKVKernel make_tuple(number{}, number{}), {i_m0, i_n1}); - EpiloguePipeline{}(o_dram_window, o_acc_tile); + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); } }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index ee1236d465..cf819c4b8d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -484,7 +484,7 @@ struct FmhaFwdSplitKVCombineKernel make_tuple(number{}, number{}), {i_m0, i_n1}); - EpiloguePipeline{}(o_dram_window, o_acc_tile); + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); } }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index c50537f3fe..9293c97a31 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -1134,7 +1134,7 @@ struct FmhaFwdSplitKVKernel make_tuple(number{}, number{}), {i_m0, i_n1}); - EpiloguePipeline{}(o_acc_dram_window, o_acc_tile); + EpiloguePipeline{}(o_acc_dram_window, o_acc_tile, nullptr); } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp index 34c4e72b22..9d3ac8b901 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp @@ -175,6 +175,12 @@ struct GemmKernelMultiD CK_TILE_HOST static auto IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool { + // Currently MultiD kernel doesn't support k_batch > 1 + if(kargs.k_batch > 1) + { + return false; + } + return UniversalGemmKernel::IsSupportedArgument(kargs); } diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp index 60e716e7e7..788d507bf5 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -193,7 +193,7 @@ struct Layernorm2dFwdPipelineOnePass Epilogue{}(y_window_, sm_scale_window_, y_scale_window, ln, smem); } else - Epilogue{}(y_window_, ln); + Epilogue{}(y_window_, ln, nullptr); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index 73cdd084c6..0de1ada87c 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -255,7 +255,7 @@ struct Layernorm2dFwdPipelineTwoPass }); static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT); - Epilogue{}(y_window, ln); + Epilogue{}(y_window, ln, nullptr); move_tile_window(gamma_window, {-Block_N}); move_tile_window(beta_window, {-Block_N}); diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp index 810c3c5243..c5923ba10d 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp @@ -221,7 +221,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass } else { - Epilogue{}(y_window_, rmsn); + Epilogue{}(y_window_, rmsn, nullptr); } } }; diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp index c77d61872e..39d7c65d3e 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp @@ -160,7 +160,7 @@ struct Rmsnorm2dFwdPipelineOnePass } else { - Epilogue{}(y_window_, rmsn); + Epilogue{}(y_window_, rmsn, nullptr); } } }; diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp index 4ca1dbc5da..d01f37879a 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp @@ -195,7 +195,7 @@ struct Rmsnorm2dFwdPipelineTwoPass }); static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP); - Epilogue{}(y_window, rmsn); + Epilogue{}(y_window, rmsn, nullptr); move_tile_window(gamma_window, {-Block_N}); move_tile_window(y_window, {0, -Block_N}); diff --git a/test/ck_tile/gemm_multi_d/CMakeLists.txt b/test/ck_tile/gemm_multi_d/CMakeLists.txt index a50de7178b..c9d53e53e2 100644 --- a/test/ck_tile/gemm_multi_d/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_d/CMakeLists.txt @@ -5,6 +5,8 @@ if(CK_USE_OCP_FP8) endif() if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") - add_gtest_executable(test_ck_tile_gemm_multi_d test_gemm_multi_d.cpp) - target_compile_definitions(test_ck_tile_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_gemm_multi_d_cshuffle test_gemm_multi_d_cshuffle.cpp) + add_gtest_executable(test_gemm_multi_d_default2d test_gemm_multi_d_default2d.cpp) + target_compile_definitions(test_gemm_multi_d_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_definitions(test_gemm_multi_d_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_cshuffle.cpp similarity index 75% rename from test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp rename to test/ck_tile/gemm_multi_d/test_gemm_multi_d_cshuffle.cpp index a634d825b7..8ac847e888 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_cshuffle.cpp @@ -18,22 +18,23 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, CDataType, CDElementWiseFn - std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F16, ElementWiseAddAdd>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F16, ElementWiseAddAdd>, + // Has cshuffle epilogue enabled + // ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, CDElementWiseFn, UseCshuffleEpilog + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F16, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F16, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F32, MultiplyMultiply>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply>, - std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply>, - std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F32, MultiplyMultiply> + std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F32, MultiplyMultiply, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply, std::true_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F32, MultiplyMultiply, std::true_type> >; // clang-format on TYPED_TEST_SUITE(TestCkTileGemmMultiD, KernelTypes); -#include "test_gemm_multi_d_ut_cases.inc" +#include "test_gemm_multi_d_ut_cases_cshuffle.inc" diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_default2d.cpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_default2d.cpp new file mode 100644 index 0000000000..4f14cc49f9 --- /dev/null +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_default2d.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_gemm_multi_d_util.hpp" + +using F16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; +using F32 = float; +using F8 = ck_tile::fp8_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypes = ::testing::Types< + // Has cshuffle epilogue disabled + // ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, CDElementWiseFn, UseCshuffleEpilog + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, BF16, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, BF16, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F16, F16, F32, F16, ElementWiseAddAdd, std::false_type>, + + std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, BF16, MultiplyMultiply, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, BF16, MultiplyMultiply, std::false_type>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F16, F16, F32, F16, MultiplyMultiply, std::false_type> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmMultiD, KernelTypes); + +#include "test_gemm_multi_d_ut_cases_default2d.inc" diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc deleted file mode 100644 index 22d887fa83..0000000000 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc +++ /dev/null @@ -1,334 +0,0 @@ -#pragma once - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x512x256) -{ - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x256x256) -{ - constexpr int M = 512; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x512x256) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x256x256) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x768x256) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x1280x256) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x1280x256) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_768x512x256) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x512x256) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x256x256) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x512x256) -{ - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x256x256) -{ - constexpr int M = 512; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x512x256) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x256x256) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x768x256) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x1280x256) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x1280x256) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_768x512x256) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x512x256) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x256x256) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x256x512) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x768x512) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x1280x512) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x1280x512) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_768x512x512) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x512x512) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x256x512) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x512x512) -{ - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x256x512) -{ - constexpr int M = 512; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x512x512) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x256x512) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x768x512) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x1280x512) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x1280x512) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_768x512x512) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x512x512) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} - -TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x256x512) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 512; - constexpr int kBatch = 2; - this->Run(M, N, K, kBatch); -} diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_cshuffle.inc b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_cshuffle.inc new file mode 100644 index 0000000000..8d21c65692 --- /dev/null +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_cshuffle.inc @@ -0,0 +1,211 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x512x512) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_default2d.inc b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_default2d.inc new file mode 100644 index 0000000000..35b40a896a --- /dev/null +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases_default2d.inc @@ -0,0 +1,211 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x512x512) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index d21777c92b..8399bc7ee3 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -70,20 +70,21 @@ template class TestCkTileGemmMultiD : public ::testing::Test { protected: - using ALayout = std::tuple_element_t<0, Tuple>; - using BLayout = std::tuple_element_t<1, Tuple>; - using D0Layout = std::tuple_element_t<2, Tuple>; - using D1Layout = std::tuple_element_t<3, Tuple>; - using ELayout = std::tuple_element_t<4, Tuple>; - using ADataType = std::tuple_element_t<5, Tuple>; - using BDataType = std::tuple_element_t<6, Tuple>; - using D0DataType = std::tuple_element_t<7, Tuple>; - using D1DataType = std::tuple_element_t<8, Tuple>; - using AccDataType = std::tuple_element_t<9, Tuple>; - using EDataType = std::tuple_element_t<10, Tuple>; - using CDElementWiseFn = std::tuple_element_t<11, Tuple>; - using DsLayout = ck_tile::tuple; - using DsDataType = ck_tile::tuple; + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using D0Layout = std::tuple_element_t<2, Tuple>; + using D1Layout = std::tuple_element_t<3, Tuple>; + using ELayout = std::tuple_element_t<4, Tuple>; + using ADataType = std::tuple_element_t<5, Tuple>; + using BDataType = std::tuple_element_t<6, Tuple>; + using D0DataType = std::tuple_element_t<7, Tuple>; + using D1DataType = std::tuple_element_t<8, Tuple>; + using AccDataType = std::tuple_element_t<9, Tuple>; + using EDataType = std::tuple_element_t<10, Tuple>; + using CDElementWiseFn = std::tuple_element_t<11, Tuple>; + using UseCshuffleEpilog = std::tuple_element_t<12, Tuple>; + using DsLayout = ck_tile::tuple; + using DsDataType = ck_tile::tuple; template ; using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< + + using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; + + using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = std:: + conditional_t; + using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); @@ -218,6 +243,7 @@ class TestCkTileGemmMultiD : public ::testing::Test const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { + std::cout << "Run without SplitK" << std::endl; Run(has_hot_loop_, tail_number_, ck_tile::integral_constant{}); } }; - if(has_hot_loop) - { - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "For compute pipeline tail number should always be Full, but have \"" - << tail_num << "\" which is not supported! PrefetchStages: " - << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - } - else - { - std::ostringstream err; - err << "Num K loop must be larger than number of prefetech stages." - << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } public: - void Run(const int M, + bool Run(const int M, const int N, const int K, const int k_batch, @@ -401,6 +404,6 @@ class TestCkTileGemmMultiD : public ::testing::Test << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; - EXPECT_TRUE(pass); + return pass; } }; diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index dd9de36865..392125aa0b 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -31,9 +31,14 @@ DEFAULT_EPILOGUE = """ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< ck_tile::DefaultGemm2DEpilogueProblem, AccDataType, CDataType, + ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, kPadM, kPadN, WarpTileM,