mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] FMHA Tests Enhancement (#2945)
* fmha-gtest-wip
* Thanks Copilot!
[ROCm/composable_kernel commit: b6036bc76a]
This commit is contained in:
@@ -5,35 +5,50 @@ endif()
|
||||
|
||||
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
|
||||
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
|
||||
set(TEST_NAME "test_ck_tile_fmha")
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_bwd_fp32 test_fmha_bwd_fp32.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_bwd_fp32 PRIVATE ${FMHA_BWD_INSTANCES})
|
||||
function(add_gtest_fwd test_group)
|
||||
set(V_TYPES "fp16" "bf16" "fp8" "fp32")
|
||||
set(CPP_TYPE_fp16 "FmhaFwdFp16")
|
||||
set(CPP_TYPE_bf16 "FmhaFwdBf16")
|
||||
set(CPP_TYPE_fp8 "FmhaFwdFp8")
|
||||
set(CPP_TYPE_fp32 "FmhaFwdFp32")
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_bwd_bf16 test_fmha_bwd_bf16.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_bwd_bf16 PRIVATE ${FMHA_BWD_INSTANCES})
|
||||
set(all_tests)
|
||||
foreach(type ${V_TYPES})
|
||||
set(name "${test_group}_${type}")
|
||||
add_gtest_executable(${name} test_fmha_fwd.cpp)
|
||||
get_test_property(${name} LABELS COMMON_LABELS)
|
||||
set_tests_properties(${name} PROPERTIES LABELS "${COMMON_LABELS};${TEST_NAME};${test_group}")
|
||||
target_compile_definitions(${name} PRIVATE DataTypeConfig=${CPP_TYPE_${type}})
|
||||
target_link_libraries(${name} PRIVATE ${FMHA_FWD_INSTANCES})
|
||||
list(APPEND all_tests ${name})
|
||||
endforeach()
|
||||
message(STATUS "FMHA FWD tests: ${all_tests}")
|
||||
add_custom_target(${test_group} DEPENDS ${all_tests})
|
||||
endfunction()
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_bwd_fp16 test_fmha_bwd_fp16.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_bwd_fp16 PRIVATE ${FMHA_BWD_INSTANCES})
|
||||
function(add_gtest_bwd test_group)
|
||||
set(V_TYPES "fp16" "bf16" "fp32")
|
||||
set(CPP_TYPE_fp16 "FmhaBwdFp16")
|
||||
set(CPP_TYPE_bf16 "FmhaBwdBf16")
|
||||
set(CPP_TYPE_fp32 "FmhaBwdFp32")
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_fwd_fp32 test_fmha_fwd_fp32.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_fwd_fp32 PRIVATE ${FMHA_FWD_INSTANCES})
|
||||
set(all_tests)
|
||||
foreach(type ${V_TYPES})
|
||||
set(name "${test_group}_${type}")
|
||||
add_gtest_executable(${name} test_fmha_bwd.cpp)
|
||||
get_test_property(${name} LABELS COMMON_LABELS)
|
||||
set_tests_properties(${name} PROPERTIES LABELS "${COMMON_LABELS};${TEST_NAME};${test_group}")
|
||||
target_compile_definitions(${name} PRIVATE DataTypeConfig=${CPP_TYPE_${type}})
|
||||
target_link_libraries(${name} PRIVATE ${FMHA_BWD_INSTANCES})
|
||||
list(APPEND all_tests ${name})
|
||||
endforeach()
|
||||
message(STATUS "FMHA BWD tests: ${all_tests}")
|
||||
add_custom_target(${test_group} DEPENDS ${all_tests})
|
||||
endfunction()
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_fwd_bf16 test_fmha_fwd_bf16.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_fwd_bf16 PRIVATE ${FMHA_FWD_INSTANCES})
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_fwd_fp16 test_fmha_fwd_fp16.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_fwd_fp16 PRIVATE ${FMHA_FWD_INSTANCES})
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_fwd_fp8 test_fmha_fwd_fp8.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_fwd_fp8 PRIVATE ${FMHA_FWD_INSTANCES})
|
||||
|
||||
add_custom_target(test_ck_tile_fmha
|
||||
DEPENDS
|
||||
test_ck_tile_fmha_bwd_fp32
|
||||
test_ck_tile_fmha_bwd_bf16
|
||||
test_ck_tile_fmha_bwd_fp16
|
||||
test_ck_tile_fmha_fwd_fp32
|
||||
test_ck_tile_fmha_fwd_bf16
|
||||
test_ck_tile_fmha_fwd_fp16
|
||||
test_ck_tile_fmha_fwd_fp8
|
||||
)
|
||||
add_gtest_fwd(${TEST_NAME}_fwd)
|
||||
add_gtest_bwd(${TEST_NAME}_bwd)
|
||||
add_custom_target(${TEST_NAME} DEPENDS ${TEST_NAME}_fwd ${TEST_NAME}_bwd)
|
||||
|
||||
248
test/ck_tile/fmha/test_fmha_bwd.cpp
Normal file
248
test/ck_tile/fmha/test_fmha_bwd.cpp
Normal file
@@ -0,0 +1,248 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#ifndef DataTypeConfig
|
||||
#define DataTypeConfig FmhaBwdFp16 // or FmhaBwdBf16 / FmhaBwdFp32
|
||||
#endif
|
||||
|
||||
using ::testing::Bool;
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
template <typename T>
|
||||
struct TestConfigs
|
||||
{
|
||||
static constexpr auto HDimValues = std::array{
|
||||
std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
|
||||
};
|
||||
template <>
|
||||
struct TestConfigs<FmhaBwdFp32>
|
||||
{
|
||||
static constexpr auto HDimValues =
|
||||
std::array{std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}};
|
||||
};
|
||||
static auto HDimValues = ValuesIn(TestConfigs<DataTypeConfig>::HDimValues);
|
||||
const auto ModeValues = ValuesIn(std::vector{mode_enum::batch, mode_enum::group});
|
||||
constexpr auto init_method = "uf";
|
||||
|
||||
// Random seed used for initializing input tensors. 0 for non-deterministic seed
|
||||
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
|
||||
|
||||
// Whether to run long tests (from smoke_test_fwd.sh)
|
||||
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS)
|
||||
|
||||
const ck_tile::stream_config stream_config{
|
||||
nullptr, // stream_id_
|
||||
false, // time_kernel_
|
||||
1, // log_level_
|
||||
0, // cold_niters_
|
||||
1, // nrepeat_
|
||||
true, // is_gpu_timer_
|
||||
false, // flush_cache_
|
||||
1, // rotating_count_
|
||||
};
|
||||
|
||||
// batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str
|
||||
using FmhaBwdDimsMaskParam = std::tuple<int, int, int, int, int, std::string>;
|
||||
using FmhaBwdTestParam = std::tuple< //
|
||||
mode_enum, // mode
|
||||
std::tuple<int, int>, // hdim_q, hdim_v
|
||||
std::tuple<bool, bool>, // io_perm
|
||||
std::string, // bias_str
|
||||
bool, // use_dbias
|
||||
float, // p_drop
|
||||
std::tuple<uint64_t, uint64_t, bool>, // drop_seed, drop_offset, drop_prefs
|
||||
FmhaBwdDimsMaskParam,
|
||||
bool // deterministic
|
||||
>;
|
||||
void fmha_bwd_test(const FmhaBwdTestParam& param)
|
||||
{
|
||||
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = param;
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [i_perm, o_perm] = perm;
|
||||
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(
|
||||
mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
0, // scale
|
||||
bias_str,
|
||||
use_dbias,
|
||||
p_drop,
|
||||
drop_seed,
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
stream_config);
|
||||
|
||||
if(result == bwd_result::no_instance)
|
||||
GTEST_SKIP() << "No instance for current parameters";
|
||||
ASSERT_EQ(result, bwd_result::success);
|
||||
}
|
||||
|
||||
// Test cases from example/ck_tile/01_fmha/script/smoke_test_bwd.sh
|
||||
class AllLong : public TestWithParam<FmhaBwdTestParam>
|
||||
{
|
||||
};
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong);
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
AllLong,
|
||||
Combine(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))
|
||||
? ModeValues
|
||||
: ValuesIn(std::vector<mode_enum>{}),
|
||||
HDimValues,
|
||||
Values(std::tuple{true, true}, std::tuple{false, false}), // perm
|
||||
Values("n", "a"),
|
||||
Values(false), // use_dbias
|
||||
Values(0.0f, 0.2f), // p_drop
|
||||
Values(std::tuple{123, 1024, true}), // seed/offset/prefs
|
||||
Values(std::tuple{1, 4, 2, 259, -1, "0"},
|
||||
std::tuple{2, 2, -1, 516, 253, "0"},
|
||||
std::tuple{1, 4, 1, 500, 251, "1"},
|
||||
std::tuple{1, 2, -1, 900, 258, "2"},
|
||||
std::tuple{2, 1, -1, 987, 219, "t:128,30"},
|
||||
std::tuple{2, 3, 1, 244, 499, "b:4,35"}),
|
||||
Values(false) // deterministic
|
||||
));
|
||||
TEST_P(AllLong, DataTypeConfig) { fmha_bwd_test(GetParam()); }
|
||||
|
||||
class HDimPadding : public TestWithParam<FmhaBwdTestParam>
|
||||
{
|
||||
};
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
HDimPadding,
|
||||
Combine(ModeValues,
|
||||
Values(std::tuple{24, 48},
|
||||
std::tuple{48, 48},
|
||||
std::tuple{72, 72},
|
||||
std::tuple{96, 96},
|
||||
std::tuple{120, 160},
|
||||
std::tuple{256, 108},
|
||||
std::tuple{40, 64}),
|
||||
Values(std::tuple{true, true}, std::tuple{false, false}), // perm
|
||||
Values("n"), // bias_str
|
||||
Values(false), // use_dbias
|
||||
Values(0.0f), // p_drop
|
||||
Values(std::tuple{0, 0, false}), // seed/offset/prefs
|
||||
Values(std::tuple{1, 4, 2, 480, -1, "0"},
|
||||
std::tuple{2, 2, -1, 300, 400, "t:64,64"},
|
||||
std::tuple{1, 4, 1, 512, 201, "1"},
|
||||
std::tuple{1, 2, -1, 900, 256, "0"},
|
||||
std::tuple{2, 1, -1, 256, 256, "1"}),
|
||||
Values(false) // deterministic
|
||||
));
|
||||
TEST_P(HDimPadding, DataTypeConfig) { fmha_bwd_test(GetParam()); }
|
||||
|
||||
class ElementwiseBias : public TestWithParam<FmhaBwdTestParam>
|
||||
{
|
||||
};
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
ElementwiseBias,
|
||||
Combine(ModeValues,
|
||||
HDimValues,
|
||||
// layouts of bias and dbias are controlled by i_perm
|
||||
Values(std::tuple{true, false}, std::tuple{false, false}),
|
||||
Values("e:0", "e:1", "e:2"),
|
||||
Bool(), // use_dbias
|
||||
Values(0.0f), // p_drop
|
||||
Values(std::tuple{0, 0, false}), // seed/offset/prefs
|
||||
Values(std::tuple{1, 4, 2, 1024, 100, "0"},
|
||||
std::tuple{3, 2, -1, 128, 256, "2"},
|
||||
std::tuple{2, 2, -1, 130, 499, "t:50,64"}),
|
||||
Values(false) // deterministic
|
||||
));
|
||||
TEST_P(ElementwiseBias, DataTypeConfig) { fmha_bwd_test(GetParam()); }
|
||||
class Alibi : public TestWithParam<FmhaBwdTestParam>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TestCkTileFmhaBwd,
|
||||
Alibi,
|
||||
Combine(ModeValues,
|
||||
HDimValues,
|
||||
Values(std::tuple{true, true}), // perm
|
||||
Values("a:0", "a:1"),
|
||||
Values(false), // use_dbias
|
||||
Values(0.0f), // p_drop
|
||||
Values(std::tuple{0, 0, false}), // seed/offset/prefs
|
||||
ValuesIn([]() {
|
||||
const std::array dims{
|
||||
std::tuple{1, 3, 3, 1024, 1000},
|
||||
std::tuple{3, 5, 5, 128, 256},
|
||||
std::tuple{2, 8, 4, 130, 320},
|
||||
};
|
||||
const std::array mask_strs{"0", "t", "b", "t:50,64", "b:32,40"};
|
||||
std::vector<FmhaBwdDimsMaskParam> dims_masks;
|
||||
std::for_each(dims.begin(), dims.end(), [&](const auto& d) {
|
||||
const auto& [b, h, hk, sq, sk] = d;
|
||||
std::for_each(mask_strs.begin(), mask_strs.end(), [&](const auto& m) {
|
||||
dims_masks.push_back(std::tuple{b, h, hk, sq, sk, m});
|
||||
});
|
||||
});
|
||||
return dims_masks;
|
||||
}()),
|
||||
Values(false) // deterministic
|
||||
));
|
||||
TEST_P(Alibi, DataTypeConfig) { fmha_bwd_test(GetParam()); }
|
||||
|
||||
class Dropout : public TestWithParam<FmhaBwdTestParam>
|
||||
{
|
||||
};
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
Dropout,
|
||||
Combine(ModeValues,
|
||||
HDimValues,
|
||||
Values(std::tuple{true, true}), // perm
|
||||
Values("n"), // bias_str
|
||||
Values(false), // use_dbias
|
||||
Values(0.123f, 0.5f), // p_drop
|
||||
Values(std::tuple{10, 123, false}, // seed/offset/prefs
|
||||
std::tuple{34534564645, 7876878876864, true}),
|
||||
Values(std::tuple{2, 6, 2, 180, 512, "0"},
|
||||
std::tuple{3, 2, 2, 256, 128, "1"},
|
||||
std::tuple{4, 2, 1, 100, 768, "2"}),
|
||||
Values(false) // deterministic
|
||||
));
|
||||
|
||||
TEST_P(Dropout, DataTypeConfig) { fmha_bwd_test(GetParam()); }
|
||||
|
||||
class Deterministic : public TestWithParam<FmhaBwdTestParam>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
Deterministic,
|
||||
Combine(ModeValues,
|
||||
HDimValues,
|
||||
Values(std::tuple{false, true}, std::tuple{true, true}), // perm
|
||||
Values("n"), // bias_str
|
||||
Values(false), // use_dbias
|
||||
Values(0.0f), // p_drop
|
||||
Values(std::tuple{0, 0, false}), // seed/offset/prefs
|
||||
Values(std::tuple{2, 6, 2, 180, 512, "0"},
|
||||
std::tuple{3, 3, 1, 256, 128, "1"},
|
||||
std::tuple{4, 2, 2, 768, 100, "2"}),
|
||||
Values(true) // deterministic
|
||||
));
|
||||
TEST_P(Deterministic, DataTypeConfig) { fmha_bwd_test(GetParam()); }
|
||||
@@ -1,347 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
using ::testing::Bool;
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
// Random seed used for initializing input tensors. 0 for non-deterministic seed
|
||||
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
|
||||
|
||||
// Whether to run long tests (from smoke_test_fwd.sh)
|
||||
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS)
|
||||
|
||||
#define CHECK_RESULT(result) \
|
||||
do \
|
||||
{ \
|
||||
if(result == bwd_result::no_instance) \
|
||||
GTEST_SKIP() << "No instance for current parameters"; \
|
||||
ASSERT_EQ(result, bwd_result::success); \
|
||||
} while(0)
|
||||
|
||||
const ck_tile::stream_config stream_config{
|
||||
nullptr, // stream_id_
|
||||
false, // time_kernel_
|
||||
1, // log_level_
|
||||
0, // cold_niters_
|
||||
1, // nrepeat_
|
||||
true, // is_gpu_timer_
|
||||
false, // flush_cache_
|
||||
1, // rotating_count_
|
||||
};
|
||||
|
||||
#define COMMON_ARGS \
|
||||
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \
|
||||
stream_config
|
||||
|
||||
auto EnableTestIf(bool condition)
|
||||
{
|
||||
return ValuesIn(condition ? std::vector<bool>{true} : std::vector<bool>{});
|
||||
}
|
||||
|
||||
class AllLong : public TestWithParam<std::tuple<bool,
|
||||
std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::string,
|
||||
float,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong);
|
||||
|
||||
// Test cases from example/ck_tile/01_fmha/script/smoke_test_bwd.sh
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TestCkTileFmhaBwd,
|
||||
AllLong,
|
||||
Combine(EnableTestIf(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))),
|
||||
HDimValues,
|
||||
Bool(),
|
||||
ModeValues,
|
||||
Values("n", "a"),
|
||||
Values(0.0f, 0.2f),
|
||||
Values(std::tuple{1, 4, 2, 259, -1, "0"},
|
||||
std::tuple{2, 2, -1, 516, 253, "0"},
|
||||
std::tuple{1, 4, 1, 500, 251, "1"},
|
||||
std::tuple{1, 2, -1, 900, 258, "2"},
|
||||
std::tuple{2, 1, -1, 987, 219, "t:128,30"},
|
||||
std::tuple{2, 3, 1, 244, 499, "b:4,35"})));
|
||||
|
||||
TEST_P(AllLong, Test)
|
||||
{
|
||||
auto [_, hdims, perm, mode, bias_str, p_drop, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
perm, // i_perm
|
||||
perm, // o_perm
|
||||
0, // scale
|
||||
bias_str, // bias_str
|
||||
false, // use_dbias
|
||||
p_drop, // p_drop
|
||||
123, // drop_seed
|
||||
1024, // drop_offset
|
||||
true, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class HDimPadding
|
||||
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
HDimPadding,
|
||||
Combine(Values(std::tuple{24, 48},
|
||||
std::tuple{48, 48},
|
||||
std::tuple{72, 72},
|
||||
std::tuple{96, 96},
|
||||
std::tuple{120, 160},
|
||||
std::tuple{256, 108},
|
||||
std::tuple{40, 64}),
|
||||
Bool(),
|
||||
ModeValues,
|
||||
Values(std::tuple{1, 4, 2, 480, -1, "0"},
|
||||
std::tuple{2, 2, -1, 300, 400, "t:64,64"},
|
||||
std::tuple{1, 4, 1, 512, 201, "1"},
|
||||
std::tuple{1, 2, -1, 900, 256, "0"},
|
||||
std::tuple{2, 1, -1, 256, 256, "1"})));
|
||||
|
||||
TEST_P(HDimPadding, Test)
|
||||
{
|
||||
auto [hdims, perm, mode, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
perm, // i_perm
|
||||
perm, // o_perm
|
||||
0, // scale
|
||||
"n", // bias_str
|
||||
false, // use_dbias
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class ElementwiseBias
|
||||
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::string,
|
||||
bool,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
ElementwiseBias,
|
||||
Combine(HDimValues,
|
||||
Bool(), // layouts of bias and dbias are controlled by i_perm
|
||||
ModeValues,
|
||||
Values("e:0", "e:1", "e:2"),
|
||||
Bool(),
|
||||
Values(std::tuple{1, 4, 2, 1024, 100, "0"},
|
||||
std::tuple{3, 2, -1, 128, 256, "2"},
|
||||
std::tuple{2, 2, -1, 130, 499, "t:50,64"})));
|
||||
|
||||
TEST_P(ElementwiseBias, Test)
|
||||
{
|
||||
auto [hdims, i_perm, mode, bias_str, use_dbias, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm, // i_perm
|
||||
false, // o_perm
|
||||
0, // scale
|
||||
bias_str, // bias_str
|
||||
use_dbias, // use_dbias
|
||||
0.0f, // p_drop
|
||||
123, // drop_seed
|
||||
1024, // drop_offset
|
||||
true, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class Alibi : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
mode_enum,
|
||||
std::string,
|
||||
std::tuple<int, int, int, int, int>,
|
||||
std::string>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
Alibi,
|
||||
Combine(HDimValues,
|
||||
ModeValues,
|
||||
Values("a:0", "a:1"),
|
||||
Values(std::tuple{1, 3, 3, 1024, 1000},
|
||||
std::tuple{3, 5, 5, 128, 256},
|
||||
std::tuple{2, 8, 4, 130, 320}),
|
||||
Values("0", "t", "b", "t:50,64", "b:32,40")));
|
||||
|
||||
TEST_P(Alibi, Test)
|
||||
{
|
||||
auto [hdims, mode, bias_str, dims, mask_str] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k] = dims;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
true, // i_perm
|
||||
true, // o_perm
|
||||
0, // scale
|
||||
bias_str, // bias_str
|
||||
false, // use_dbias
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class Dropout : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
mode_enum,
|
||||
float,
|
||||
std::tuple<uint64_t, uint64_t, bool>,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
Dropout,
|
||||
Combine(HDimValues,
|
||||
ModeValues,
|
||||
Values(0.123f, 0.5f),
|
||||
Values(std::tuple{10, 123, false},
|
||||
std::tuple{34534564645, 7876878876864, true}),
|
||||
Values(std::tuple{2, 6, 2, 180, 512, "0"},
|
||||
std::tuple{3, 2, 2, 256, 128, "1"},
|
||||
std::tuple{4, 2, 1, 100, 768, "2"})));
|
||||
|
||||
TEST_P(Dropout, Test)
|
||||
{
|
||||
auto [hdims, mode, p_drop, drop_seed_offset_prefs, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
true, // i_perm
|
||||
true, // o_perm
|
||||
0.1f, // scale
|
||||
"n", // bias_str
|
||||
false, // use_dbias
|
||||
p_drop, // p_drop
|
||||
drop_seed, // drop_seed
|
||||
drop_offset, // drop_offset
|
||||
drop_prefs, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class Deterministic
|
||||
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
Deterministic,
|
||||
Combine(HDimValues,
|
||||
Bool(),
|
||||
ModeValues,
|
||||
Values(std::tuple{2, 6, 2, 180, 512, "0"},
|
||||
std::tuple{3, 3, 1, 256, 128, "1"},
|
||||
std::tuple{4, 2, 2, 768, 100, "2"})));
|
||||
|
||||
TEST_P(Deterministic, Test)
|
||||
{
|
||||
auto [hdims, i_perm, mode, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm, // i_perm
|
||||
true, // o_perm
|
||||
0, // scale
|
||||
"n", // bias_str
|
||||
false, // use_dbias
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
true, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using DataTypeConfig = FmhaBwdBf16;
|
||||
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
const auto HDimValues =
|
||||
Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
constexpr auto init_method = "uf";
|
||||
|
||||
#include "test_fmha_bwd.inc"
|
||||
@@ -1,21 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using DataTypeConfig = FmhaBwdFp16;
|
||||
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
const auto HDimValues =
|
||||
Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
constexpr auto init_method = "uf";
|
||||
|
||||
#include "test_fmha_bwd.inc"
|
||||
@@ -1,20 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using DataTypeConfig = FmhaBwdFp32;
|
||||
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
const auto HDimValues = Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1});
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
const std::string init_method = "uf";
|
||||
|
||||
#include "test_fmha_bwd.inc"
|
||||
@@ -1,12 +1,104 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#ifndef DataTypeConfig
|
||||
#define DataTypeConfig FmhaFwdFp16 // or FmhaFwdBf16 / FmhaFwdFp8 / FmhaFwdFp32
|
||||
#endif
|
||||
|
||||
using ::testing::Bool;
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
template <typename T>
|
||||
struct TestConfigs
|
||||
{
|
||||
static constexpr auto HDimValues = std::array{
|
||||
std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, 128},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{192, 128},
|
||||
std::tuple{192, -1},
|
||||
std::tuple{256, -1},
|
||||
};
|
||||
static constexpr auto SplitKVHDimValues = std::array{
|
||||
std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, -1},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{256, -1},
|
||||
};
|
||||
static constexpr auto AppendKVHDimValues = std::array{
|
||||
std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
||||
static constexpr auto IsVRowmajorValues = std::array{false, true};
|
||||
static constexpr bool squant = false;
|
||||
static constexpr bool def_lse = true;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
static int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
};
|
||||
template <>
|
||||
struct TestConfigs<FmhaFwdFp8>
|
||||
{
|
||||
// Currently there are no fp8 instances for splitkv, pagedkv by default (the tests pass if such
|
||||
// instances are added), however the corresponding tests are not disabled (they will be skipped)
|
||||
// in case such instances will be added in the future.
|
||||
|
||||
static constexpr auto HDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}};
|
||||
static constexpr auto SplitKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}};
|
||||
static constexpr auto AppendKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}};
|
||||
// There are no fp8 instances with seqlen padding (mode_enum::group requires it)
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch};
|
||||
static constexpr auto IsVRowmajorValues = std::array{false};
|
||||
static constexpr bool squant = true;
|
||||
static constexpr bool def_lse = false;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
static int adjust_seqlen(int seqlen)
|
||||
{
|
||||
// There are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests
|
||||
return ck_tile::integer_least_multiple(seqlen, 128);
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct TestConfigs<FmhaFwdFp32>
|
||||
{
|
||||
static constexpr auto HDimValues = std::array{
|
||||
std::tuple{32, -1},
|
||||
std::tuple{48, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, 128},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{192, -1},
|
||||
std::tuple{256, -1},
|
||||
};
|
||||
static constexpr auto SplitKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
||||
static constexpr auto AppendKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
||||
static constexpr auto IsVRowmajorValues = std::array{true};
|
||||
static constexpr bool squant = false;
|
||||
static constexpr bool def_lse = true;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
static int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
};
|
||||
|
||||
static auto HDimValues = ValuesIn(TestConfigs<DataTypeConfig>::HDimValues);
|
||||
static auto SplitKVHDimValues = ValuesIn(TestConfigs<DataTypeConfig>::SplitKVHDimValues);
|
||||
static auto AppendKVHDimValues = ValuesIn(TestConfigs<DataTypeConfig>::AppendKVHDimValues);
|
||||
static auto ModeValues = ValuesIn(TestConfigs<DataTypeConfig>::ModeValues);
|
||||
static auto IsVRowmajorValues = ValuesIn(TestConfigs<DataTypeConfig>::IsVRowmajorValues);
|
||||
constexpr bool squant = TestConfigs<DataTypeConfig>::squant;
|
||||
constexpr bool def_lse = TestConfigs<DataTypeConfig>::def_lse;
|
||||
constexpr bool def_is_v_rowmajor = TestConfigs<DataTypeConfig>::def_is_v_rowmajor;
|
||||
int adjust_seqlen(int seqlen) { return TestConfigs<DataTypeConfig>::adjust_seqlen(seqlen); }
|
||||
constexpr auto init_method = "uf";
|
||||
|
||||
// Random seed used for initializing input tensors. 0 for non-deterministic seed
|
||||
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
|
||||
|
||||
@@ -79,7 +171,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
std::tuple{1, 2, 1, -1, -1, 33, 0, -1, "2"},
|
||||
std::tuple{1, 2, 1, -1, -1, 1, 10, 32, "2"})));
|
||||
|
||||
TEST_P(AllLong, Test)
|
||||
TEST_P(AllLong, DataTypeConfig)
|
||||
{
|
||||
auto [_, hdims, perm, is_v_rowmajor, mode, lse, bias_str, p_drop, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
@@ -283,7 +375,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
std::tuple{1, 2, -1, 900, 256, -1, "0"},
|
||||
std::tuple{2, 1, -1, 256, 256, -1, "1"})));
|
||||
|
||||
TEST_P(HDimPadding, Test)
|
||||
TEST_P(HDimPadding, DataTypeConfig)
|
||||
{
|
||||
auto [hdims, perm, is_v_rowmajor, mode, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
@@ -343,7 +435,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
std::tuple{3, 2, -1, 128, 256, "2"},
|
||||
std::tuple{2, 2, -1, 130, 499, "t:50,64"})));
|
||||
|
||||
TEST_P(ElementwiseBias, Test)
|
||||
TEST_P(ElementwiseBias, DataTypeConfig)
|
||||
{
|
||||
auto [hdims, i_perm, mode, bias_str, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
@@ -402,7 +494,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
std::tuple{2, 8, 2, 300, 355}),
|
||||
Values("0", "t", "b", "t:50,64", "b:32,40")));
|
||||
|
||||
TEST_P(Alibi, Test)
|
||||
TEST_P(Alibi, DataTypeConfig)
|
||||
{
|
||||
auto [hdims, mode, bias_str, dims, mask_str] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
@@ -462,7 +554,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
std::tuple{3, 2, 2, 256, 128, "1"},
|
||||
std::tuple{4, 3, 1, 100, 768, "2"})));
|
||||
|
||||
TEST_P(Dropout, Test)
|
||||
TEST_P(Dropout, DataTypeConfig)
|
||||
{
|
||||
auto [hdims, mode, p_drop, drop_seed_offset_prefs, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
@@ -528,7 +620,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
std::tuple{3, 2, -1, 128, 768, "2"},
|
||||
std::tuple{2, 2, -1, 230, 899, "t:50,64"})));
|
||||
|
||||
TEST_P(PagedKV, Test)
|
||||
TEST_P(PagedKV, DataTypeConfig)
|
||||
{
|
||||
auto [hdims, i_perm, is_v_rowmajor, mode, page_block_size, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
@@ -597,7 +689,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
std::tuple{2, 2, -1, 512, 2000, "0"},
|
||||
std::tuple{3, 2, -1, 230, 899, "t:128,128"})));
|
||||
|
||||
TEST_P(SplitKV, Test)
|
||||
TEST_P(SplitKV, DataTypeConfig)
|
||||
{
|
||||
auto [hdims, i_perm, is_v_rowmajor, mode_use_cache_batch_idx, num_splits, dims_mask] =
|
||||
GetParam();
|
||||
@@ -668,7 +760,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKV);
|
||||
|
||||
TEST_P(AppendKV, Test)
|
||||
TEST_P(AppendKV, DataTypeConfig)
|
||||
{
|
||||
auto [hdims,
|
||||
i_perm,
|
||||
@@ -745,7 +837,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
std::tuple{1, 2, 1, 128, 55, "0"},
|
||||
std::tuple{3, 4, 2, 72, 128, "1"})));
|
||||
|
||||
TEST_P(AppendKVRoPE, Test)
|
||||
TEST_P(AppendKVRoPE, DataTypeConfig)
|
||||
{
|
||||
auto [_, hdims, i_perm, is_v_rowmajor, rotary, seqlen_knew, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
@@ -1017,7 +1109,7 @@ static const std::vector<PaddingParam> kPaddingParams = BuildPaddingParams();
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd_Padding, PaddingCases, ValuesIn(kPaddingParams));
|
||||
|
||||
TEST_P(PaddingCases, Test)
|
||||
TEST_P(PaddingCases, DataTypeConfig)
|
||||
{
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
|
||||
{
|
||||
@@ -1,44 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
|
||||
using ::testing::Values;
|
||||
|
||||
using DataTypeConfig = FmhaFwdBf16;
|
||||
|
||||
const auto HDimValues = Values(std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, 128},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{192, 128},
|
||||
std::tuple{192, -1},
|
||||
std::tuple{256, -1});
|
||||
|
||||
const auto SplitKVHDimValues = Values(std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, -1},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{256, -1});
|
||||
|
||||
const auto AppendKVHDimValues =
|
||||
Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
const auto IsVRowmajorValues = Values(false, true);
|
||||
|
||||
const bool squant = false;
|
||||
const std::string init_method = "uf";
|
||||
const bool def_lse = true;
|
||||
const bool def_is_v_rowmajor = true;
|
||||
|
||||
int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
|
||||
#include "test_fmha_fwd.inc"
|
||||
@@ -1,44 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
|
||||
using ::testing::Values;
|
||||
|
||||
using DataTypeConfig = FmhaFwdFp16;
|
||||
|
||||
const auto HDimValues = Values(std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, 128},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{192, 128},
|
||||
std::tuple{192, -1},
|
||||
std::tuple{256, -1});
|
||||
|
||||
const auto SplitKVHDimValues = Values(std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, -1},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{256, -1});
|
||||
|
||||
const auto AppendKVHDimValues =
|
||||
Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
const auto IsVRowmajorValues = Values(false, true);
|
||||
|
||||
const bool squant = false;
|
||||
const std::string init_method = "uf";
|
||||
const bool def_lse = true;
|
||||
const bool def_is_v_rowmajor = true;
|
||||
|
||||
int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
|
||||
#include "test_fmha_fwd.inc"
|
||||
@@ -1,39 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
|
||||
using ::testing::Values;
|
||||
|
||||
using DataTypeConfig = FmhaFwdFp32;
|
||||
|
||||
const auto HDimValues = Values(std::tuple{32, -1},
|
||||
std::tuple{48, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, 128},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{192, -1},
|
||||
std::tuple{256, -1});
|
||||
|
||||
const auto SplitKVHDimValues = Values();
|
||||
|
||||
const auto AppendKVHDimValues = Values();
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
const auto IsVRowmajorValues = Values(true);
|
||||
|
||||
const bool squant = false;
|
||||
const std::string init_method = "uf";
|
||||
const bool def_lse = true;
|
||||
const bool def_is_v_rowmajor = true;
|
||||
|
||||
int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
|
||||
#include "test_fmha_fwd.inc"
|
||||
@@ -1,42 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
|
||||
using ::testing::Values;
|
||||
|
||||
using DataTypeConfig = FmhaFwdFp8;
|
||||
|
||||
// Currently there are no fp8 instances for splitkv, pagedkv by default (the tests pass if such
|
||||
// instances are added), however the corresponding tests are not disabled (they will be skipped)
|
||||
// in case such instances will be added in the future.
|
||||
|
||||
const auto HDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1});
|
||||
|
||||
const auto SplitKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1});
|
||||
|
||||
const auto AppendKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1});
|
||||
|
||||
// There are no fp8 instances with seqlen padding (mode_enum::group requires it)
|
||||
const auto ModeValues = Values(mode_enum::batch);
|
||||
|
||||
const auto IsVRowmajorValues = Values(false);
|
||||
|
||||
const auto squant = true;
|
||||
const std::string init_method = "uf";
|
||||
const bool def_lse = false;
|
||||
const bool def_is_v_rowmajor = true;
|
||||
|
||||
int adjust_seqlen(int seqlen)
|
||||
{
|
||||
// There are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests
|
||||
return ck_tile::integer_least_multiple(seqlen, 128);
|
||||
}
|
||||
|
||||
#include "test_fmha_fwd.inc"
|
||||
Reference in New Issue
Block a user