mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Merge commit '508e7912f9bb758c22c7f7c1fc5dbb4cd3030c06' into develop
This commit is contained in:
@@ -29,14 +29,9 @@ struct Default2DEpilogueProblem
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename DsLayout_,
|
||||
typename CLayout_,
|
||||
typename CDElementwise_,
|
||||
index_t kM_,
|
||||
index_t kN_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
index_t kMPerXdl_,
|
||||
@@ -55,20 +50,10 @@ struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataTyp
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CLayout = remove_cvref_t<CLayout_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t kMPerXdl = kMPerXdl_;
|
||||
static constexpr index_t kNPerXdl = kNPerXdl_;
|
||||
static constexpr index_t kKPerXdl = kKPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
@@ -77,7 +62,6 @@ struct Default2DEpilogue
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool UseRawStore = Problem::UseRawStore;
|
||||
@@ -87,71 +71,44 @@ struct Default2DEpilogue
|
||||
|
||||
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
|
||||
// how do we fix this ?
|
||||
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* = nullptr)
|
||||
template <typename ODramWindowTmp, typename OAccTile>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) const
|
||||
{
|
||||
const auto storeOrUpdateTile = [&](const auto& o_tile) {
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
buffer_store_fence();
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr(Problem::NumDTensor >= 1)
|
||||
{
|
||||
using elementwise_result_t = decltype(load_tile(
|
||||
make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(),
|
||||
make_tuple(Problem::kMPerBlock, Problem::kNPerBlock),
|
||||
ds_dram_windows[number<0>{}].get_window_origin(),
|
||||
o_acc_tile.get_tile_distribution())));
|
||||
|
||||
elementwise_result_t elementwise_result;
|
||||
|
||||
const auto d_tensor_tuple = generate_tuple(
|
||||
[&](auto idx) {
|
||||
const auto d_tile_window =
|
||||
make_tile_window(ds_dram_windows[idx], o_acc_tile.get_tile_distribution());
|
||||
return load_tile(d_tile_window);
|
||||
},
|
||||
number<Problem::NumDTensor>{});
|
||||
|
||||
const auto c_d_tuple = concat_tuple_of_reference(
|
||||
tie(elementwise_result, o_acc_tile),
|
||||
generate_tie([&](auto idx) -> const auto& { return d_tensor_tuple[idx]; },
|
||||
number<Problem::NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_d_tuple);
|
||||
|
||||
storeOrUpdateTile(elementwise_result);
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
{
|
||||
storeOrUpdateTile(o_acc_tile);
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& /* unused */,
|
||||
void* = nullptr) const
|
||||
{
|
||||
return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
@@ -165,9 +122,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
|
||||
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
|
||||
@@ -236,11 +192,7 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD([[maybe_unused]] number<I> index)
|
||||
{
|
||||
return GetVectorSizeC();
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; }
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -175,12 +175,6 @@ struct GemmKernelMultiD
|
||||
CK_TILE_HOST static auto
|
||||
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
|
||||
{
|
||||
// Currently MultiD kernel doesn't support k_batch > 1
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,6 @@ if(CK_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_gemm_multi_d_cshuffle test_gemm_multi_d_cshuffle.cpp)
|
||||
add_gtest_executable(test_gemm_multi_d_default2d test_gemm_multi_d_default2d.cpp)
|
||||
target_compile_definitions(test_gemm_multi_d_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_definitions(test_gemm_multi_d_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_gtest_executable(test_ck_tile_gemm_multi_d test_gemm_multi_d.cpp)
|
||||
target_compile_definitions(test_ck_tile_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
@@ -18,23 +18,22 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// Has cshuffle epilogue enabled
|
||||
// ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, CDElementWiseFn, UseCshuffleEpilog
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F16, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F16, ElementWiseAddAdd, std::true_type>,
|
||||
// ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, CDataType, CDElementWiseFn
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F16, ElementWiseAddAdd>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F16, ElementWiseAddAdd>,
|
||||
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F32, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F32, MultiplyMultiply, std::true_type>
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F32, MultiplyMultiply>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F32, MultiplyMultiply>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmMultiD, KernelTypes);
|
||||
|
||||
#include "test_gemm_multi_d_ut_cases_cshuffle.inc"
|
||||
#include "test_gemm_multi_d_ut_cases.inc"
|
||||
@@ -1,43 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_multi_d_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// Has cshuffle epilogue disabled
|
||||
// ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, CDElementWiseFn, UseCshuffleEpilog
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, BF16, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, BF16, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, F16, F16, F32, F16, ElementWiseAddAdd, std::false_type>,
|
||||
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, BF16, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, BF16, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Col, Row, Row, Row, F8, F8, F16, F16, F32, F16, MultiplyMultiply, std::false_type>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmMultiD, KernelTypes);
|
||||
|
||||
#include "test_gemm_multi_d_ut_cases_default2d.inc"
|
||||
334
test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc
Normal file
334
test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc
Normal file
@@ -0,0 +1,334 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x512x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x256x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x512x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x256x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x768x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x1280x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x1280x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_768x512x256)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x512x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x256x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x512x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x256x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x512x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x256x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x768x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x1280x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x1280x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_768x512x256)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x512x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x256x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x256x512)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x768x512)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x1280x512)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x1280x512)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_768x512x512)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x512x512)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x256x512)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x512x512)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x256x512)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x512x512)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x256x512)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x768x512)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x1280x512)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x1280x512)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_768x512x512)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x512x512)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x256x512)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
this->Run(M, N, K, kBatch);
|
||||
}
|
||||
@@ -1,211 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x512x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x256x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x512x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x256x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x768x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x1280x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x1280x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_768x512x256)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_1280x512x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_1280x256x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x512x512)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x512x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x256x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x512x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x256x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x768x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x1280x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x1280x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_768x512x256)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_1280x512x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_1280x256x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
@@ -1,211 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x512x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x256x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x512x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x256x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x768x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x1280x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x1280x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_768x512x256)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_1280x512x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_1280x256x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x512x512)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x512x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x256x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x512x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x256x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x768x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x1280x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x1280x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_768x512x256)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_1280x512x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_1280x256x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
@@ -70,21 +70,20 @@ template <typename Tuple>
|
||||
class TestCkTileGemmMultiD : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<2, Tuple>;
|
||||
using D1Layout = std::tuple_element_t<3, Tuple>;
|
||||
using ELayout = std::tuple_element_t<4, Tuple>;
|
||||
using ADataType = std::tuple_element_t<5, Tuple>;
|
||||
using BDataType = std::tuple_element_t<6, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<7, Tuple>;
|
||||
using D1DataType = std::tuple_element_t<8, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<9, Tuple>;
|
||||
using EDataType = std::tuple_element_t<10, Tuple>;
|
||||
using CDElementWiseFn = std::tuple_element_t<11, Tuple>;
|
||||
using UseCshuffleEpilog = std::tuple_element_t<12, Tuple>;
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<2, Tuple>;
|
||||
using D1Layout = std::tuple_element_t<3, Tuple>;
|
||||
using ELayout = std::tuple_element_t<4, Tuple>;
|
||||
using ADataType = std::tuple_element_t<5, Tuple>;
|
||||
using BDataType = std::tuple_element_t<6, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<7, Tuple>;
|
||||
using D1DataType = std::tuple_element_t<8, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<9, Tuple>;
|
||||
using EDataType = std::tuple_element_t<10, Tuple>;
|
||||
using CDElementWiseFn = std::tuple_element_t<11, Tuple>;
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -170,28 +169,7 @@ class TestCkTileGemmMultiD : public ::testing::Test
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
kPadM,
|
||||
kPadN,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
true,
|
||||
memory_operation>>;
|
||||
|
||||
using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
@@ -210,9 +188,6 @@ class TestCkTileGemmMultiD : public ::testing::Test
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using GemmEpilogue = std::
|
||||
conditional_t<UseCshuffleEpilog::value, CShuffleGemmEpilogue, DefaultGemmEpilogue>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -243,7 +218,6 @@ class TestCkTileGemmMultiD : public ::testing::Test
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
std::cout << "Run without SplitK" << std::endl;
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
@@ -251,19 +225,42 @@ class TestCkTileGemmMultiD : public ::testing::Test
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Run using SplitK" << std::endl;
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "For compute pipeline tail number should always be Full, but have \""
|
||||
<< tail_num << "\" which is not supported! PrefetchStages: "
|
||||
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Num K loop must be larger than number of prefetech stages."
|
||||
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
bool Run(const int M,
|
||||
void Run(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const int k_batch,
|
||||
@@ -404,6 +401,6 @@ class TestCkTileGemmMultiD : public ::testing::Test
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
|
||||
return pass;
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -31,14 +31,9 @@ DEFAULT_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM,
|
||||
|
||||
Reference in New Issue
Block a user