mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Added test for gemm_add_relu wmma instance
This commit is contained in:
@@ -138,14 +138,14 @@ struct DeviceOperationInstanceFactory<
|
||||
|
||||
#elif defined(CK_USE_WMMA)
|
||||
// For wmma ADataType must be same as BDatatype.
|
||||
(CK_ENABLE_FP16) if constexpr(is_same_v<ADataType, half_t> &&
|
||||
is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_relu_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances(
|
||||
add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
@@ -159,7 +159,7 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_relu_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances(
|
||||
add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,3 +39,8 @@ add_gtest_executable(test_gemm_bilinear_wmma test_gemm_bilinear_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_bilinear_wmma PRIVATE utility device_gemm_bilinear_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_relu_wmma test_gemm_add_relu_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_relu_wmma PRIVATE utility device_gemm_add_relu_instance)
|
||||
endif()
|
||||
34
test/gemm_add/test_gemm_add_relu_wmma.cpp
Normal file
34
test/gemm_add/test_gemm_add_relu_wmma.cpp
Normal file
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_relu_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddRelu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
|
||||
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_relu_impl<
|
||||
typename TestGemmD0Common<Tuple>::ADataType,
|
||||
typename TestGemmD0Common<Tuple>::BDataType,
|
||||
typename TestGemmD0Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0Common<Tuple>::EDataType,
|
||||
typename TestGemmD0Common<Tuple>::ALayout,
|
||||
typename TestGemmD0Common<Tuple>::BLayout,
|
||||
typename TestGemmD0Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes =
|
||||
::testing::Types<std::tuple<F16, F16, F32, F16, F16, Row, Row, ck::Tuple<Row>, Row>,
|
||||
std::tuple<BF16, BF16, F32, BF16, BF16, Row, Row, ck::Tuple<Row>, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddRelu, Test_BF16FP16) { this->Run(); }
|
||||
Reference in New Issue
Block a user