mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
Wmma support for multiple Ds based GEMMs (#2613)
* Fixed cmake errors related to gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8" * Fixed cmake build errors related to test_fp8 * Updates to support mixed precision (cherry picked from commit e65d71180393e7b66169c56565a6bac740427de6) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Adding support for RRR, F8xF16xF16 gemm_universal_wmma - wip (cherry picked from commit f8c06322df0abcbd5945a56cdf5bffe56480f9f0) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Added support for F8xF16xF16 to gemm_wmma_universal (cherry picked from commit 15c851de6daa513a12c2e3af299bab0176175fb5) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Added support for F16xF8xF16 to gemm_wmma_universal * Added support for BF16xI4xBF16 to gemm_wmma_universal (cherry picked from commit c6a4a69d2d43d59bae8bdabfae80d648646f217e) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Added support for F16xI4xF16 to gemm_wmma_universal * Fixed IsSupportedArgument to check ComputeTypeA, ComputeTypeB instead of ADataType, BDataType * Added missing test class for FP16_KM_NK * Pre-commit hooks fixes * Added padding instances for f16xf16xf16 * Fixed cmake errors related to gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8" (cherry picked from commit5bdc993dbf) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Fixed cmake build errors related to test_fp8 (cherry picked from commit12176616b6) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Ammending changes for adding support for padding instances for f16xf16xf16 * Fixes for padding instances for f16xf16xf16 * Added padding instances for bf16xbf16, f8xf8 * Added packed instances for bf16xi4xbf16 * Added padding instances for f8xf16xf16 * Added padding instances for f16xf8xf16, f16xi4xf16 * Fixed typos for bf16xbf16xbf16 padding instances * Fixed typos for padded instances * Added tests for fp16, KM_KN and KM_NK * Padding not supported for when BDataType is pk_i4_t. Added fix for correct check and removed padding instances. * Fixed typos * Updated the set of tests for FP16 * Updated the set of tests for FP16 * Fix typo * Moved f16xi4 test under the correct data layout group * example for gemm_universal_bf16 * Adding examples for gemm_wmma instances * Added the missing parameters * Fixed review comments and added executable to cmakeLists * Fixing clang format * Fixing build erros * Fixed compilation failure. * Modified some code as per gemm_universal_examples * Fixed the gemm specialization error * Fixed the build errors. * Fix strides of a/b_thread_desc The descriptors are larger than needed (even though the compiler don't alloc registers for unused values). * Load in M/NRepeat dims with thread copy's slice instead of a loop * Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation * Implement Intrawave and Interwave variants of pipeline v1 * Add instances for Interwave and Intrawave v1 * Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0 * Remove instances that are too slow (mostly because of register spilling) * Add a workaround for fp8/bf8->f32 packed conversion issue * Add instances for Interwave and Intrawave v1 * Enable profiling of mixed precision with f8 and int4 on WMMA * Fix segfault in profiler when B is pk_i4_t b_device_buf's size in bytes is larger than b_k_n_permute so b_device_buf.ToDevice reads out-of-bounds. * Remove instances that are too slow (mostly because of register spilling) * Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations * Add test case for bf16_i4 * Add missing Regular tests * Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS They take more than 30 seconds * Fix a bug that fp16_i4 validation passes only with PermuteB A permutation required by conversion from pk_i4_t to half_t does not depend on PermuteB, they can be used independently. * Use PermuteB with f16_i4 in most instances (as xdl) Some instances use PermuteB = false for checking correctness. See also the previous commit. * Fix cache flushing for pk_i4 * Add mixed precision examples * Disable all tests and instances with f8 on gfx11 Even though f8_f16 and f16_f8 don't require f8 WMMA instructions, gfx11 still lacks hardware instructions for fast f8->f32 conversion. * Add FP16 KM_NK and KM_KN test suites for XDL These tests were added to common .inc for better testing of WMMA instances * Support multiple D in GridwiseGemm_wmma_cshuffle_v3 DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters. * Use ThreadGroupTensorSliceTransfer_v7r3 * Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support * Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma * Implement DeviceGemmMultipleD_Wmma_CShuffleV3 * Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3 * Prepare gemma_add tests for adding wmma * Add gemm_add_fastgelu instances and test * Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use DeviceGemmMultipleDSplitK instances there. * removed unnecessary ck parts from compilation * initial gemm_add_multiply instance implementations * fixed profiler help message for gemm_add_multiply * improved multiply_add profiler layout help * fixed template arguments for test instances * added test for gemm_add_multiply * Support multiple D in GridwiseGemm_wmma_cshuffle_v3 DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters. * Use ThreadGroupTensorSliceTransfer_v7r3 * Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support * Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma * Implement DeviceGemmMultipleD_Wmma_CShuffleV3 * Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3 * Prepare gemma_add tests for adding wmma * Add gemm_add_fastgelu instances and test * Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use DeviceGemmMultipleDSplitK instances there. * switched to splitK interface * log print added to splitk benchmarks * revert main cmake comments * newline change reverted * added add_fastgelu instances * revert unintended change in xdl add_fastgelu * created gemm_add_add_fastgelu instances * created fastegelu instances * added tests for all splitk fastgelus * Added tests. * multiply_add instances created * updates to add_multiply splitk instances * splitk xdl test fixes * added wmma multiply_multiply instances * fixed ONLY_XDL_AND_WMMA_KERNELS tag * Added gemm_add examples for wmma v1 and v3 * fixed / workarounded i8 instances * Modified the v3 code to added one fp16 bxdl instance. * added bf16 xdl instance. * adding gemm_add wmma_cshuffle and other support (cherry picked from commit ec447e7f564095ea969eddc39ec77b843aa52976) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * add instances into camkelists (cherry picked from commit 23bf2d2771c939ea3ca7f493433c55255bffd08e) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * This is work in progress, edited the template parameters in order to build (cherry picked from commit b4fde8a3314cb44659c4bbda35f1a0133c63dc41) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * temp work saved, changed the BDataType to f16 or bf16 since wmma currently not support non-equal A and B datatype (cherry picked from commit 22fbd68f1db458ab50780a394ee2544c7a1484d1) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * added datatype and use clang-format-12 (cherry picked from commit ae4e853682ef1bb27784b2f965b4a66b3751ceec) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * Fixing build errors * Added instances for v3 * Adding instances and executables * Code update of template parameters modified. * Renamed file. * Added tests. * resolved error tests. * Fixing build errors * Updated comments * removed the changes as per the MR review comment. * Updated tests. * fp8 instances - not tested * Restored the Cmake file that was reverted by mistake during rebase. * fixed wmma_op test * Updated comments. * Updated the template parameter description * fixed rdna4 instances * fixed back compatibility on gfx11 * cleanups * fix ckProfiler * one more cmake fix * added fp8 instances * Updated tests to ad BF16 instances as per review comment * Added include file and cleaned up(as per review comment) * Updated and optimized the example code for all types. * Fixed clang format * Resolve "Implement `device_gemm_bilinear` for RDNA4" * test generalization to handle FP16 shuffle better * added missing changes * Added bf16 wmma instance for add_relu * Added f16 wmma instance and corrected bf16 instance errors. * Added instances to Cmake * Modified the template parameters to make the instances work. * Fixed typo in profiler * Added v3 instances for gemm_add_relu * addressed core review comments * Added test for gemm_add_relu wmma instance * Cleaned up the code. * Added examples for gemm_add_relu * Fixing typo to resolve build errors. * Fixes applied to fix the precision loss. * fix billinear test after merge * Removed the old wmma instances. * Added wrapper and renamed the wmma_v3 instances * Updated copyrights and added wrappers. * Fixes applied according to review comments * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Robin Voetter <robin@streamhpc.com> * Removed the old wmma instances. * Updated wrapper for the v3 instances * removed the old wmma examples * Renamed the v3 instances * Deleted the gtest file added by mistake. * Updated thge profiler with wrapper * Fixed test errors. * Fixed the review comments * Fixed the if condition MACROS. * REVERTED THE PROFILER CHANGES * Revert "REVERTED THE PROFILER CHANGES" This reverts commit21cb98546c. * Revert "Fixed test errors." This reverts commit13efcc6fe1. * Revert "Updated thge profiler with wrapper" This reverts commit536f86661d. * Added missing wrapper instances * Updated copyrights. * Fixed typo. * Fixed copyrights. * Updated copyrights. * updated copyrights. * comments on the atomics workaround * fixed cmake comment * Fix bug from merge * clang-format-18 * Fix compilation error * Fix linking error * Fix bug in add and add_relu examples * Fix error including file (typo) * Quick fix to compile examples for different targets * Fix for multi target * implemented f16 and bf16 instances for gemm_silu * addressed review comments * addressed review comments * Fix clang format * Fix clang format --------- Co-authored-by: Anca Hamuraru <anca@streamhpc.com> Co-authored-by: apoorva <apoorva@streamhpc.com> Co-authored-by: Anton Gorenko <anton@streamhpc.com> Co-authored-by: Zoltan Lakatos <zoltan.lakatos@streamhpc.com> Co-authored-by: Cenxuan <cenxuan@streamhpc.com> Co-authored-by: Robin Voetter <robin@streamhpc.com> Co-authored-by: Kiefer van Teutem <kiefer.van.teutem@streamhpc.com> Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> [ROCm/composable_kernel commit:b740380906]
This commit is contained in:
@@ -1,19 +1,71 @@
|
||||
add_gtest_executable(test_gemm_add test_gemm_add_xdl.hpp)
|
||||
# Implements test instances for MultipleD with xdl and wmma support.
|
||||
|
||||
add_gtest_executable(test_gemm_add_xdl test_gemm_add_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance)
|
||||
target_link_libraries(test_gemm_add_xdl PRIVATE utility device_gemm_add_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_relu test_gemm_add_relu_xdl.cpp)
|
||||
add_gtest_executable(test_gemm_add_relu_xdl test_gemm_add_relu_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance)
|
||||
target_link_libraries(test_gemm_add_relu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_silu test_gemm_add_silu_xdl.cpp)
|
||||
add_gtest_executable(test_gemm_add_silu_xdl test_gemm_add_silu_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
|
||||
target_link_libraries(test_gemm_add_silu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu_xdl.cpp)
|
||||
add_gtest_executable(test_gemm_add_silu_wmma test_gemm_add_silu_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance)
|
||||
target_link_libraries(test_gemm_add_silu_wmma PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_fastgelu_xdl test_gemm_add_fastgelu_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_fastgelu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_fastgelu_wmma test_gemm_fastgelu_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_fastgelu_wmma PRIVATE utility device_gemm_fastgelu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_fastgelu_wmma test_gemm_add_fastgelu_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_fastgelu_wmma PRIVATE utility device_gemm_add_fastgelu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_wmma test_gemm_add_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_wmma PRIVATE utility device_gemm_add_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_add_fastgelu_wmma test_gemm_add_add_fastgelu_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_add_fastgelu_wmma PRIVATE utility device_gemm_add_add_fastgelu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_multiply_multiply_wmma test_gemm_multiply_multiply_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_multiply_multiply_wmma PRIVATE utility device_gemm_multiply_multiply_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_multiply_wmma test_gemm_add_multiply_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_multiply_wmma PRIVATE utility device_gemm_add_multiply_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_multiply_add_wmma test_gemm_multiply_add_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_multiply_add_wmma PRIVATE utility device_gemm_multiply_add_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_bilinear_wmma test_gemm_bilinear_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_bilinear_wmma PRIVATE utility device_gemm_bilinear_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_relu_wmma test_gemm_add_relu_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_relu_wmma PRIVATE utility device_gemm_add_relu_instance)
|
||||
endif()
|
||||
39
test/gemm_add/test_gemm_add_add_fastgelu_wmma.cpp
Normal file
39
test/gemm_add/test_gemm_add_add_fastgelu_wmma.cpp
Normal file
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_add_fastgelu_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddAddFastgelu : public TestGemmD0D1Common<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmD0D1Common<Tuple>::ProfileCall;
|
||||
|
||||
public:
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_add_fastgelu_impl<
|
||||
typename TestGemmD0D1Common<Tuple>::ADataType,
|
||||
typename TestGemmD0D1Common<Tuple>::BDataType,
|
||||
typename TestGemmD0D1Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0D1Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0D1Common<Tuple>::D1DataType,
|
||||
typename TestGemmD0D1Common<Tuple>::EDataType,
|
||||
typename TestGemmD0D1Common<Tuple>::ALayout,
|
||||
typename TestGemmD0D1Common<Tuple>::BLayout,
|
||||
typename TestGemmD0D1Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0D1Common<Tuple>::D1Layout,
|
||||
typename TestGemmD0D1Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes =
|
||||
::testing::Types<std::tuple<F16, F16, F32, F16, F16, F16, Row, Row, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, F16, Row, Col, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, F16, Col, Row, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, F16, Col, Col, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAddAddFastgelu, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddAddFastgelu, Test_FP16FP16) { this->Run(); }
|
||||
35
test/gemm_add/test_gemm_add_fastgelu_wmma.cpp
Normal file
35
test/gemm_add/test_gemm_add_fastgelu_wmma.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_fastgelu_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddFastgelu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
|
||||
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_fastgelu_impl<
|
||||
typename TestGemmD0Common<Tuple>::ADataType,
|
||||
typename TestGemmD0Common<Tuple>::BDataType,
|
||||
typename TestGemmD0Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0Common<Tuple>::EDataType,
|
||||
typename TestGemmD0Common<Tuple>::ALayout,
|
||||
typename TestGemmD0Common<Tuple>::BLayout,
|
||||
typename TestGemmD0Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, Row, Col, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, Col, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, Col, Col, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAddFastgelu, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddFastgelu, Test_FP16FP16) { this->Run(); }
|
||||
@@ -1,37 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_fastgelu_impl.hpp"
|
||||
#include "test_gemm_add_xdl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddFastgelu : public TestGemmAdd<Tuple>
|
||||
class TestGemmAddFastgelu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
private:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<2, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<3, Tuple>;
|
||||
using EDataType = std::tuple_element_t<4, Tuple>;
|
||||
using ALayout = std::tuple_element_t<5, Tuple>;
|
||||
using BLayout = std::tuple_element_t<6, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<7, Tuple>;
|
||||
using ELayout = std::tuple_element_t<8, Tuple>;
|
||||
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
|
||||
|
||||
constexpr static auto ProfileGemmAddFastgeluImpl =
|
||||
ck::profiler::profile_gemm_add_fastgelu_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
D0DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
ELayout>;
|
||||
|
||||
decltype(ProfileGemmAddFastgeluImpl) GetImpl() override { return ProfileGemmAddFastgeluImpl; }
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_fastgelu_impl<
|
||||
typename TestGemmD0Common<Tuple>::ADataType,
|
||||
typename TestGemmD0Common<Tuple>::BDataType,
|
||||
typename TestGemmD0Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0Common<Tuple>::EDataType,
|
||||
typename TestGemmD0Common<Tuple>::ALayout,
|
||||
typename TestGemmD0Common<Tuple>::BLayout,
|
||||
typename TestGemmD0Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
||||
|
||||
39
test/gemm_add/test_gemm_add_multiply_wmma.cpp
Normal file
39
test/gemm_add/test_gemm_add_multiply_wmma.cpp
Normal file
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
#include "profiler/profile_gemm_add_multiply_impl.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddMultiply : public TestGemmD0D1Common<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmD0D1Common<Tuple>::ProfileCall;
|
||||
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_multiply_impl<
|
||||
typename TestGemmD0D1Common<Tuple>::ADataType,
|
||||
typename TestGemmD0D1Common<Tuple>::BDataType,
|
||||
typename TestGemmD0D1Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0D1Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0D1Common<Tuple>::D1DataType,
|
||||
typename TestGemmD0D1Common<Tuple>::EDataType,
|
||||
typename TestGemmD0D1Common<Tuple>::ALayout,
|
||||
typename TestGemmD0D1Common<Tuple>::BLayout,
|
||||
typename TestGemmD0D1Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0D1Common<Tuple>::D1Layout,
|
||||
typename TestGemmD0D1Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes =
|
||||
::testing::Types<std::tuple<F16, F16, F32, F16, F16, F16, Row, Col, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, F16, Row, Row, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, F16, Col, Col, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, F16, Col, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAddMultiply, KernelTypes);
|
||||
// Due to F16 shuffle data type tests has to run with limited K size. Change instances to FP32?
|
||||
TYPED_TEST(TestGemmAddMultiply, Test) { this->Run({{16, 32, 64}, {2048, 1024, 256}}); }
|
||||
33
test/gemm_add/test_gemm_add_relu_wmma.cpp
Normal file
33
test/gemm_add/test_gemm_add_relu_wmma.cpp
Normal file
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_relu_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddRelu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
|
||||
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_relu_impl<
|
||||
typename TestGemmD0Common<Tuple>::ADataType,
|
||||
typename TestGemmD0Common<Tuple>::BDataType,
|
||||
typename TestGemmD0Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0Common<Tuple>::EDataType,
|
||||
typename TestGemmD0Common<Tuple>::ALayout,
|
||||
typename TestGemmD0Common<Tuple>::BLayout,
|
||||
typename TestGemmD0Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, BF16, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddRelu, Test_BF16FP16) { this->Run(); }
|
||||
@@ -1,37 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_relu_impl.hpp"
|
||||
#include "test_gemm_add_xdl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddRelu : public TestGemmAdd<Tuple>
|
||||
class TestGemmAddRelu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
private:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<2, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<3, Tuple>;
|
||||
using EDataType = std::tuple_element_t<4, Tuple>;
|
||||
using ALayout = std::tuple_element_t<5, Tuple>;
|
||||
using BLayout = std::tuple_element_t<6, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<7, Tuple>;
|
||||
using ELayout = std::tuple_element_t<8, Tuple>;
|
||||
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
|
||||
|
||||
constexpr static auto ProfileGemmAddReluImpl =
|
||||
ck::profiler::profile_gemm_add_relu_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
D0DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
ELayout>;
|
||||
|
||||
decltype(ProfileGemmAddReluImpl) GetImpl() override { return ProfileGemmAddReluImpl; }
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_relu_impl<
|
||||
typename TestGemmD0Common<Tuple>::ADataType,
|
||||
typename TestGemmD0Common<Tuple>::BDataType,
|
||||
typename TestGemmD0Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0Common<Tuple>::EDataType,
|
||||
typename TestGemmD0Common<Tuple>::ALayout,
|
||||
typename TestGemmD0Common<Tuple>::BLayout,
|
||||
typename TestGemmD0Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
||||
|
||||
34
test/gemm_add/test_gemm_add_silu_wmma.cpp
Normal file
34
test/gemm_add/test_gemm_add_silu_wmma.cpp
Normal file
@@ -0,0 +1,34 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_silu_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddSilu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
|
||||
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_silu_impl<
|
||||
typename TestGemmD0Common<Tuple>::ADataType,
|
||||
typename TestGemmD0Common<Tuple>::BDataType,
|
||||
typename TestGemmD0Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0Common<Tuple>::EDataType,
|
||||
typename TestGemmD0Common<Tuple>::ALayout,
|
||||
typename TestGemmD0Common<Tuple>::BLayout,
|
||||
typename TestGemmD0Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, BF16, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAddSilu, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_BF16FP16) { this->Run(); }
|
||||
@@ -1,37 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_silu_impl.hpp"
|
||||
#include "test_gemm_add_xdl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddSilu : public TestGemmAdd<Tuple>
|
||||
class TestGemmAddSilu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
private:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<2, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<3, Tuple>;
|
||||
using EDataType = std::tuple_element_t<4, Tuple>;
|
||||
using ALayout = std::tuple_element_t<5, Tuple>;
|
||||
using BLayout = std::tuple_element_t<6, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<7, Tuple>;
|
||||
using ELayout = std::tuple_element_t<8, Tuple>;
|
||||
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
|
||||
|
||||
constexpr static auto ProfileGemmAddSiluImpl =
|
||||
ck::profiler::profile_gemm_add_silu_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
D0DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
ELayout>;
|
||||
|
||||
decltype(ProfileGemmAddSiluImpl) GetImpl() override { return ProfileGemmAddSiluImpl; }
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_silu_impl<
|
||||
typename TestGemmD0Common<Tuple>::ADataType,
|
||||
typename TestGemmD0Common<Tuple>::BDataType,
|
||||
typename TestGemmD0Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0Common<Tuple>::EDataType,
|
||||
typename TestGemmD0Common<Tuple>::ALayout,
|
||||
typename TestGemmD0Common<Tuple>::BLayout,
|
||||
typename TestGemmD0Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
||||
|
||||
32
test/gemm_add/test_gemm_add_wmma.cpp
Normal file
32
test/gemm_add/test_gemm_add_wmma.cpp
Normal file
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAdd : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
|
||||
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_impl<typename TestGemmD0Common<Tuple>::ADataType,
|
||||
typename TestGemmD0Common<Tuple>::BDataType,
|
||||
typename TestGemmD0Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0Common<Tuple>::EDataType,
|
||||
typename TestGemmD0Common<Tuple>::ALayout,
|
||||
typename TestGemmD0Common<Tuple>::BLayout,
|
||||
typename TestGemmD0Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, BF16, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAdd, KernelTypes);
|
||||
TYPED_TEST(TestGemmAdd, Test_BF16FP16) { this->Run(); }
|
||||
32
test/gemm_add/test_gemm_add_xdl.cpp
Normal file
32
test/gemm_add/test_gemm_add_xdl.cpp
Normal file
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAdd : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
|
||||
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_impl<typename TestGemmD0Common<Tuple>::ADataType,
|
||||
typename TestGemmD0Common<Tuple>::BDataType,
|
||||
typename TestGemmD0Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0Common<Tuple>::EDataType,
|
||||
typename TestGemmD0Common<Tuple>::ALayout,
|
||||
typename TestGemmD0Common<Tuple>::BLayout,
|
||||
typename TestGemmD0Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAdd, KernelTypes);
|
||||
TYPED_TEST(TestGemmAdd, Test_BF16FP16_INT8) { this->Run(); }
|
||||
@@ -1,72 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_impl.hpp"
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using I8 = int8_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAdd : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<2, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<3, Tuple>;
|
||||
using EDataType = std::tuple_element_t<4, Tuple>;
|
||||
using ALayout = std::tuple_element_t<5, Tuple>;
|
||||
using BLayout = std::tuple_element_t<6, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<7, Tuple>;
|
||||
using ELayout = std::tuple_element_t<8, Tuple>;
|
||||
|
||||
constexpr static auto ProfileGemmAddImpl = ck::profiler::profile_gemm_add_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
D0DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
ELayout>;
|
||||
|
||||
virtual decltype(ProfileGemmAddImpl) GetImpl() { return ProfileGemmAddImpl; }
|
||||
|
||||
void Run()
|
||||
{
|
||||
std::vector<std::vector<ck::index_t>> lengths = {
|
||||
{16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}};
|
||||
|
||||
bool all_success = true;
|
||||
|
||||
for(auto length : lengths)
|
||||
{
|
||||
int M = length[0];
|
||||
int N = length[1];
|
||||
int K = length[2];
|
||||
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
int StrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
|
||||
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
|
||||
|
||||
all_success =
|
||||
all_success &
|
||||
GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAdd, KernelTypes);
|
||||
TYPED_TEST(TestGemmAdd, Test_BF16FP16_INT8) { this->Run(); }
|
||||
69
test/gemm_add/test_gemm_bilinear_wmma.cpp
Normal file
69
test/gemm_add/test_gemm_bilinear_wmma.cpp
Normal file
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_bilinear_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmBilinear : public ::testing::Test
|
||||
{
|
||||
private:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<2, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<3, Tuple>;
|
||||
using EDataType = std::tuple_element_t<4, Tuple>;
|
||||
using ALayout = std::tuple_element_t<5, Tuple>;
|
||||
using BLayout = std::tuple_element_t<6, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<7, Tuple>;
|
||||
using ELayout = std::tuple_element_t<8, Tuple>;
|
||||
|
||||
constexpr static auto ProfileGemmBilinearImpl =
|
||||
ck::profiler::profile_gemm_bilinear_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
D0DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
ELayout>;
|
||||
|
||||
public:
|
||||
void Run(TestMatrixSizes const& lengths)
|
||||
{
|
||||
bool all_success = true;
|
||||
|
||||
for(auto length : lengths)
|
||||
{
|
||||
int M = length[0];
|
||||
int N = length[1];
|
||||
int K = length[2];
|
||||
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
int StrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
|
||||
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
|
||||
|
||||
all_success =
|
||||
all_success &
|
||||
ProfileGemmBilinearImpl(
|
||||
1, 1, false, true, M, N, K, StrideA, StrideB, StrideD0, StrideE, 1.F, 1.F);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, Row, Col, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, Col, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, Col, Col, Row, Row>,
|
||||
std::tuple<I8, I8, I32, I8, I8, Row, Row, Row, Row>,
|
||||
std::tuple<I8, I8, I32, I8, I8, Row, Col, Row, Row>,
|
||||
std::tuple<I8, I8, I32, I8, I8, Col, Row, Row, Row>,
|
||||
std::tuple<I8, I8, I32, I8, I8, Col, Col, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmBilinear, KernelTypes);
|
||||
TYPED_TEST(TestGemmBilinear, Test) { this->Run(DefaultTestMatrixSizes); }
|
||||
146
test/gemm_add/test_gemm_common.hpp
Normal file
146
test/gemm_add/test_gemm_common.hpp
Normal file
@@ -0,0 +1,146 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using F8 = ck::f8_t;
|
||||
|
||||
// M, N, K
|
||||
using TestMatrixSizes = std::vector<std::vector<ck::index_t>>;
|
||||
|
||||
static const TestMatrixSizes DefaultTestMatrixSizes = {
|
||||
{16, 32, 64}, {512, 2048, 4096}, {2048, 1024, 16}};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmCommon : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<2, Tuple>;
|
||||
using EDataType = std::tuple_element_t<3, Tuple>;
|
||||
using ALayout = std::tuple_element_t<4, Tuple>;
|
||||
using BLayout = std::tuple_element_t<5, Tuple>;
|
||||
using ELayout = std::tuple_element_t<6, Tuple>;
|
||||
|
||||
using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int);
|
||||
|
||||
virtual ProfileCall GetImpl() = 0;
|
||||
|
||||
void Run(const TestMatrixSizes& lengths = DefaultTestMatrixSizes)
|
||||
{
|
||||
bool all_success = true;
|
||||
|
||||
for(auto length : lengths)
|
||||
{
|
||||
int M = length[0];
|
||||
int N = length[1];
|
||||
int K = length[2];
|
||||
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
|
||||
|
||||
all_success =
|
||||
all_success & GetImpl()(1, 1, false, true, M, N, K, StrideA, StrideB, StrideE);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmD0Common : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<2, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<3, Tuple>;
|
||||
using EDataType = std::tuple_element_t<4, Tuple>;
|
||||
using ALayout = std::tuple_element_t<5, Tuple>;
|
||||
using BLayout = std::tuple_element_t<6, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<7, Tuple>;
|
||||
using ELayout = std::tuple_element_t<8, Tuple>;
|
||||
|
||||
using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int, int);
|
||||
|
||||
virtual ProfileCall GetImpl() = 0;
|
||||
|
||||
void Run(const TestMatrixSizes& lengths = DefaultTestMatrixSizes)
|
||||
{
|
||||
bool all_success = true;
|
||||
|
||||
for(auto length : lengths)
|
||||
{
|
||||
int M = length[0];
|
||||
int N = length[1];
|
||||
int K = length[2];
|
||||
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
int StrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
|
||||
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
|
||||
|
||||
all_success =
|
||||
all_success &
|
||||
GetImpl()(1, 1, false, true, M, N, K, StrideA, StrideB, StrideD0, StrideE);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmD0D1Common : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<2, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<3, Tuple>;
|
||||
using D1DataType = std::tuple_element_t<4, Tuple>;
|
||||
using EDataType = std::tuple_element_t<5, Tuple>;
|
||||
using ALayout = std::tuple_element_t<6, Tuple>;
|
||||
using BLayout = std::tuple_element_t<7, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<8, Tuple>;
|
||||
using D1Layout = std::tuple_element_t<9, Tuple>;
|
||||
using ELayout = std::tuple_element_t<10, Tuple>;
|
||||
|
||||
using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int, int, int);
|
||||
|
||||
virtual ProfileCall GetImpl() = 0;
|
||||
|
||||
void Run(const TestMatrixSizes& lengths = DefaultTestMatrixSizes)
|
||||
{
|
||||
bool all_success = true;
|
||||
|
||||
for(auto length : lengths)
|
||||
{
|
||||
int M = length[0];
|
||||
int N = length[1];
|
||||
int K = length[2];
|
||||
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
int StrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
|
||||
int StrideD1 = ck::is_same_v<D1Layout, Row> ? N : M;
|
||||
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
|
||||
|
||||
all_success =
|
||||
all_success &
|
||||
GetImpl()(
|
||||
1, 1, false, true, M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
32
test/gemm_add/test_gemm_fastgelu_wmma.cpp
Normal file
32
test/gemm_add/test_gemm_fastgelu_wmma.cpp
Normal file
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_fastgelu_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmFastgelu : public TestGemmCommon<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmCommon<Tuple>::ProfileCall;
|
||||
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_fastgelu_impl<typename TestGemmCommon<Tuple>::ADataType,
|
||||
typename TestGemmCommon<Tuple>::BDataType,
|
||||
typename TestGemmCommon<Tuple>::AccDataType,
|
||||
typename TestGemmCommon<Tuple>::EDataType,
|
||||
typename TestGemmCommon<Tuple>::ALayout,
|
||||
typename TestGemmCommon<Tuple>::BLayout,
|
||||
typename TestGemmCommon<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, Row, Col, Row>,
|
||||
std::tuple<F16, F16, F32, F16, Col, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, Col, Col, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmFastgelu, KernelTypes);
|
||||
TYPED_TEST(TestGemmFastgelu, Test_BF16FP16) { this->Run(); }
|
||||
40
test/gemm_add/test_gemm_multiply_add_wmma.cpp
Normal file
40
test/gemm_add/test_gemm_multiply_add_wmma.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_gemm_common.hpp"
|
||||
#include "profiler/profile_gemm_multiply_add_impl.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmMultiplyAdd : public TestGemmD0D1Common<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmD0D1Common<Tuple>::ProfileCall;
|
||||
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_multiply_add_impl<
|
||||
typename TestGemmD0D1Common<Tuple>::ADataType,
|
||||
typename TestGemmD0D1Common<Tuple>::BDataType,
|
||||
typename TestGemmD0D1Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0D1Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0D1Common<Tuple>::D1DataType,
|
||||
typename TestGemmD0D1Common<Tuple>::EDataType,
|
||||
typename TestGemmD0D1Common<Tuple>::ALayout,
|
||||
typename TestGemmD0D1Common<Tuple>::BLayout,
|
||||
typename TestGemmD0D1Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0D1Common<Tuple>::D1Layout,
|
||||
typename TestGemmD0D1Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<
|
||||
#ifdef CK_USE_WMMA_FP8
|
||||
std::tuple<F16, F8, F32, F32, F32, F16, Row, Col, Row, Row, Row>,
|
||||
std::tuple<F16, F8, F32, F32, F32, F16, Row, Row, Row, Row, Row>,
|
||||
#endif
|
||||
std::tuple<F16, F16, F32, F16, F16, F16, Row, Col, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, F16, Row, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmMultiplyAdd, KernelTypes);
|
||||
// Due to F16 shuffle data type tests has to run with limited K size. Change instances to FP32?
|
||||
TYPED_TEST(TestGemmMultiplyAdd, Test) { this->Run({{16, 32, 64}, {2048, 1024, 256}}); }
|
||||
99
test/gemm_add/test_gemm_multiply_multiply_wmma.cpp
Normal file
99
test/gemm_add/test_gemm_multiply_multiply_wmma.cpp
Normal file
@@ -0,0 +1,99 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_multiply_multiply_impl.hpp"
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
using F8 = ck::f8_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmMultiplyMultiply : public ::testing::Test
|
||||
{
|
||||
private:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<2, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<3, Tuple>;
|
||||
using D1DataType = std::tuple_element_t<4, Tuple>;
|
||||
using EDataType = std::tuple_element_t<5, Tuple>;
|
||||
using ALayout = std::tuple_element_t<6, Tuple>;
|
||||
using BLayout = std::tuple_element_t<7, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<8, Tuple>;
|
||||
using D1Layout = std::tuple_element_t<9, Tuple>;
|
||||
using ELayout = std::tuple_element_t<10, Tuple>;
|
||||
|
||||
constexpr static auto ProfileGemmMultiplyMultiplyImpl =
|
||||
ck::profiler::profile_gemm_multiply_multiply_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType, // ComputeDataType for
|
||||
// reference gemm
|
||||
AccDataType,
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
D1Layout,
|
||||
ELayout>;
|
||||
|
||||
public:
|
||||
void Run()
|
||||
{
|
||||
std::vector<std::vector<ck::index_t>> lengths = {
|
||||
{16, 32, 64}, {512, 2048, 4096}, {2048, 1024, 16}};
|
||||
|
||||
bool all_success = true;
|
||||
|
||||
for(auto length : lengths)
|
||||
{
|
||||
int M = length[0];
|
||||
int N = length[1];
|
||||
int K = length[2];
|
||||
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
int StrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
|
||||
int StrideD1 = ck::is_same_v<D1Layout, Row> ? N : M;
|
||||
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
|
||||
|
||||
all_success = all_success & ProfileGemmMultiplyMultiplyImpl(1,
|
||||
1,
|
||||
false,
|
||||
true,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideD0,
|
||||
StrideD1,
|
||||
StrideE,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<
|
||||
#ifdef CK_USE_WMMA_FP8
|
||||
std::tuple<F8, F8, F32, F32, F32, F16, Row, Col, Row, Col, Row>,
|
||||
std::tuple<F8, F8, F32, F32, F32, BF16, Row, Col, Row, Col, Row>,
|
||||
#endif
|
||||
std::tuple<I8, I8, I32, F16, F16, F16, Row, Col, Row, Col, Row>,
|
||||
std::tuple<I8, I8, I32, F32, F32, BF16, Row, Col, Row, Col, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmMultiplyMultiply, KernelTypes);
|
||||
TYPED_TEST(TestGemmMultiplyMultiply, Test) { this->Run(); }
|
||||
Reference in New Issue
Block a user