From 6ec0ad2758eba9beb54b995882488b0c7eb3354d Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 1 Jul 2025 13:44:18 +0000 Subject: [PATCH] Added test for gemm_add_relu wmma instance --- .../gpu/gemm_add_relu.hpp | 10 +++--- test/gemm_add/CMakeLists.txt | 5 +++ test/gemm_add/test_gemm_add_relu_wmma.cpp | 34 +++++++++++++++++++ 3 files changed, 44 insertions(+), 5 deletions(-) create mode 100644 test/gemm_add/test_gemm_add_relu_wmma.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp index b79059de9a..2cc7cab5e6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp @@ -138,14 +138,14 @@ struct DeviceOperationInstanceFactory< #elif defined(CK_USE_WMMA) // For wmma ADataType must be same as BDatatype. - (CK_ENABLE_FP16) if constexpr(is_same_v && - is_same_v && - is_same_v && is_same_v) +#if defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { - add_device_gemm_add_relu_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( op_ptrs); } } @@ -159,7 +159,7 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { - add_device_gemm_add_relu_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( op_ptrs); } } diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index 18fc3ee8f8..2e50516082 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -39,3 +39,8 @@ add_gtest_executable(test_gemm_bilinear_wmma test_gemm_bilinear_wmma.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_bilinear_wmma PRIVATE utility device_gemm_bilinear_instance) endif() + +add_gtest_executable(test_gemm_add_relu_wmma test_gemm_add_relu_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_relu_wmma PRIVATE utility device_gemm_add_relu_instance) +endif() \ No newline at end of file diff --git a/test/gemm_add/test_gemm_add_relu_wmma.cpp b/test/gemm_add/test_gemm_add_relu_wmma.cpp new file mode 100644 index 0000000000..e1e304f70f --- /dev/null +++ b/test/gemm_add/test_gemm_add_relu_wmma.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_relu_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmAddRelu : public TestGemmD0Common +{ + using ProfileCall = typename TestGemmD0Common::ProfileCall; + + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_relu_impl< + typename TestGemmD0Common::ADataType, + typename TestGemmD0Common::BDataType, + typename TestGemmD0Common::AccDataType, + typename TestGemmD0Common::D0DataType, + typename TestGemmD0Common::EDataType, + typename TestGemmD0Common::ALayout, + typename TestGemmD0Common::BLayout, + typename TestGemmD0Common::D0Layout, + typename TestGemmD0Common::ELayout>; + } +}; + +using KernelTypes = + ::testing::Types, Row>, + std::tuple, Row>>; + +TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes); +TYPED_TEST(TestGemmAddRelu, Test_BF16FP16) { this->Run(); }