mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
* parse examples inside the add_example_executable function * fix the example 64 cmake file * add xdl flag to the gemm_bias_softmax_gemm_permute example * add filtering of tests based on architecture type * enable test_grouped_gemm for gfx9 only * enable test_transpose only for gfx9 * only linnk test_transpose if it gets built * split the gemm instances by architectures * split gemm_bilinear,grouped_conv_bwd_weight instances by targets * split instances by architecture * split grouped_conv instances by architecture * fix clang format * fix the if-else logic in group_conv headers * small fix for grouped convolution instances * fix the grouped conv bwd weight dl instances * fix client examples * only enable client examples 3 and 4 on gfx9 * set the gfx9 macro * make sure the architecture macros are set by cmake * use separate set of xdl/wmma flags for host code * sinmplify the main cmake file * add conv_fwd_bf8 instance declaration
42 lines
1.8 KiB
C++
42 lines
1.8 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include "gtest/gtest.h"
|
|
#include "ck/ck.hpp"
|
|
#include "profiler/profile_gemm_add_silu_impl.hpp"
|
|
#include "test_gemm_add_xdl.hpp"
|
|
|
|
template <typename Tuple>
|
|
class TestGemmAddSilu : public TestGemmAdd<Tuple>
|
|
{
|
|
private:
|
|
using ADataType = std::tuple_element_t<0, Tuple>;
|
|
using BDataType = std::tuple_element_t<1, Tuple>;
|
|
using AccDataType = std::tuple_element_t<2, Tuple>;
|
|
using D0DataType = std::tuple_element_t<3, Tuple>;
|
|
using EDataType = std::tuple_element_t<4, Tuple>;
|
|
using ALayout = std::tuple_element_t<5, Tuple>;
|
|
using BLayout = std::tuple_element_t<6, Tuple>;
|
|
using D0Layout = std::tuple_element_t<7, Tuple>;
|
|
using ELayout = std::tuple_element_t<8, Tuple>;
|
|
|
|
constexpr static auto ProfileGemmAddSiluImpl =
|
|
ck::profiler::profile_gemm_add_silu_impl<ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
D0DataType,
|
|
EDataType,
|
|
ALayout,
|
|
BLayout,
|
|
D0Layout,
|
|
ELayout>;
|
|
|
|
decltype(ProfileGemmAddSiluImpl) GetImpl() override { return ProfileGemmAddSiluImpl; }
|
|
};
|
|
|
|
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
|
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
|
|
|
|
TYPED_TEST_SUITE(TestGemmAddSilu, KernelTypes);
|
|
TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_INT8) { this->Run(); }
|