Files
composable_kernel/test/gemm_split_k/test_gemm_splitk_xdl.cpp
Illia Silin ae57e5938e Split the instances by architecture. (#1223)
* parse examples inside the add_example_executable function

* fix the example 64 cmake file

* add xdl flag to the gemm_bias_softmax_gemm_permute example

* add filtering of tests based on architecture type

* enable test_grouped_gemm for gfx9 only

* enable test_transpose only for gfx9

* only linnk test_transpose if it gets built

* split the gemm instances by architectures

* split gemm_bilinear,grouped_conv_bwd_weight instances by targets

* split instances by architecture

* split grouped_conv instances by architecture

* fix clang format

* fix the if-else logic in group_conv headers

* small fix for grouped convolution instances

* fix the grouped conv bwd weight dl instances

* fix client examples

* only enable client examples 3 and 4 on gfx9

* set the gfx9 macro

* make sure the architecture macros are set by cmake

* use separate set of xdl/wmma flags for host code

* sinmplify the main cmake file

* add conv_fwd_bf8 instance declaration
2024-04-02 09:42:17 -07:00

67 lines
1.7 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_splitk_util.hpp"
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
namespace {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
{
using type = std::tuple<Xs..., Ys...>;
};
} // namespace
template <typename Tuple>
class TestGemmSplitK_MK_KN
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmSplitK_MK_NK
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmSplitK_KM_KN
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmSplitK_KM_NK
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
// ADataType, BDataType, CDataType
std::tuple< F16, F16, F16>,
std::tuple< F32, F32, F32>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmSplitK_MK_KN, KernelTypes);
TYPED_TEST_SUITE(TestGemmSplitK_MK_NK, KernelTypes);
TYPED_TEST_SUITE(TestGemmSplitK_KM_KN, KernelTypes);
TYPED_TEST_SUITE(TestGemmSplitK_KM_NK, KernelTypes);
#include "test_gemm_splitk_ut_cases.inc"