diff --git a/test/ck_tile/fmha/CMakeLists.txt b/test/ck_tile/fmha/CMakeLists.txt index 8e5cce4c0b..ca7b7b6324 100644 --- a/test/ck_tile/fmha/CMakeLists.txt +++ b/test/ck_tile/fmha/CMakeLists.txt @@ -5,35 +5,50 @@ endif() set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") +set(TEST_NAME "test_ck_tile_fmha") -add_gtest_executable(test_ck_tile_fmha_bwd_fp32 test_fmha_bwd_fp32.cpp) -target_link_libraries(test_ck_tile_fmha_bwd_fp32 PRIVATE ${FMHA_BWD_INSTANCES}) +function(add_gtest_fwd test_group) + set(V_TYPES "fp16" "bf16" "fp8" "fp32") + set(CPP_TYPE_fp16 "FmhaFwdFp16") + set(CPP_TYPE_bf16 "FmhaFwdBf16") + set(CPP_TYPE_fp8 "FmhaFwdFp8") + set(CPP_TYPE_fp32 "FmhaFwdFp32") -add_gtest_executable(test_ck_tile_fmha_bwd_bf16 test_fmha_bwd_bf16.cpp) -target_link_libraries(test_ck_tile_fmha_bwd_bf16 PRIVATE ${FMHA_BWD_INSTANCES}) + set(all_tests) + foreach(type ${V_TYPES}) + set(name "${test_group}_${type}") + add_gtest_executable(${name} test_fmha_fwd.cpp) + get_test_property(${name} LABELS COMMON_LABELS) + set_tests_properties(${name} PROPERTIES LABELS "${COMMON_LABELS};${TEST_NAME};${test_group}") + target_compile_definitions(${name} PRIVATE DataTypeConfig=${CPP_TYPE_${type}}) + target_link_libraries(${name} PRIVATE ${FMHA_FWD_INSTANCES}) + list(APPEND all_tests ${name}) + endforeach() + message(STATUS "FMHA FWD tests: ${all_tests}") + add_custom_target(${test_group} DEPENDS ${all_tests}) +endfunction() -add_gtest_executable(test_ck_tile_fmha_bwd_fp16 test_fmha_bwd_fp16.cpp) -target_link_libraries(test_ck_tile_fmha_bwd_fp16 PRIVATE ${FMHA_BWD_INSTANCES}) +function(add_gtest_bwd test_group) + set(V_TYPES "fp16" "bf16" "fp32") + set(CPP_TYPE_fp16 "FmhaBwdFp16") + set(CPP_TYPE_bf16 "FmhaBwdBf16") + set(CPP_TYPE_fp32 "FmhaBwdFp32") -add_gtest_executable(test_ck_tile_fmha_fwd_fp32 test_fmha_fwd_fp32.cpp) -target_link_libraries(test_ck_tile_fmha_fwd_fp32 PRIVATE ${FMHA_FWD_INSTANCES}) + set(all_tests) + foreach(type ${V_TYPES}) + set(name "${test_group}_${type}") + add_gtest_executable(${name} test_fmha_bwd.cpp) + get_test_property(${name} LABELS COMMON_LABELS) + set_tests_properties(${name} PROPERTIES LABELS "${COMMON_LABELS};${TEST_NAME};${test_group}") + target_compile_definitions(${name} PRIVATE DataTypeConfig=${CPP_TYPE_${type}}) + target_link_libraries(${name} PRIVATE ${FMHA_BWD_INSTANCES}) + list(APPEND all_tests ${name}) + endforeach() + message(STATUS "FMHA BWD tests: ${all_tests}") + add_custom_target(${test_group} DEPENDS ${all_tests}) +endfunction() -add_gtest_executable(test_ck_tile_fmha_fwd_bf16 test_fmha_fwd_bf16.cpp) -target_link_libraries(test_ck_tile_fmha_fwd_bf16 PRIVATE ${FMHA_FWD_INSTANCES}) -add_gtest_executable(test_ck_tile_fmha_fwd_fp16 test_fmha_fwd_fp16.cpp) -target_link_libraries(test_ck_tile_fmha_fwd_fp16 PRIVATE ${FMHA_FWD_INSTANCES}) - -add_gtest_executable(test_ck_tile_fmha_fwd_fp8 test_fmha_fwd_fp8.cpp) -target_link_libraries(test_ck_tile_fmha_fwd_fp8 PRIVATE ${FMHA_FWD_INSTANCES}) - -add_custom_target(test_ck_tile_fmha - DEPENDS - test_ck_tile_fmha_bwd_fp32 - test_ck_tile_fmha_bwd_bf16 - test_ck_tile_fmha_bwd_fp16 - test_ck_tile_fmha_fwd_fp32 - test_ck_tile_fmha_fwd_bf16 - test_ck_tile_fmha_fwd_fp16 - test_ck_tile_fmha_fwd_fp8 -) +add_gtest_fwd(${TEST_NAME}_fwd) +add_gtest_bwd(${TEST_NAME}_bwd) +add_custom_target(${TEST_NAME} DEPENDS ${TEST_NAME}_fwd ${TEST_NAME}_bwd) diff --git a/test/ck_tile/fmha/test_fmha_bwd.cpp b/test/ck_tile/fmha/test_fmha_bwd.cpp new file mode 100644 index 0000000000..190cdd6452 --- /dev/null +++ b/test/ck_tile/fmha/test_fmha_bwd.cpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "example/ck_tile/01_fmha/fmha_bwd.hpp" +#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp" + +#include "gtest/gtest.h" + +#ifndef DataTypeConfig +#define DataTypeConfig FmhaBwdFp16 // or FmhaBwdBf16 / FmhaBwdFp32 +#endif + +using ::testing::Bool; +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +template +struct TestConfigs +{ + static constexpr auto HDimValues = std::array{ + std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}}; +}; +template <> +struct TestConfigs +{ + static constexpr auto HDimValues = + std::array{std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}}; +}; +static auto HDimValues = ValuesIn(TestConfigs::HDimValues); +const auto ModeValues = ValuesIn(std::vector{mode_enum::batch, mode_enum::group}); +constexpr auto init_method = "uf"; + +// Random seed used for initializing input tensors. 0 for non-deterministic seed +CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456) + +// Whether to run long tests (from smoke_test_fwd.sh) +CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS) + +const ck_tile::stream_config stream_config{ + nullptr, // stream_id_ + false, // time_kernel_ + 1, // log_level_ + 0, // cold_niters_ + 1, // nrepeat_ + true, // is_gpu_timer_ + false, // flush_cache_ + 1, // rotating_count_ +}; + +// batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str +using FmhaBwdDimsMaskParam = std::tuple; +using FmhaBwdTestParam = std::tuple< // + mode_enum, // mode + std::tuple, // hdim_q, hdim_v + std::tuple, // io_perm + std::string, // bias_str + bool, // use_dbias + float, // p_drop + std::tuple, // drop_seed, drop_offset, drop_prefs + FmhaBwdDimsMaskParam, + bool // deterministic + >; +void fmha_bwd_test(const FmhaBwdTestParam& param) +{ + auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = param; + auto [hdim_q, hdim_v] = hdims; + auto [i_perm, o_perm] = perm; + auto [drop_seed, drop_offset, drop_prefs] = drop_misc; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + auto result = fmha_bwd_run( + mode, + batch, + nhead, + nhead_k, + {seqlen_q}, + {seqlen_k}, + hdim_q, + hdim_v, + i_perm, + o_perm, + 0, // scale + bias_str, + use_dbias, + p_drop, + drop_seed, + drop_offset, + drop_prefs, + mask_str, + det, // deterministic + init_method, + static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), + 1, + stream_config); + + if(result == bwd_result::no_instance) + GTEST_SKIP() << "No instance for current parameters"; + ASSERT_EQ(result, bwd_result::success); +} + +// Test cases from example/ck_tile/01_fmha/script/smoke_test_bwd.sh +class AllLong : public TestWithParam +{ +}; +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong); +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, + AllLong, + Combine(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS)) + ? ModeValues + : ValuesIn(std::vector{}), + HDimValues, + Values(std::tuple{true, true}, std::tuple{false, false}), // perm + Values("n", "a"), + Values(false), // use_dbias + Values(0.0f, 0.2f), // p_drop + Values(std::tuple{123, 1024, true}), // seed/offset/prefs + Values(std::tuple{1, 4, 2, 259, -1, "0"}, + std::tuple{2, 2, -1, 516, 253, "0"}, + std::tuple{1, 4, 1, 500, 251, "1"}, + std::tuple{1, 2, -1, 900, 258, "2"}, + std::tuple{2, 1, -1, 987, 219, "t:128,30"}, + std::tuple{2, 3, 1, 244, 499, "b:4,35"}), + Values(false) // deterministic + )); +TEST_P(AllLong, DataTypeConfig) { fmha_bwd_test(GetParam()); } + +class HDimPadding : public TestWithParam +{ +}; +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, + HDimPadding, + Combine(ModeValues, + Values(std::tuple{24, 48}, + std::tuple{48, 48}, + std::tuple{72, 72}, + std::tuple{96, 96}, + std::tuple{120, 160}, + std::tuple{256, 108}, + std::tuple{40, 64}), + Values(std::tuple{true, true}, std::tuple{false, false}), // perm + Values("n"), // bias_str + Values(false), // use_dbias + Values(0.0f), // p_drop + Values(std::tuple{0, 0, false}), // seed/offset/prefs + Values(std::tuple{1, 4, 2, 480, -1, "0"}, + std::tuple{2, 2, -1, 300, 400, "t:64,64"}, + std::tuple{1, 4, 1, 512, 201, "1"}, + std::tuple{1, 2, -1, 900, 256, "0"}, + std::tuple{2, 1, -1, 256, 256, "1"}), + Values(false) // deterministic + )); +TEST_P(HDimPadding, DataTypeConfig) { fmha_bwd_test(GetParam()); } + +class ElementwiseBias : public TestWithParam +{ +}; +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, + ElementwiseBias, + Combine(ModeValues, + HDimValues, + // layouts of bias and dbias are controlled by i_perm + Values(std::tuple{true, false}, std::tuple{false, false}), + Values("e:0", "e:1", "e:2"), + Bool(), // use_dbias + Values(0.0f), // p_drop + Values(std::tuple{0, 0, false}), // seed/offset/prefs + Values(std::tuple{1, 4, 2, 1024, 100, "0"}, + std::tuple{3, 2, -1, 128, 256, "2"}, + std::tuple{2, 2, -1, 130, 499, "t:50,64"}), + Values(false) // deterministic + )); +TEST_P(ElementwiseBias, DataTypeConfig) { fmha_bwd_test(GetParam()); } +class Alibi : public TestWithParam +{ +}; + +INSTANTIATE_TEST_SUITE_P( + TestCkTileFmhaBwd, + Alibi, + Combine(ModeValues, + HDimValues, + Values(std::tuple{true, true}), // perm + Values("a:0", "a:1"), + Values(false), // use_dbias + Values(0.0f), // p_drop + Values(std::tuple{0, 0, false}), // seed/offset/prefs + ValuesIn([]() { + const std::array dims{ + std::tuple{1, 3, 3, 1024, 1000}, + std::tuple{3, 5, 5, 128, 256}, + std::tuple{2, 8, 4, 130, 320}, + }; + const std::array mask_strs{"0", "t", "b", "t:50,64", "b:32,40"}; + std::vector dims_masks; + std::for_each(dims.begin(), dims.end(), [&](const auto& d) { + const auto& [b, h, hk, sq, sk] = d; + std::for_each(mask_strs.begin(), mask_strs.end(), [&](const auto& m) { + dims_masks.push_back(std::tuple{b, h, hk, sq, sk, m}); + }); + }); + return dims_masks; + }()), + Values(false) // deterministic + )); +TEST_P(Alibi, DataTypeConfig) { fmha_bwd_test(GetParam()); } + +class Dropout : public TestWithParam +{ +}; +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, + Dropout, + Combine(ModeValues, + HDimValues, + Values(std::tuple{true, true}), // perm + Values("n"), // bias_str + Values(false), // use_dbias + Values(0.123f, 0.5f), // p_drop + Values(std::tuple{10, 123, false}, // seed/offset/prefs + std::tuple{34534564645, 7876878876864, true}), + Values(std::tuple{2, 6, 2, 180, 512, "0"}, + std::tuple{3, 2, 2, 256, 128, "1"}, + std::tuple{4, 2, 1, 100, 768, "2"}), + Values(false) // deterministic + )); + +TEST_P(Dropout, DataTypeConfig) { fmha_bwd_test(GetParam()); } + +class Deterministic : public TestWithParam +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, + Deterministic, + Combine(ModeValues, + HDimValues, + Values(std::tuple{false, true}, std::tuple{true, true}), // perm + Values("n"), // bias_str + Values(false), // use_dbias + Values(0.0f), // p_drop + Values(std::tuple{0, 0, false}), // seed/offset/prefs + Values(std::tuple{2, 6, 2, 180, 512, "0"}, + std::tuple{3, 3, 1, 256, 128, "1"}, + std::tuple{4, 2, 2, 768, 100, "2"}), + Values(true) // deterministic + )); +TEST_P(Deterministic, DataTypeConfig) { fmha_bwd_test(GetParam()); } diff --git a/test/ck_tile/fmha/test_fmha_bwd.inc b/test/ck_tile/fmha/test_fmha_bwd.inc deleted file mode 100644 index 704b5c7bf7..0000000000 --- a/test/ck_tile/fmha/test_fmha_bwd.inc +++ /dev/null @@ -1,347 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -using ::testing::Bool; -using ::testing::Combine; -using ::testing::TestWithParam; -using ::testing::Values; -using ::testing::ValuesIn; - -// Random seed used for initializing input tensors. 0 for non-deterministic seed -CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456) - -// Whether to run long tests (from smoke_test_fwd.sh) -CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS) - -#define CHECK_RESULT(result) \ - do \ - { \ - if(result == bwd_result::no_instance) \ - GTEST_SKIP() << "No instance for current parameters"; \ - ASSERT_EQ(result, bwd_result::success); \ - } while(0) - -const ck_tile::stream_config stream_config{ - nullptr, // stream_id_ - false, // time_kernel_ - 1, // log_level_ - 0, // cold_niters_ - 1, // nrepeat_ - true, // is_gpu_timer_ - false, // flush_cache_ - 1, // rotating_count_ -}; - -#define COMMON_ARGS \ - init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \ - stream_config - -auto EnableTestIf(bool condition) -{ - return ValuesIn(condition ? std::vector{true} : std::vector{}); -} - -class AllLong : public TestWithParam, - bool, - mode_enum, - std::string, - float, - std::tuple>> -{ -}; - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong); - -// Test cases from example/ck_tile/01_fmha/script/smoke_test_bwd.sh - -INSTANTIATE_TEST_SUITE_P( - TestCkTileFmhaBwd, - AllLong, - Combine(EnableTestIf(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))), - HDimValues, - Bool(), - ModeValues, - Values("n", "a"), - Values(0.0f, 0.2f), - Values(std::tuple{1, 4, 2, 259, -1, "0"}, - std::tuple{2, 2, -1, 516, 253, "0"}, - std::tuple{1, 4, 1, 500, 251, "1"}, - std::tuple{1, 2, -1, 900, 258, "2"}, - std::tuple{2, 1, -1, 987, 219, "t:128,30"}, - std::tuple{2, 3, 1, 244, 499, "b:4,35"}))); - -TEST_P(AllLong, Test) -{ - auto [_, hdims, perm, mode, bias_str, p_drop, dims_mask] = GetParam(); - auto [hdim_q, hdim_v] = hdims; - auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; - - auto result = fmha_bwd_run(mode, - batch, - nhead, - nhead_k, - {seqlen_q}, - {seqlen_k}, - hdim_q, - hdim_v, - perm, // i_perm - perm, // o_perm - 0, // scale - bias_str, // bias_str - false, // use_dbias - p_drop, // p_drop - 123, // drop_seed - 1024, // drop_offset - true, // drop_prefs - mask_str, // mask_str - false, // deterministic - COMMON_ARGS); - CHECK_RESULT(result); -} - -class HDimPadding - : public TestWithParam, - bool, - mode_enum, - std::tuple>> -{ -}; - -INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, - HDimPadding, - Combine(Values(std::tuple{24, 48}, - std::tuple{48, 48}, - std::tuple{72, 72}, - std::tuple{96, 96}, - std::tuple{120, 160}, - std::tuple{256, 108}, - std::tuple{40, 64}), - Bool(), - ModeValues, - Values(std::tuple{1, 4, 2, 480, -1, "0"}, - std::tuple{2, 2, -1, 300, 400, "t:64,64"}, - std::tuple{1, 4, 1, 512, 201, "1"}, - std::tuple{1, 2, -1, 900, 256, "0"}, - std::tuple{2, 1, -1, 256, 256, "1"}))); - -TEST_P(HDimPadding, Test) -{ - auto [hdims, perm, mode, dims_mask] = GetParam(); - auto [hdim_q, hdim_v] = hdims; - auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; - - auto result = fmha_bwd_run(mode, - batch, - nhead, - nhead_k, - {seqlen_q}, - {seqlen_k}, - hdim_q, - hdim_v, - perm, // i_perm - perm, // o_perm - 0, // scale - "n", // bias_str - false, // use_dbias - 0.0f, // p_drop - 0, // drop_seed - 0, // drop_offset - false, // drop_prefs - mask_str, // mask_str - false, // deterministic - COMMON_ARGS); - CHECK_RESULT(result); -} - -class ElementwiseBias - : public TestWithParam, - bool, - mode_enum, - std::string, - bool, - std::tuple>> -{ -}; - -INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, - ElementwiseBias, - Combine(HDimValues, - Bool(), // layouts of bias and dbias are controlled by i_perm - ModeValues, - Values("e:0", "e:1", "e:2"), - Bool(), - Values(std::tuple{1, 4, 2, 1024, 100, "0"}, - std::tuple{3, 2, -1, 128, 256, "2"}, - std::tuple{2, 2, -1, 130, 499, "t:50,64"}))); - -TEST_P(ElementwiseBias, Test) -{ - auto [hdims, i_perm, mode, bias_str, use_dbias, dims_mask] = GetParam(); - auto [hdim_q, hdim_v] = hdims; - auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; - - auto result = fmha_bwd_run(mode, - batch, - nhead, - nhead_k, - {seqlen_q}, - {seqlen_k}, - hdim_q, - hdim_v, - i_perm, // i_perm - false, // o_perm - 0, // scale - bias_str, // bias_str - use_dbias, // use_dbias - 0.0f, // p_drop - 123, // drop_seed - 1024, // drop_offset - true, // drop_prefs - mask_str, // mask_str - false, // deterministic - COMMON_ARGS); - CHECK_RESULT(result); -} - -class Alibi : public TestWithParam, - mode_enum, - std::string, - std::tuple, - std::string>> -{ -}; - -INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, - Alibi, - Combine(HDimValues, - ModeValues, - Values("a:0", "a:1"), - Values(std::tuple{1, 3, 3, 1024, 1000}, - std::tuple{3, 5, 5, 128, 256}, - std::tuple{2, 8, 4, 130, 320}), - Values("0", "t", "b", "t:50,64", "b:32,40"))); - -TEST_P(Alibi, Test) -{ - auto [hdims, mode, bias_str, dims, mask_str] = GetParam(); - auto [hdim_q, hdim_v] = hdims; - auto [batch, nhead, nhead_k, seqlen_q, seqlen_k] = dims; - - auto result = fmha_bwd_run(mode, - batch, - nhead, - nhead_k, - {seqlen_q}, - {seqlen_k}, - hdim_q, - hdim_v, - true, // i_perm - true, // o_perm - 0, // scale - bias_str, // bias_str - false, // use_dbias - 0.0f, // p_drop - 0, // drop_seed - 0, // drop_offset - false, // drop_prefs - mask_str, // mask_str - false, // deterministic - COMMON_ARGS); - CHECK_RESULT(result); -} - -class Dropout : public TestWithParam, - mode_enum, - float, - std::tuple, - std::tuple>> -{ -}; - -INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, - Dropout, - Combine(HDimValues, - ModeValues, - Values(0.123f, 0.5f), - Values(std::tuple{10, 123, false}, - std::tuple{34534564645, 7876878876864, true}), - Values(std::tuple{2, 6, 2, 180, 512, "0"}, - std::tuple{3, 2, 2, 256, 128, "1"}, - std::tuple{4, 2, 1, 100, 768, "2"}))); - -TEST_P(Dropout, Test) -{ - auto [hdims, mode, p_drop, drop_seed_offset_prefs, dims_mask] = GetParam(); - auto [hdim_q, hdim_v] = hdims; - auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs; - auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; - - auto result = fmha_bwd_run(mode, - batch, - nhead, - nhead_k, - {seqlen_q}, - {seqlen_k}, - hdim_q, - hdim_v, - true, // i_perm - true, // o_perm - 0.1f, // scale - "n", // bias_str - false, // use_dbias - p_drop, // p_drop - drop_seed, // drop_seed - drop_offset, // drop_offset - drop_prefs, // drop_prefs - mask_str, // mask_str - false, // deterministic - COMMON_ARGS); - CHECK_RESULT(result); -} - -class Deterministic - : public TestWithParam, - bool, - mode_enum, - std::tuple>> -{ -}; - -INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, - Deterministic, - Combine(HDimValues, - Bool(), - ModeValues, - Values(std::tuple{2, 6, 2, 180, 512, "0"}, - std::tuple{3, 3, 1, 256, 128, "1"}, - std::tuple{4, 2, 2, 768, 100, "2"}))); - -TEST_P(Deterministic, Test) -{ - auto [hdims, i_perm, mode, dims_mask] = GetParam(); - auto [hdim_q, hdim_v] = hdims; - auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; - - auto result = fmha_bwd_run(mode, - batch, - nhead, - nhead_k, - {seqlen_q}, - {seqlen_k}, - hdim_q, - hdim_v, - i_perm, // i_perm - true, // o_perm - 0, // scale - "n", // bias_str - false, // use_dbias - 0.0f, // p_drop - 0, // drop_seed - 0, // drop_offset - false, // drop_prefs - mask_str, // mask_str - true, // deterministic - COMMON_ARGS); - CHECK_RESULT(result); -} diff --git a/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp b/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp deleted file mode 100644 index 077e45a10d..0000000000 --- a/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp +++ /dev/null @@ -1,21 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "example/ck_tile/01_fmha/fmha_bwd.hpp" -#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp" - -#include "gtest/gtest.h" - -using DataTypeConfig = FmhaBwdBf16; - -using ::testing::Values; -using ::testing::ValuesIn; - -const auto HDimValues = - Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); - -const auto ModeValues = Values(mode_enum::batch, mode_enum::group); - -constexpr auto init_method = "uf"; - -#include "test_fmha_bwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp b/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp deleted file mode 100644 index 86621b0494..0000000000 --- a/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp +++ /dev/null @@ -1,21 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "example/ck_tile/01_fmha/fmha_bwd.hpp" -#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp" - -#include "gtest/gtest.h" - -using DataTypeConfig = FmhaBwdFp16; - -using ::testing::Values; -using ::testing::ValuesIn; - -const auto HDimValues = - Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); - -const auto ModeValues = Values(mode_enum::batch, mode_enum::group); - -constexpr auto init_method = "uf"; - -#include "test_fmha_bwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp b/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp deleted file mode 100644 index 09010d4b22..0000000000 --- a/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "example/ck_tile/01_fmha/fmha_bwd.hpp" -#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp" - -#include "gtest/gtest.h" - -using DataTypeConfig = FmhaBwdFp32; - -using ::testing::Values; -using ::testing::ValuesIn; - -const auto HDimValues = Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}); - -const auto ModeValues = Values(mode_enum::batch, mode_enum::group); - -const std::string init_method = "uf"; - -#include "test_fmha_bwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_fwd.inc b/test/ck_tile/fmha/test_fmha_fwd.cpp similarity index 92% rename from test/ck_tile/fmha/test_fmha_fwd.inc rename to test/ck_tile/fmha/test_fmha_fwd.cpp index ccca5cf969..6e4b547465 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.inc +++ b/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -1,12 +1,104 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include "example/ck_tile/01_fmha/fmha_fwd.hpp" +#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp" + +#include "gtest/gtest.h" + +#ifndef DataTypeConfig +#define DataTypeConfig FmhaFwdFp16 // or FmhaFwdBf16 / FmhaFwdFp8 / FmhaFwdFp32 +#endif + using ::testing::Bool; using ::testing::Combine; using ::testing::TestWithParam; using ::testing::Values; using ::testing::ValuesIn; +template +struct TestConfigs +{ + static constexpr auto HDimValues = std::array{ + std::tuple{32, -1}, + std::tuple{64, -1}, + std::tuple{96, 128}, + std::tuple{128, -1}, + std::tuple{192, 128}, + std::tuple{192, -1}, + std::tuple{256, -1}, + }; + static constexpr auto SplitKVHDimValues = std::array{ + std::tuple{32, -1}, + std::tuple{64, -1}, + std::tuple{96, -1}, + std::tuple{128, -1}, + std::tuple{256, -1}, + }; + static constexpr auto AppendKVHDimValues = std::array{ + std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}}; + static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group}; + static constexpr auto IsVRowmajorValues = std::array{false, true}; + static constexpr bool squant = false; + static constexpr bool def_lse = true; + static constexpr bool def_is_v_rowmajor = true; + static int adjust_seqlen(int seqlen) { return seqlen; } +}; +template <> +struct TestConfigs +{ + // Currently there are no fp8 instances for splitkv, pagedkv by default (the tests pass if such + // instances are added), however the corresponding tests are not disabled (they will be skipped) + // in case such instances will be added in the future. + + static constexpr auto HDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}}; + static constexpr auto SplitKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}}; + static constexpr auto AppendKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}}; + // There are no fp8 instances with seqlen padding (mode_enum::group requires it) + static constexpr auto ModeValues = std::array{mode_enum::batch}; + static constexpr auto IsVRowmajorValues = std::array{false}; + static constexpr bool squant = true; + static constexpr bool def_lse = false; + static constexpr bool def_is_v_rowmajor = true; + static int adjust_seqlen(int seqlen) + { + // There are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests + return ck_tile::integer_least_multiple(seqlen, 128); + } +}; +template <> +struct TestConfigs +{ + static constexpr auto HDimValues = std::array{ + std::tuple{32, -1}, + std::tuple{48, -1}, + std::tuple{64, -1}, + std::tuple{96, 128}, + std::tuple{128, -1}, + std::tuple{192, -1}, + std::tuple{256, -1}, + }; + static constexpr auto SplitKVHDimValues = std::array, 0>{}; + static constexpr auto AppendKVHDimValues = std::array, 0>{}; + static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group}; + static constexpr auto IsVRowmajorValues = std::array{true}; + static constexpr bool squant = false; + static constexpr bool def_lse = true; + static constexpr bool def_is_v_rowmajor = true; + static int adjust_seqlen(int seqlen) { return seqlen; } +}; + +static auto HDimValues = ValuesIn(TestConfigs::HDimValues); +static auto SplitKVHDimValues = ValuesIn(TestConfigs::SplitKVHDimValues); +static auto AppendKVHDimValues = ValuesIn(TestConfigs::AppendKVHDimValues); +static auto ModeValues = ValuesIn(TestConfigs::ModeValues); +static auto IsVRowmajorValues = ValuesIn(TestConfigs::IsVRowmajorValues); +constexpr bool squant = TestConfigs::squant; +constexpr bool def_lse = TestConfigs::def_lse; +constexpr bool def_is_v_rowmajor = TestConfigs::def_is_v_rowmajor; +int adjust_seqlen(int seqlen) { return TestConfigs::adjust_seqlen(seqlen); } +constexpr auto init_method = "uf"; + // Random seed used for initializing input tensors. 0 for non-deterministic seed CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456) @@ -79,7 +171,7 @@ INSTANTIATE_TEST_SUITE_P( std::tuple{1, 2, 1, -1, -1, 33, 0, -1, "2"}, std::tuple{1, 2, 1, -1, -1, 1, 10, 32, "2"}))); -TEST_P(AllLong, Test) +TEST_P(AllLong, DataTypeConfig) { auto [_, hdims, perm, is_v_rowmajor, mode, lse, bias_str, p_drop, dims_mask] = GetParam(); auto [hdim_q, hdim_v] = hdims; @@ -283,7 +375,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, std::tuple{1, 2, -1, 900, 256, -1, "0"}, std::tuple{2, 1, -1, 256, 256, -1, "1"}))); -TEST_P(HDimPadding, Test) +TEST_P(HDimPadding, DataTypeConfig) { auto [hdims, perm, is_v_rowmajor, mode, dims_mask] = GetParam(); auto [hdim_q, hdim_v] = hdims; @@ -343,7 +435,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, std::tuple{3, 2, -1, 128, 256, "2"}, std::tuple{2, 2, -1, 130, 499, "t:50,64"}))); -TEST_P(ElementwiseBias, Test) +TEST_P(ElementwiseBias, DataTypeConfig) { auto [hdims, i_perm, mode, bias_str, dims_mask] = GetParam(); auto [hdim_q, hdim_v] = hdims; @@ -402,7 +494,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, std::tuple{2, 8, 2, 300, 355}), Values("0", "t", "b", "t:50,64", "b:32,40"))); -TEST_P(Alibi, Test) +TEST_P(Alibi, DataTypeConfig) { auto [hdims, mode, bias_str, dims, mask_str] = GetParam(); auto [hdim_q, hdim_v] = hdims; @@ -462,7 +554,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, std::tuple{3, 2, 2, 256, 128, "1"}, std::tuple{4, 3, 1, 100, 768, "2"}))); -TEST_P(Dropout, Test) +TEST_P(Dropout, DataTypeConfig) { auto [hdims, mode, p_drop, drop_seed_offset_prefs, dims_mask] = GetParam(); auto [hdim_q, hdim_v] = hdims; @@ -528,7 +620,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, std::tuple{3, 2, -1, 128, 768, "2"}, std::tuple{2, 2, -1, 230, 899, "t:50,64"}))); -TEST_P(PagedKV, Test) +TEST_P(PagedKV, DataTypeConfig) { auto [hdims, i_perm, is_v_rowmajor, mode, page_block_size, dims_mask] = GetParam(); auto [hdim_q, hdim_v] = hdims; @@ -597,7 +689,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, std::tuple{2, 2, -1, 512, 2000, "0"}, std::tuple{3, 2, -1, 230, 899, "t:128,128"}))); -TEST_P(SplitKV, Test) +TEST_P(SplitKV, DataTypeConfig) { auto [hdims, i_perm, is_v_rowmajor, mode_use_cache_batch_idx, num_splits, dims_mask] = GetParam(); @@ -668,7 +760,7 @@ INSTANTIATE_TEST_SUITE_P( GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKV); -TEST_P(AppendKV, Test) +TEST_P(AppendKV, DataTypeConfig) { auto [hdims, i_perm, @@ -745,7 +837,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, std::tuple{1, 2, 1, 128, 55, "0"}, std::tuple{3, 4, 2, 72, 128, "1"}))); -TEST_P(AppendKVRoPE, Test) +TEST_P(AppendKVRoPE, DataTypeConfig) { auto [_, hdims, i_perm, is_v_rowmajor, rotary, seqlen_knew, dims_mask] = GetParam(); auto [hdim_q, hdim_v] = hdims; @@ -1017,7 +1109,7 @@ static const std::vector kPaddingParams = BuildPaddingParams(); INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd_Padding, PaddingCases, ValuesIn(kPaddingParams)); -TEST_P(PaddingCases, Test) +TEST_P(PaddingCases, DataTypeConfig) { if constexpr(std::is_same_v) { diff --git a/test/ck_tile/fmha/test_fmha_fwd_bf16.cpp b/test/ck_tile/fmha/test_fmha_fwd_bf16.cpp deleted file mode 100644 index fbc6449a6a..0000000000 --- a/test/ck_tile/fmha/test_fmha_fwd_bf16.cpp +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "example/ck_tile/01_fmha/fmha_fwd.hpp" -#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp" - -#include "gtest/gtest.h" - -#include -#include - -using ::testing::Values; - -using DataTypeConfig = FmhaFwdBf16; - -const auto HDimValues = Values(std::tuple{32, -1}, - std::tuple{64, -1}, - std::tuple{96, 128}, - std::tuple{128, -1}, - std::tuple{192, 128}, - std::tuple{192, -1}, - std::tuple{256, -1}); - -const auto SplitKVHDimValues = Values(std::tuple{32, -1}, - std::tuple{64, -1}, - std::tuple{96, -1}, - std::tuple{128, -1}, - std::tuple{256, -1}); - -const auto AppendKVHDimValues = - Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); - -const auto ModeValues = Values(mode_enum::batch, mode_enum::group); - -const auto IsVRowmajorValues = Values(false, true); - -const bool squant = false; -const std::string init_method = "uf"; -const bool def_lse = true; -const bool def_is_v_rowmajor = true; - -int adjust_seqlen(int seqlen) { return seqlen; } - -#include "test_fmha_fwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_fwd_fp16.cpp b/test/ck_tile/fmha/test_fmha_fwd_fp16.cpp deleted file mode 100644 index abc2c44726..0000000000 --- a/test/ck_tile/fmha/test_fmha_fwd_fp16.cpp +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "example/ck_tile/01_fmha/fmha_fwd.hpp" -#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp" - -#include "gtest/gtest.h" - -#include -#include - -using ::testing::Values; - -using DataTypeConfig = FmhaFwdFp16; - -const auto HDimValues = Values(std::tuple{32, -1}, - std::tuple{64, -1}, - std::tuple{96, 128}, - std::tuple{128, -1}, - std::tuple{192, 128}, - std::tuple{192, -1}, - std::tuple{256, -1}); - -const auto SplitKVHDimValues = Values(std::tuple{32, -1}, - std::tuple{64, -1}, - std::tuple{96, -1}, - std::tuple{128, -1}, - std::tuple{256, -1}); - -const auto AppendKVHDimValues = - Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); - -const auto ModeValues = Values(mode_enum::batch, mode_enum::group); - -const auto IsVRowmajorValues = Values(false, true); - -const bool squant = false; -const std::string init_method = "uf"; -const bool def_lse = true; -const bool def_is_v_rowmajor = true; - -int adjust_seqlen(int seqlen) { return seqlen; } - -#include "test_fmha_fwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_fwd_fp32.cpp b/test/ck_tile/fmha/test_fmha_fwd_fp32.cpp deleted file mode 100644 index 00f1eb0629..0000000000 --- a/test/ck_tile/fmha/test_fmha_fwd_fp32.cpp +++ /dev/null @@ -1,39 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "example/ck_tile/01_fmha/fmha_fwd.hpp" -#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp" - -#include "gtest/gtest.h" - -#include -#include - -using ::testing::Values; - -using DataTypeConfig = FmhaFwdFp32; - -const auto HDimValues = Values(std::tuple{32, -1}, - std::tuple{48, -1}, - std::tuple{64, -1}, - std::tuple{96, 128}, - std::tuple{128, -1}, - std::tuple{192, -1}, - std::tuple{256, -1}); - -const auto SplitKVHDimValues = Values(); - -const auto AppendKVHDimValues = Values(); - -const auto ModeValues = Values(mode_enum::batch, mode_enum::group); - -const auto IsVRowmajorValues = Values(true); - -const bool squant = false; -const std::string init_method = "uf"; -const bool def_lse = true; -const bool def_is_v_rowmajor = true; - -int adjust_seqlen(int seqlen) { return seqlen; } - -#include "test_fmha_fwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp b/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp deleted file mode 100644 index b99c304d1f..0000000000 --- a/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "example/ck_tile/01_fmha/fmha_fwd.hpp" -#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp" - -#include "gtest/gtest.h" - -#include -#include - -using ::testing::Values; - -using DataTypeConfig = FmhaFwdFp8; - -// Currently there are no fp8 instances for splitkv, pagedkv by default (the tests pass if such -// instances are added), however the corresponding tests are not disabled (they will be skipped) -// in case such instances will be added in the future. - -const auto HDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}); - -const auto SplitKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}); - -const auto AppendKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}); - -// There are no fp8 instances with seqlen padding (mode_enum::group requires it) -const auto ModeValues = Values(mode_enum::batch); - -const auto IsVRowmajorValues = Values(false); - -const auto squant = true; -const std::string init_method = "uf"; -const bool def_lse = false; -const bool def_is_v_rowmajor = true; - -int adjust_seqlen(int seqlen) -{ - // There are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests - return ck_tile::integer_least_multiple(seqlen, 128); -} - -#include "test_fmha_fwd.inc"