Files
composable_kernel/test/ck_tile/fmha/CMakeLists.txt
Illia Silin 17e4c8eac9 [rocm-libraries] ROCm/rocm-libraries#4883 (commit 56347bb)
[CK] Disable test_fmha_fwd_fp8fp16 on gfx90a by default.
 (#4883)

## Motivation

Since gfx90a has no native support for FP8 datatype, all FP8 tests
should be disabled there by default.

## Technical Details

The test_fmha_fwd_fp8fp16 is the last failing test in CK on gfx90a with
staging compiler.

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-02-26 02:09:06 +00:00

78 lines
3.0 KiB
CMake

# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
set(TEST_NAME "test_ck_tile_fmha")
function(add_gtest_fwd test_group)
if((GPU_TARGETS MATCHES "gfx90a" AND CK_USE_FP8_ON_UNSUPPORTED_ARCH) OR GPU_TARGETS MATCHES "gfx9[45]|gfx12")
set(V_TYPES "fp16" "bf16" "fp8bf16" "fp32")
elseif((GPU_TARGETS MATCHES "gfx90a" AND NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) OR GPU_TARGETS MATCHES "gfx11")
set(V_TYPES "fp16" "bf16" "fp32")
endif()
set(CPP_TYPE_fp16 "FmhaFwdFp16")
set(CPP_TYPE_bf16 "FmhaFwdBf16")
set(CPP_TYPE_fp8bf16 "FmhaFwdFp8Bf16")
set(CPP_TYPE_fp32 "FmhaFwdFp32")
set(sources)
if(TARGET ${FMHA_FWD_INSTANCES})
get_target_property(sources ${FMHA_FWD_INSTANCES} SOURCES)
message(VERBOSE "${FMHA_FWD_INSTANCES} SOURCES ${sources}")
endif()
set(all_tests)
foreach(type ${V_TYPES})
set(name "${test_group}_${type}")
if(NOT sources MATCHES "_${type}_")
message(STATUS "No FMHA FWD instances for ${type}, skip ${name}")
continue()
endif()
add_gtest_executable(${name} test_fmha_fwd.cpp)
get_test_property(${name} LABELS COMMON_LABELS)
set_tests_properties(${name} PROPERTIES LABELS "${COMMON_LABELS};${TEST_NAME};${test_group}")
target_compile_definitions(${name} PRIVATE DataTypeConfig=${CPP_TYPE_${type}})
target_link_libraries(${name} PRIVATE ${FMHA_FWD_INSTANCES})
list(APPEND all_tests ${name})
endforeach()
message(STATUS "FMHA FWD tests: ${all_tests}")
add_custom_target(${test_group} DEPENDS ${all_tests})
endfunction()
function(add_gtest_bwd test_group)
set(V_TYPES "fp16" "bf16" "fp32")
set(CPP_TYPE_fp16 "FmhaBwdFp16")
set(CPP_TYPE_bf16 "FmhaBwdBf16")
set(CPP_TYPE_fp32 "FmhaBwdFp32")
set(sources)
if(TARGET ${FMHA_BWD_INSTANCES})
get_target_property(sources ${FMHA_BWD_INSTANCES} SOURCES)
message(VERBOSE "${FMHA_BWD_INSTANCES} SOURCES ${sources}")
endif()
set(all_tests)
foreach(type ${V_TYPES})
set(name "${test_group}_${type}")
if(NOT sources MATCHES "_${type}_")
message(STATUS "No FMHA BWD instances for ${type}, skip ${name}")
continue()
endif()
add_gtest_executable(${name} test_fmha_bwd.cpp)
get_test_property(${name} LABELS COMMON_LABELS)
set_tests_properties(${name} PROPERTIES LABELS "${COMMON_LABELS};${TEST_NAME};${test_group}")
target_compile_definitions(${name} PRIVATE DataTypeConfig=${CPP_TYPE_${type}})
target_link_libraries(${name} PRIVATE ${FMHA_BWD_INSTANCES})
list(APPEND all_tests ${name})
endforeach()
message(STATUS "FMHA BWD tests: ${all_tests}")
add_custom_target(${test_group} DEPENDS ${all_tests})
endfunction()
add_gtest_fwd(${TEST_NAME}_fwd)
add_gtest_bwd(${TEST_NAME}_bwd)
add_custom_target(${TEST_NAME} DEPENDS ${TEST_NAME}_fwd ${TEST_NAME}_bwd)