mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
* parse examples inside the add_example_executable function
* fix the example 64 cmake file
* add xdl flag to the gemm_bias_softmax_gemm_permute example
* add filtering of tests based on architecture type
* enable test_grouped_gemm for gfx9 only
* enable test_transpose only for gfx9
* only linnk test_transpose if it gets built
* split the gemm instances by architectures
* split gemm_bilinear,grouped_conv_bwd_weight instances by targets
* split instances by architecture
* split grouped_conv instances by architecture
* fix clang format
* fix the if-else logic in group_conv headers
* small fix for grouped convolution instances
* fix the grouped conv bwd weight dl instances
* fix client examples
* only enable client examples 3 and 4 on gfx9
* set the gfx9 macro
* make sure the architecture macros are set by cmake
* use separate set of xdl/wmma flags for host code
* sinmplify the main cmake file
* add conv_fwd_bf8 instance declaration
[ROCm/composable_kernel commit: ae57e5938e]
42 lines
1.8 KiB
C++
42 lines
1.8 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include "gtest/gtest.h"
|
|
#include "ck/ck.hpp"
|
|
#include "profiler/profile_gemm_add_silu_impl.hpp"
|
|
#include "test_gemm_add_xdl.hpp"
|
|
|
|
template <typename Tuple>
|
|
class TestGemmAddSilu : public TestGemmAdd<Tuple>
|
|
{
|
|
private:
|
|
using ADataType = std::tuple_element_t<0, Tuple>;
|
|
using BDataType = std::tuple_element_t<1, Tuple>;
|
|
using AccDataType = std::tuple_element_t<2, Tuple>;
|
|
using D0DataType = std::tuple_element_t<3, Tuple>;
|
|
using EDataType = std::tuple_element_t<4, Tuple>;
|
|
using ALayout = std::tuple_element_t<5, Tuple>;
|
|
using BLayout = std::tuple_element_t<6, Tuple>;
|
|
using D0Layout = std::tuple_element_t<7, Tuple>;
|
|
using ELayout = std::tuple_element_t<8, Tuple>;
|
|
|
|
constexpr static auto ProfileGemmAddSiluImpl =
|
|
ck::profiler::profile_gemm_add_silu_impl<ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
D0DataType,
|
|
EDataType,
|
|
ALayout,
|
|
BLayout,
|
|
D0Layout,
|
|
ELayout>;
|
|
|
|
decltype(ProfileGemmAddSiluImpl) GetImpl() override { return ProfileGemmAddSiluImpl; }
|
|
};
|
|
|
|
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
|
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
|
|
|
|
TYPED_TEST_SUITE(TestGemmAddSilu, KernelTypes);
|
|
TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_INT8) { this->Run(); }
|