mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
[CK_TILE] Add gtests for FMHA (#2744)
* Improve random number generation
* use different seed for each input (Q, K, V...);
* use deterministic generation of:
* seqstart_q/k (for group mode);
* block_table (for paged-kvcahe);
* cache_batch_idx (for kvcache);
* Extract arg_parser-related code from run functions to use them as tests
* Split examples into main programs and fmha runners, build instances separately
* Add dummy tests that use instances and runners
* Fix a missed corner case of f32->f8 conversion
When value if < min f8 denormal but > min f8 denormal / 2, it must be
rounded to min f8 denormal (i.e. 0b1), not to 0.
* Fix incorrect fp8 scales for P and O in validation code
DataTypeConfig was incorrectly compared with fp8_t.
* Add host generation of dropout random values and use it for validation
Previously host validation (reference_batched_dropout) used random
numbers generated by BlockDropout of the kernel, meaning that incorrect
generation on device (bad distribution, repeated numbers, too many zeros,
etc.) would not trigger any validation errors.
* Implement tests from smoke_test_bwd.sh
* Return result as enum to distinguish failure and missing instance
* Add tests for bwd features: bias, alibi, dropout
* Implement tests from smoke_test_fwd.sh
* Pass seqlen_q/k as vectors to fwd and bwd runners
* Add tests for fwd features: bias, alibi, dropout
* Add tests for pagedkv and splitkv
* Fix conditions when to use splitkv and pagedkv kernels
splitkv was executed only when use_kvcache which == (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size).
In the SplitKV tests: the regular fwd kernel was executed if use_cache_batch_idx was not requested even when num_splitkv > 1.
In the AppendKV tests: the pagedkv kernel was executed but it often failed to find an instance.
* Add tests for appendkv
* Use is_v_rowmajor = true because there are no instances with column layout anymore
* Split public and private compile options for instances
Tests and examples need to know only about CK_TILE_FMHA_FWD_*_API.
* Improve parsing validation in bias and mask
* Pass bias as string for consistency with mask
* Catch parsing and other exceptions
* Add bwd test for deterministic flag
* Initialize fp8 tensors (-init=ufq) similarly to uf
* Fix splitkv/pagedkv invocation: use padded sk when seqlen_k_ptr is not null
seqlen_k cannot be used to determine padding when seqlen_k_ptr is
provided. The actual seqlen_k is taken from seqlen_k_ptr[b].
Even seqlen_k values (% bn0 == 0) use padded seqlen_k while seqlen_k_ptr
may contain arbitrary values.
In the example or tests this produces incorrect results with appendkv
(for example, -d=32 -s=1 -s_k=64 -s_knew=7 -vlayout=c -b=8).
* Fix use_pagedkv value when kvcache = true but page_block_size = 0
In this case block_table_ptr is nullptr which is accessed in the kernel.
* Clean up bwd tests
* Unify fwd tests for f16/bf16 and fp8
* Use better explicit instantiation declaration for fmha_bwd<2>
* Use the same seed for all tests, allow to override it with env variable
* Undo clang-format of one irrelevant file
For some reason my local clang-format-18 and the one in CI work differently.
* Do not build instances and tests on unsupported archs
* Build instance libraries as OBJECT library
* CI: Enable sccache for HIP
There are source files with LANGUAGE HIP, they need
-DCMAKE_HIP_COMPILER_LAUNCHER=sccache
* Add tests to REGRESSION_TESTS
* Fix OOB accesses in deterministic bwd due to incorrectly assumed kN0
The runner assumes kN0 = (hdim_q <= 128) ? 128 : 64 but there are
smaller tiles (for tr_load or fp32). This can create too small dq_acc_buf.
* Pass CK_TILE_FMHA_FWD_*_API as INTERFACE compile options
The instances don't actually depend on them, only examples and tests do.
Passing these definitions as INTERFACE allows to change FMHA_FWD_ENABLE_APIS
without recompiling instances that are already in ccache.
* Fix formatting and names
[ROCm/composable_kernel commit: ec006bb8e0]
This commit is contained in:
@@ -38,6 +38,11 @@ set(REGRESSION_TESTS
|
||||
test_conv_tensor_rearrange
|
||||
test_gemm_mx
|
||||
test_ck_tile_batched_transpose
|
||||
test_ck_tile_fmha_bwd_bf16
|
||||
test_ck_tile_fmha_bwd_fp16
|
||||
test_ck_tile_fmha_fwd_bf16
|
||||
test_ck_tile_fmha_fwd_fp16
|
||||
test_ck_tile_fmha_fwd_fp8
|
||||
)
|
||||
|
||||
function(add_test_executable TEST_NAME)
|
||||
|
||||
@@ -25,3 +25,4 @@ add_subdirectory(utility)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(epilogue)
|
||||
add_subdirectory(atomic_add_op)
|
||||
add_subdirectory(fmha)
|
||||
|
||||
@@ -94,7 +94,9 @@ TYPED_TEST(ConvertTest, ToFp8)
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'0000'000);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b1'0000'000);
|
||||
|
||||
// All values smaller than min f8 subnormal must be converted to f8 zero
|
||||
// All values <= min f8 subnormal/2 must be converted to f8 zero
|
||||
EXPECT_EQ(c(+0.001953125f * 0.6f), 0b0'0000'001);
|
||||
EXPECT_EQ(c(-0.001953125f * 0.6f), 0b1'0000'001);
|
||||
constexpr int src_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
|
||||
constexpr int dst_min_subnorm_exp =
|
||||
@@ -176,7 +178,9 @@ TYPED_TEST(ConvertTest, ToFp8)
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'0000'000);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b0'0000'000);
|
||||
|
||||
// All values smaller than min f8 subnormal must be converted to f8 zero
|
||||
// All values <= min f8 subnormal/2 must be converted to f8 zero
|
||||
EXPECT_EQ(c(+0.0009765625f * 0.6f), 0b0'0000'001);
|
||||
EXPECT_EQ(c(-0.0009765625f * 0.6f), 0b1'0000'001);
|
||||
constexpr int src_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
|
||||
constexpr int dst_min_subnorm_exp =
|
||||
@@ -282,7 +286,9 @@ TYPED_TEST(ConvertTest, ToBf8)
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'00000'00);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b1'00000'00);
|
||||
|
||||
// All values smaller than min f8 subnormal must be converted to f8 zero
|
||||
// All values <= min f8 subnormal/2 must be converted to f8 zero
|
||||
EXPECT_EQ(c(+1.52587890625e-05f * 0.6f), 0b0'0000'001);
|
||||
EXPECT_EQ(c(-1.52587890625e-05f * 0.6f), 0b1'0000'001);
|
||||
constexpr int src_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
|
||||
constexpr int dst_min_subnorm_exp =
|
||||
@@ -373,7 +379,9 @@ TYPED_TEST(ConvertTest, ToBf8)
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'00000'00);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b0'00000'00);
|
||||
|
||||
// All values smaller than min f8 subnormal must be converted to f8 zero
|
||||
// All values <= min f8 subnormal/2 must be converted to f8 zero
|
||||
EXPECT_EQ(c(+7.62939453125e-06f * 0.6f), 0b0'0000'001);
|
||||
EXPECT_EQ(c(-7.62939453125e-06f * 0.6f), 0b1'0000'001);
|
||||
constexpr int src_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
|
||||
constexpr int dst_min_subnorm_exp =
|
||||
|
||||
31
test/ck_tile/fmha/CMakeLists.txt
Normal file
31
test/ck_tile/fmha/CMakeLists.txt
Normal file
@@ -0,0 +1,31 @@
|
||||
# Keep in sync with example/ck_tile/01_fmha/CMakeLists.txt
|
||||
if(NOT SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
|
||||
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_bwd_bf16 test_fmha_bwd_bf16.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_bwd_bf16 PRIVATE ${FMHA_BWD_INSTANCES})
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_bwd_fp16 test_fmha_bwd_fp16.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_bwd_fp16 PRIVATE ${FMHA_BWD_INSTANCES})
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_fwd_bf16 test_fmha_fwd_bf16.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_fwd_bf16 PRIVATE ${FMHA_FWD_INSTANCES})
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_fwd_fp16 test_fmha_fwd_fp16.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_fwd_fp16 PRIVATE ${FMHA_FWD_INSTANCES})
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_fwd_fp8 test_fmha_fwd_fp8.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_fwd_fp8 PRIVATE ${FMHA_FWD_INSTANCES})
|
||||
|
||||
add_custom_target(test_ck_tile_fmha
|
||||
DEPENDS
|
||||
test_ck_tile_fmha_bwd_bf16
|
||||
test_ck_tile_fmha_bwd_fp16
|
||||
test_ck_tile_fmha_fwd_bf16
|
||||
test_ck_tile_fmha_fwd_fp16
|
||||
test_ck_tile_fmha_fwd_fp8
|
||||
)
|
||||
344
test/ck_tile/fmha/test_fmha_bwd.inc
Normal file
344
test/ck_tile/fmha/test_fmha_bwd.inc
Normal file
@@ -0,0 +1,344 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
using ::testing::Bool;
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
// Random seed used for initializing input tensors. 0 for non-deterministic seed
|
||||
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
|
||||
|
||||
// Whether to run long tests (from smoke_test_fwd.sh)
|
||||
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS)
|
||||
|
||||
#define CHECK_RESULT(result) \
|
||||
do \
|
||||
{ \
|
||||
if(result == bwd_result::no_instance) \
|
||||
GTEST_SKIP() << "No instance for current parameters"; \
|
||||
ASSERT_EQ(result, bwd_result::success); \
|
||||
} while(0)
|
||||
|
||||
const ck_tile::stream_config stream_config{
|
||||
nullptr, // stream_id_
|
||||
false, // time_kernel_
|
||||
1, // log_level_
|
||||
0, // cold_niters_
|
||||
1, // nrepeat_
|
||||
true, // is_gpu_timer_
|
||||
false, // flush_cache_
|
||||
1, // rotating_count_
|
||||
};
|
||||
|
||||
#define COMMON_ARGS \
|
||||
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \
|
||||
stream_config
|
||||
|
||||
auto EnableTestIf(bool condition)
|
||||
{
|
||||
return ValuesIn(condition ? std::vector<bool>{true} : std::vector<bool>{});
|
||||
}
|
||||
|
||||
class AllLong : public TestWithParam<std::tuple<bool,
|
||||
std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::string,
|
||||
float,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong);
|
||||
|
||||
// Test cases from example/ck_tile/01_fmha/script/smoke_test_bwd.sh
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TestCkTileFmhaBwd,
|
||||
AllLong,
|
||||
Combine(EnableTestIf(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))),
|
||||
HDimValues,
|
||||
Bool(),
|
||||
ModeValues,
|
||||
Values("n", "a"),
|
||||
Values(0.0f, 0.2f),
|
||||
Values(std::tuple{1, 4, 2, 259, -1, "0"},
|
||||
std::tuple{2, 2, -1, 516, 253, "0"},
|
||||
std::tuple{1, 4, 1, 500, 251, "1"},
|
||||
std::tuple{1, 2, -1, 900, 258, "2"},
|
||||
std::tuple{2, 1, -1, 987, 219, "t:128,30"},
|
||||
std::tuple{2, 3, 1, 244, 499, "b:4,35"})));
|
||||
|
||||
TEST_P(AllLong, Test)
|
||||
{
|
||||
auto [_, hdims, perm, mode, bias_str, p_drop, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
perm, // i_perm
|
||||
perm, // o_perm
|
||||
0, // scale
|
||||
bias_str, // bias_str
|
||||
false, // use_dbias
|
||||
p_drop, // p_drop
|
||||
123, // drop_seed
|
||||
1024, // drop_offset
|
||||
true, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class HDimPadding
|
||||
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
HDimPadding,
|
||||
Combine(Values(std::tuple{24, 48},
|
||||
std::tuple{120, 160},
|
||||
std::tuple{256, 108},
|
||||
std::tuple{40, 64}),
|
||||
Bool(),
|
||||
ModeValues,
|
||||
Values(std::tuple{1, 4, 2, 480, -1, "0"},
|
||||
std::tuple{2, 2, -1, 300, 400, "t:64,64"},
|
||||
std::tuple{1, 4, 1, 512, 201, "1"},
|
||||
std::tuple{1, 2, -1, 900, 256, "0"},
|
||||
std::tuple{2, 1, -1, 256, 256, "1"})));
|
||||
|
||||
TEST_P(HDimPadding, Test)
|
||||
{
|
||||
auto [hdims, perm, mode, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
perm, // i_perm
|
||||
perm, // o_perm
|
||||
0, // scale
|
||||
"n", // bias_str
|
||||
false, // use_dbias
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class ElementwiseBias
|
||||
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::string,
|
||||
bool,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
ElementwiseBias,
|
||||
Combine(HDimValues,
|
||||
Bool(), // layouts of bias and dbias are controlled by i_perm
|
||||
ModeValues,
|
||||
Values("e:0", "e:1", "e:2"),
|
||||
Bool(),
|
||||
Values(std::tuple{1, 4, 2, 1024, 100, "0"},
|
||||
std::tuple{3, 2, -1, 128, 256, "2"},
|
||||
std::tuple{2, 2, -1, 130, 499, "t:50,64"})));
|
||||
|
||||
TEST_P(ElementwiseBias, Test)
|
||||
{
|
||||
auto [hdims, i_perm, mode, bias_str, use_dbias, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm, // i_perm
|
||||
false, // o_perm
|
||||
0, // scale
|
||||
bias_str, // bias_str
|
||||
use_dbias, // use_dbias
|
||||
0.0f, // p_drop
|
||||
123, // drop_seed
|
||||
1024, // drop_offset
|
||||
true, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class Alibi : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
mode_enum,
|
||||
std::string,
|
||||
std::tuple<int, int, int, int, int>,
|
||||
std::string>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
Alibi,
|
||||
Combine(HDimValues,
|
||||
ModeValues,
|
||||
Values("a:0", "a:1"),
|
||||
Values(std::tuple{1, 3, 3, 1024, 1000},
|
||||
std::tuple{3, 5, 5, 128, 256},
|
||||
std::tuple{2, 8, 4, 130, 320}),
|
||||
Values("0", "t", "b", "t:50,64", "b:32,40")));
|
||||
|
||||
TEST_P(Alibi, Test)
|
||||
{
|
||||
auto [hdims, mode, bias_str, dims, mask_str] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k] = dims;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
true, // i_perm
|
||||
true, // o_perm
|
||||
0, // scale
|
||||
bias_str, // bias_str
|
||||
false, // use_dbias
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class Dropout : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
mode_enum,
|
||||
float,
|
||||
std::tuple<uint64_t, uint64_t, bool>,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
Dropout,
|
||||
Combine(HDimValues,
|
||||
ModeValues,
|
||||
Values(0.123f, 0.5f),
|
||||
Values(std::tuple{10, 123, false},
|
||||
std::tuple{34534564645, 7876878876864, true}),
|
||||
Values(std::tuple{2, 6, 2, 180, 512, "0"},
|
||||
std::tuple{3, 2, 2, 256, 128, "1"},
|
||||
std::tuple{4, 2, 1, 100, 768, "2"})));
|
||||
|
||||
TEST_P(Dropout, Test)
|
||||
{
|
||||
auto [hdims, mode, p_drop, drop_seed_offset_prefs, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
true, // i_perm
|
||||
true, // o_perm
|
||||
0.1f, // scale
|
||||
"n", // bias_str
|
||||
false, // use_dbias
|
||||
p_drop, // p_drop
|
||||
drop_seed, // drop_seed
|
||||
drop_offset, // drop_offset
|
||||
drop_prefs, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class Deterministic
|
||||
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
Deterministic,
|
||||
Combine(HDimValues,
|
||||
Bool(),
|
||||
ModeValues,
|
||||
Values(std::tuple{2, 6, 2, 180, 512, "0"},
|
||||
std::tuple{3, 3, 1, 256, 128, "1"},
|
||||
std::tuple{4, 2, 2, 768, 100, "2"})));
|
||||
|
||||
TEST_P(Deterministic, Test)
|
||||
{
|
||||
auto [hdims, i_perm, mode, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm, // i_perm
|
||||
true, // o_perm
|
||||
0, // scale
|
||||
"n", // bias_str
|
||||
false, // use_dbias
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
true, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
21
test/ck_tile/fmha/test_fmha_bwd_bf16.cpp
Normal file
21
test/ck_tile/fmha/test_fmha_bwd_bf16.cpp
Normal file
@@ -0,0 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using DataTypeConfig = FmhaBwdBf16;
|
||||
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
const auto HDimValues =
|
||||
Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
constexpr std::string init_method = "uf";
|
||||
|
||||
#include "test_fmha_bwd.inc"
|
||||
21
test/ck_tile/fmha/test_fmha_bwd_fp16.cpp
Normal file
21
test/ck_tile/fmha/test_fmha_bwd_fp16.cpp
Normal file
@@ -0,0 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using DataTypeConfig = FmhaBwdFp16;
|
||||
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
const auto HDimValues =
|
||||
Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
constexpr std::string init_method = "uf";
|
||||
|
||||
#include "test_fmha_bwd.inc"
|
||||
628
test/ck_tile/fmha/test_fmha_fwd.inc
Normal file
628
test/ck_tile/fmha/test_fmha_fwd.inc
Normal file
@@ -0,0 +1,628 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
using ::testing::Bool;
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
// Random seed used for initializing input tensors. 0 for non-deterministic seed
|
||||
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
|
||||
|
||||
// Whether to run long tests (from smoke_test_fwd.sh)
|
||||
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS)
|
||||
|
||||
#define CHECK_RESULT(result) \
|
||||
do \
|
||||
{ \
|
||||
if(result == fwd_result::no_instance) \
|
||||
GTEST_SKIP() << "No instance for current parameters"; \
|
||||
ASSERT_EQ(result, fwd_result::success); \
|
||||
} while(0)
|
||||
|
||||
const ck_tile::stream_config stream_config{
|
||||
nullptr, // stream_id_
|
||||
false, // time_kernel_
|
||||
1, // log_level_
|
||||
0, // cold_niters_
|
||||
1, // nrepeat_
|
||||
true, // is_gpu_timer_
|
||||
false, // flush_cache_
|
||||
1, // rotating_count_
|
||||
};
|
||||
|
||||
// range_q, range_k, range_v, range_p, range_o, squant
|
||||
#define QUANT_ARGS 1, 1, 1, 1, 1, squant
|
||||
|
||||
#define COMMON_ARGS \
|
||||
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \
|
||||
stream_config
|
||||
|
||||
auto EnableTestIf(bool condition)
|
||||
{
|
||||
return ValuesIn(condition ? std::vector<bool>{true} : std::vector<bool>{});
|
||||
}
|
||||
|
||||
class AllLong : public TestWithParam<
|
||||
std::tuple<bool,
|
||||
std::tuple<int, int>,
|
||||
bool,
|
||||
bool,
|
||||
mode_enum,
|
||||
bool,
|
||||
std::string,
|
||||
float,
|
||||
std::tuple<int, int, int, int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong);
|
||||
|
||||
// Test cases from example/ck_tile/01_fmha/script/smoke_test_fwd.sh
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TestCkTileFmhaFwd,
|
||||
AllLong,
|
||||
Combine(EnableTestIf(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))),
|
||||
HDimValues,
|
||||
Bool(),
|
||||
IsVRowmajorValues,
|
||||
ModeValues,
|
||||
Bool(),
|
||||
Values("n", "e", "a"),
|
||||
Values(0.0f, 0.2f),
|
||||
Values(std::tuple{2, 2, 1, 16, -1, 55, 256, -1, "0"},
|
||||
std::tuple{1, 3, -1, -1, -1, 100, 51, -1, "0"},
|
||||
std::tuple{2, 1, -1, 16, -1, 99, 256, -1, "1"},
|
||||
std::tuple{1, 2, 1, -1, -1, 1024, 256, -1, "2"},
|
||||
std::tuple{2, 1, -1, -1, 24, 3, 99, -1, "2"},
|
||||
std::tuple{3, 2, 1, -1, -1, 200, 520, -1, "t:128,30"},
|
||||
std::tuple{2, 1, -1, -1, -1, 99, 32, -1, "b:4,35"},
|
||||
std::tuple{1, 2, 1, -1, -1, 33, 0, -1, "2"},
|
||||
std::tuple{1, 2, 1, -1, -1, 1, 10, 32, "2"})));
|
||||
|
||||
TEST_P(AllLong, Test)
|
||||
{
|
||||
auto [_, hdims, perm, is_v_rowmajor, mode, lse, bias_str, p_drop, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, hdim_q_, hdim_v_, seqlen_q, seqlen_k, seqlen_kpad, mask_str] =
|
||||
dims_mask;
|
||||
|
||||
hdim_q = hdim_q_ == -1 ? hdim_q : hdim_q_;
|
||||
hdim_v = hdim_v_ == -1 ? hdim_v : hdim_v_;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{seqlen_kpad}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
perm, // i_perm
|
||||
perm, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
is_v_rowmajor, // is_v_rowmajor
|
||||
lse, // lse
|
||||
0, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
bias_str, // bias_str
|
||||
p_drop, // p_drop
|
||||
123, // drop_seed
|
||||
1024, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class HDimPadding
|
||||
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::tuple<int, int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
HDimPadding,
|
||||
Combine(Values(std::tuple{24, 48},
|
||||
std::tuple{120, 160},
|
||||
std::tuple{256, 108},
|
||||
std::tuple{40, 64}),
|
||||
Bool(),
|
||||
IsVRowmajorValues,
|
||||
ModeValues,
|
||||
Values(std::tuple{1, 4, 2, 480, -1, -1, "0"},
|
||||
std::tuple{2, 2, -1, 300, 400, 512, "t:64,64"},
|
||||
std::tuple{1, 4, 1, 512, 201, 256, "1"},
|
||||
std::tuple{1, 2, -1, 900, 256, -1, "0"},
|
||||
std::tuple{2, 1, -1, 256, 256, -1, "1"})));
|
||||
|
||||
TEST_P(HDimPadding, Test)
|
||||
{
|
||||
auto [hdims, perm, is_v_rowmajor, mode, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, seqlen_kpad, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{seqlen_kpad}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
perm, // i_perm
|
||||
perm, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
0, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
"n", // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class ElementwiseBias
|
||||
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::string,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
ElementwiseBias,
|
||||
Combine(HDimValues,
|
||||
Bool(), // layout of bias is controlled by i_perm
|
||||
ModeValues,
|
||||
Values("e:0", "e:1", "e:2"),
|
||||
Values(std::tuple{1, 4, 2, 1024, 100, "0"},
|
||||
std::tuple{3, 2, -1, 128, 256, "2"},
|
||||
std::tuple{2, 2, -1, 130, 499, "t:50,64"})));
|
||||
|
||||
TEST_P(ElementwiseBias, Test)
|
||||
{
|
||||
auto [hdims, i_perm, mode, bias_str, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
i_perm, // i_perm
|
||||
false, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
def_is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
0, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
bias_str, // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class Alibi : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
mode_enum,
|
||||
std::string,
|
||||
std::tuple<int, int, int, int, int>,
|
||||
std::string>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
Alibi,
|
||||
Combine(HDimValues,
|
||||
ModeValues,
|
||||
Values("a:0", "a:1"),
|
||||
Values(std::tuple{1, 3, 3, 1024, 1000},
|
||||
std::tuple{3, 5, 5, 128, 256},
|
||||
std::tuple{2, 8, 2, 300, 355}),
|
||||
Values("0", "t", "b", "t:50,64", "b:32,40")));
|
||||
|
||||
TEST_P(Alibi, Test)
|
||||
{
|
||||
auto [hdims, mode, bias_str, dims, mask_str] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k] = dims;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
true, // i_perm
|
||||
true, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
def_is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
0, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
bias_str, // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class Dropout : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
mode_enum,
|
||||
float,
|
||||
std::tuple<uint64_t, uint64_t, bool>,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
Dropout,
|
||||
Combine(HDimValues,
|
||||
ModeValues,
|
||||
Values(0.123f, 0.5f),
|
||||
Values(std::tuple{10, 123, false},
|
||||
std::tuple{34534564645, 7876878876864, true}),
|
||||
Values(std::tuple{2, 4, 2, 280, 512, "0"},
|
||||
std::tuple{3, 2, 2, 256, 128, "1"},
|
||||
std::tuple{4, 3, 1, 100, 768, "2"})));
|
||||
|
||||
TEST_P(Dropout, Test)
|
||||
{
|
||||
auto [hdims, mode, p_drop, drop_seed_offset_prefs, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
false, // i_perm
|
||||
false, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
def_is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
0, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
"n", // bias_str
|
||||
p_drop, // p_drop
|
||||
drop_seed, // drop_seed
|
||||
drop_offset, // drop_offset
|
||||
drop_prefs, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
#if CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
|
||||
class PagedKV : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
bool,
|
||||
mode_enum,
|
||||
int,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
PagedKV,
|
||||
Combine(SplitKVHDimValues,
|
||||
Bool(), // layouts of k and v are controlled by i_perm
|
||||
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
||||
ModeValues,
|
||||
Values(128, 256),
|
||||
Values(std::tuple{2, 3, 1, 200, 1024, "0"},
|
||||
std::tuple{3, 2, -1, 128, 768, "2"},
|
||||
std::tuple{2, 2, -1, 230, 899, "t:50,64"})));
|
||||
|
||||
TEST_P(PagedKV, Test)
|
||||
{
|
||||
auto [hdims, i_perm, is_v_rowmajor, mode, page_block_size, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
i_perm, // i_perm
|
||||
false, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
page_block_size, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
"n", // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
#endif // CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
|
||||
class SplitKV : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
bool,
|
||||
std::tuple<mode_enum, bool>,
|
||||
int,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
SplitKV,
|
||||
Combine(SplitKVHDimValues,
|
||||
Bool(), // layouts of k and v are controlled by i_perm
|
||||
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
||||
Values(std::tuple{mode_enum::batch, false},
|
||||
std::tuple{mode_enum::batch, true},
|
||||
std::tuple{mode_enum::group, false}),
|
||||
Values(3, 4),
|
||||
Values(std::tuple{4, 3, 1, 200, 1024, "0"},
|
||||
std::tuple{2, 2, -1, 512, 2000, "0"},
|
||||
std::tuple{3, 2, -1, 230, 899, "t:128,128"})));
|
||||
|
||||
TEST_P(SplitKV, Test)
|
||||
{
|
||||
auto [hdims, i_perm, is_v_rowmajor, mode_use_cache_batch_idx, num_splits, dims_mask] =
|
||||
GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [mode, use_cache_batch_idx] = mode_use_cache_batch_idx;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
i_perm, // i_perm
|
||||
false, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
0, // page_block_size
|
||||
use_cache_batch_idx, // use_cache_batch_idx
|
||||
"n", // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
num_splits, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
#endif // CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
|
||||
class AppendKV : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
bool,
|
||||
std::tuple<int, bool>,
|
||||
int,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TestCkTileFmhaFwd,
|
||||
AppendKV,
|
||||
Combine(AppendKVHDimValues,
|
||||
Bool(), // layouts of k and v are controlled by i_perm
|
||||
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
||||
ValuesIn({std::tuple{0, true}, std::tuple{0, false}, std::tuple{128, false}}),
|
||||
Values(1, 64, -1),
|
||||
Values(std::tuple{3, 3, -1, 60, 129, "t:32,32"},
|
||||
std::tuple{3, 2, 2, 256, 256, "0"},
|
||||
std::tuple{2, 3, 1, 264, 265, "1"},
|
||||
std::tuple{4, 4, 2, 71, 64, "1"})));
|
||||
|
||||
TEST_P(AppendKV, Test)
|
||||
{
|
||||
auto [hdims,
|
||||
i_perm,
|
||||
is_v_rowmajor,
|
||||
page_block_size_use_cache_batch_idx,
|
||||
seqlen_knew,
|
||||
dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [page_block_size, use_cache_batch_idx] = page_block_size_use_cache_batch_idx;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
seqlen_knew = seqlen_knew == -1 ? seqlen_k : seqlen_knew;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode_enum::batch,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
seqlen_knew, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
i_perm, // i_perm
|
||||
true, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
page_block_size, // page_block_size
|
||||
use_cache_batch_idx, // use_cache_batch_idx
|
||||
"n", // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
false, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class AppendKVRoPE
|
||||
: public TestWithParam<std::tuple<bool,
|
||||
std::tuple<int, int>,
|
||||
bool,
|
||||
bool,
|
||||
std::tuple<int, bool>,
|
||||
int,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKVRoPE);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
AppendKVRoPE,
|
||||
Combine(EnableTestIf(!std::is_same_v<DataTypeConfig, FmhaFwdFp8>),
|
||||
AppendKVHDimValues,
|
||||
Bool(), // layouts of k and v are controlled by i_perm
|
||||
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
||||
Values(std::tuple{0, false},
|
||||
std::tuple{16, true},
|
||||
std::tuple{32, false},
|
||||
std::tuple{-1, true}),
|
||||
Values(16, 50, -1),
|
||||
Values(std::tuple{2, 3, -1, 60, 129, "t:32,32"},
|
||||
std::tuple{1, 2, 1, 128, 55, "0"},
|
||||
std::tuple{3, 4, 2, 72, 128, "1"})));
|
||||
|
||||
TEST_P(AppendKVRoPE, Test)
|
||||
{
|
||||
auto [_, hdims, i_perm, is_v_rowmajor, rotary, seqlen_knew, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [rotary_dim, is_rotary_interleaved] = rotary;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
rotary_dim = rotary_dim == -1 ? hdim_q : rotary_dim;
|
||||
seqlen_knew = seqlen_knew == -1 ? seqlen_k : seqlen_knew;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode_enum::batch,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
seqlen_knew, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
rotary_dim, // rotary_dim
|
||||
i_perm, // i_perm
|
||||
true, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
is_v_rowmajor, // is_v_rowmajor
|
||||
true, // lse
|
||||
0, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
"n", // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
is_rotary_interleaved, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
#endif // CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
44
test/ck_tile/fmha/test_fmha_fwd_bf16.cpp
Normal file
44
test/ck_tile/fmha/test_fmha_fwd_bf16.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
|
||||
using ::testing::Values;
|
||||
|
||||
using DataTypeConfig = FmhaFwdBf16;
|
||||
|
||||
const auto HDimValues = Values(std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, 128},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{192, 128},
|
||||
std::tuple{192, -1},
|
||||
std::tuple{256, -1});
|
||||
|
||||
const auto SplitKVHDimValues = Values(std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, -1},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{256, -1});
|
||||
|
||||
const auto AppendKVHDimValues =
|
||||
Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
const auto IsVRowmajorValues = Values(false, true);
|
||||
|
||||
const bool squant = false;
|
||||
const std::string init_method = "uf";
|
||||
const bool def_lse = true;
|
||||
const bool def_is_v_rowmajor = true;
|
||||
|
||||
int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
|
||||
#include "test_fmha_fwd.inc"
|
||||
44
test/ck_tile/fmha/test_fmha_fwd_fp16.cpp
Normal file
44
test/ck_tile/fmha/test_fmha_fwd_fp16.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
|
||||
using ::testing::Values;
|
||||
|
||||
using DataTypeConfig = FmhaFwdFp16;
|
||||
|
||||
const auto HDimValues = Values(std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, 128},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{192, 128},
|
||||
std::tuple{192, -1},
|
||||
std::tuple{256, -1});
|
||||
|
||||
const auto SplitKVHDimValues = Values(std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, -1},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{256, -1});
|
||||
|
||||
const auto AppendKVHDimValues =
|
||||
Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
const auto IsVRowmajorValues = Values(false, true);
|
||||
|
||||
const bool squant = false;
|
||||
const std::string init_method = "uf";
|
||||
const bool def_lse = true;
|
||||
const bool def_is_v_rowmajor = true;
|
||||
|
||||
int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
|
||||
#include "test_fmha_fwd.inc"
|
||||
43
test/ck_tile/fmha/test_fmha_fwd_fp8.cpp
Normal file
43
test/ck_tile/fmha/test_fmha_fwd_fp8.cpp
Normal file
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
|
||||
using ::testing::Values;
|
||||
|
||||
using DataTypeConfig = FmhaFwdFp8;
|
||||
|
||||
// Currently there are no fp8 instances for splitkv, pagedkv by default (the tests pass if such
|
||||
// instances are added), however the corresponding tests are not disabled (they will be skipped)
|
||||
// in case such instances will be added in the future.
|
||||
|
||||
const auto HDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto SplitKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto AppendKVHDimValues =
|
||||
Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
// There are no fp8 instances with seqlen padding (mode_enum::group requires it)
|
||||
const auto ModeValues = Values(mode_enum::batch);
|
||||
|
||||
const auto IsVRowmajorValues = Values(false);
|
||||
|
||||
const bool squant = true;
|
||||
const std::string init_method = "ufq";
|
||||
const bool def_lse = false;
|
||||
const bool def_is_v_rowmajor = false;
|
||||
|
||||
int adjust_seqlen(int seqlen)
|
||||
{
|
||||
// There are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests
|
||||
return ck_tile::integer_least_multiple(seqlen, 128);
|
||||
}
|
||||
|
||||
#include "test_fmha_fwd.inc"
|
||||
Reference in New Issue
Block a user