mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Implement device_grouped_gemm_fixed_nk_bias for RDNA4 ## Proposed changes Summary: - Modified implementation for grouped_gemm_fixed_nk_bias - FP16 WMMA examples - WMMA instances - Profiler for grouped_gemm_fixed_nk_bias - Add WMMA instances to existing tests **This PR depends on PR https://github.com/ROCm/rocm-libraries/pull/4299 and should be merged after it. Only the last 6 commits are in the scope of this PR.** ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [x] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [x] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [x] I have added inline documentation which enables the maintainers with understanding the motivation - [x] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
46 lines
2.3 KiB
CMake
46 lines
2.3 KiB
CMake
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
add_custom_target(test_grouped_gemm)
|
|
|
|
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
|
|
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
|
|
# the instance library if there's no instances present for the current arch.
|
|
if (CK_USE_XDL OR CK_USE_WMMA)
|
|
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp)
|
|
if(result EQUAL 0)
|
|
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
|
|
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
|
|
endif()
|
|
|
|
add_gtest_executable(test_grouped_gemm_fastgelu test_grouped_gemm_fastgelu.cpp)
|
|
if(result EQUAL 0)
|
|
target_link_libraries(test_grouped_gemm_fastgelu PRIVATE utility device_grouped_gemm_fastgelu_instance)
|
|
add_dependencies(test_grouped_gemm test_grouped_gemm_fastgelu)
|
|
endif()
|
|
|
|
add_gtest_executable(test_grouped_gemm_fixed_nk test_grouped_gemm_fixed_nk.cpp)
|
|
if(result EQUAL 0)
|
|
target_link_libraries(test_grouped_gemm_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_instance)
|
|
add_dependencies(test_grouped_gemm test_grouped_gemm_fixed_nk)
|
|
endif()
|
|
|
|
add_gtest_executable(test_grouped_gemm_fixed_nk_bias test_grouped_gemm_fixed_nk_bias.cpp)
|
|
if(result EQUAL 0)
|
|
target_link_libraries(test_grouped_gemm_fixed_nk_bias PRIVATE utility device_grouped_gemm_bias_instance)
|
|
add_dependencies(test_grouped_gemm test_grouped_gemm_fixed_nk_bias)
|
|
endif()
|
|
|
|
add_gtest_executable(test_grouped_gemm_multi_abd_fixed_nk test_grouped_gemm_multi_abd_fixed_nk.cpp)
|
|
if(result EQUAL 0)
|
|
target_link_libraries(test_grouped_gemm_multi_abd_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_multi_abd_instance)
|
|
add_dependencies(test_grouped_gemm test_grouped_gemm_multi_abd_fixed_nk)
|
|
endif()
|
|
endif()
|
|
|
|
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)
|
|
if(result EQUAL 0)
|
|
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
|
|
add_dependencies(test_grouped_gemm test_grouped_gemm_interface)
|
|
endif()
|