mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Implement device_grouped_gemm_fixed_nk_bias for RDNA4 ## Proposed changes Summary: - Modified implementation for grouped_gemm_fixed_nk_bias - FP16 WMMA examples - WMMA instances - Profiler for grouped_gemm_fixed_nk_bias - Add WMMA instances to existing tests **This PR depends on PR https://github.com/ROCm/rocm-libraries/pull/4299 and should be merged after it. Only the last 6 commits are in the scope of this PR.** ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [x] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [x] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [x] I have added inline documentation which enables the maintainers with understanding the motivation - [x] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
66 lines
3.5 KiB
CMake
66 lines
3.5 KiB
CMake
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
add_custom_target(example_grouped_gemm_xdl)
|
|
add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
|
|
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32)
|
|
|
|
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
|
|
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16)
|
|
|
|
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
|
|
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16)
|
|
|
|
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
|
|
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16)
|
|
|
|
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp)
|
|
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16)
|
|
|
|
add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp)
|
|
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16)
|
|
|
|
add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp)
|
|
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
|
|
|
|
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
|
|
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
|
|
|
|
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16_fp8 grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp)
|
|
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16_fp8)
|
|
|
|
add_example_executable(example_grouped_gemm_multiple_d_xdl_fp16 grouped_gemm_multiple_d_xdl_fp16.cpp)
|
|
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_xdl_fp16)
|
|
|
|
if(USE_BITINT_EXTENSION_INT4)
|
|
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
|
|
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
|
|
endif()
|
|
|
|
add_custom_target(example_grouped_gemm_wmma)
|
|
add_example_executable(example_grouped_gemm_wmma_splitk_fp16 grouped_gemm_wmma_splitk_fp16.cpp)
|
|
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_fp16)
|
|
|
|
add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp)
|
|
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16)
|
|
|
|
add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp)
|
|
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16)
|
|
|
|
add_example_executable(example_grouped_gemm_wmma_fixed_nk_fp16 grouped_gemm_wmma_fixed_nk_fp16.cpp)
|
|
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_fixed_nk_fp16)
|
|
|
|
add_example_executable(example_grouped_gemm_wmma_fixed_nk_bias_fp16 grouped_gemm_wmma_fixed_nk_bias_fp16.cpp)
|
|
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_fixed_nk_bias_fp16)
|
|
|
|
|
|
list(APPEND gpu_list_tf32 gfx942 gfx950)
|
|
set(target 0)
|
|
foreach(gpu IN LISTS GPU_TARGETS)
|
|
if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0)
|
|
add_example_executable(example_grouped_gemm_xdl_fp32_tf32 grouped_gemm_xdl_fp32_tf32.cpp)
|
|
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32_tf32)
|
|
set(target 1)
|
|
endif()
|
|
endforeach()
|