Added test for gemm_add_relu wmma instance

This commit is contained in:
apoorva
2025-07-01 13:44:18 +00:00
parent f5843dd22b
commit 6ec0ad2758
3 changed files with 44 additions and 5 deletions

View File

@@ -138,14 +138,14 @@ struct DeviceOperationInstanceFactory<
#elif defined(CK_USE_WMMA)
// For wmma ADataType must be same as BDatatype.
(CK_ENABLE_FP16) if constexpr(is_same_v<ADataType, half_t> &&
is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_relu_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances(
add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
op_ptrs);
}
}
@@ -159,7 +159,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_relu_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances(
add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
op_ptrs);
}
}

View File

@@ -39,3 +39,8 @@ add_gtest_executable(test_gemm_bilinear_wmma test_gemm_bilinear_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_bilinear_wmma PRIVATE utility device_gemm_bilinear_instance)
endif()
add_gtest_executable(test_gemm_add_relu_wmma test_gemm_add_relu_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_add_relu_wmma PRIVATE utility device_gemm_add_relu_instance)
endif()

View File

@@ -0,0 +1,34 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_relu_impl.hpp"
#include "test_gemm_common.hpp"
template <typename Tuple>
class TestGemmAddRelu : public TestGemmD0Common<Tuple>
{
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
ProfileCall GetImpl() override
{
return ck::profiler::profile_gemm_add_relu_impl<
typename TestGemmD0Common<Tuple>::ADataType,
typename TestGemmD0Common<Tuple>::BDataType,
typename TestGemmD0Common<Tuple>::AccDataType,
typename TestGemmD0Common<Tuple>::D0DataType,
typename TestGemmD0Common<Tuple>::EDataType,
typename TestGemmD0Common<Tuple>::ALayout,
typename TestGemmD0Common<Tuple>::BLayout,
typename TestGemmD0Common<Tuple>::D0Layout,
typename TestGemmD0Common<Tuple>::ELayout>;
}
};
using KernelTypes =
::testing::Types<std::tuple<F16, F16, F32, F16, F16, Row, Row, ck::Tuple<Row>, Row>,
std::tuple<BF16, BF16, F32, BF16, BF16, Row, Row, ck::Tuple<Row>, Row>>;
TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes);
TYPED_TEST(TestGemmAddRelu, Test_BF16FP16) { this->Run(); }