Files
composable_kernel/test/ck_tile/fmha/test_fmha_fwd.inc
Anton Gorenko 1edd250115 [CK_TILE] Support f32 in FMHA (fwd and bwd) (#2836)
* Support 16x16 (MFMA, WMMA) and 32x32 (MFMA) tiles in fwd and bwd BlockDropout

Add comments with dropout implementation details

Fix performance regression of fwd+dropout

    * Remove some usage of type punning (reinterpret_cast with ref or ptr) in Philox;
    * "scalarize" seed and offset, they may come either from kernel args or from device memory
      (presumably loaded with vector loads).

    These changes help the compiler to procude more optimal code and reduce register spilling.

Use WarpGemmDispatcher instead of explicit WarpGemmMfma... to get  CWarpDstrEncoding

Use code based on BlockDropout in BlockDropoutBwd

Refactor BlockDropout (fwd)

Implement BlockDropout (fwd) for WMMA

    Originally BlockDropout only supported 32x32 tiles (IsWG32 = true),
    this version supports 16x16 tiles.
    If MPerBlock > MWarp * 16, it can generate numbers for two 16x16 tiles, similarly
    to BlockDropoutBwd.

Implement BlockDropoutBwd for WMMA

Remove MakeRandValLds* functions unused in BlockDropoutBwd

Remove unused Run overload from BlockDropoutBwd

* Fix regression with philox seed and offset when they exceed 32-bit int

__builtin_amdgcn_readfirstlane works with 32-bit values, seed and offset
are 64-bit so they get truncated.

* Add F32 MFMA warp gemms

* Support f32 in fwd FMHA

* Implement transpose_vectors for 4-byte types (float)

* Fix unexpected implicit f32->uint32 cast in buffer_store<4>

__builtin_amdgcn_raw_buffer_store_b32 expects unsigned int but float was passed (implicitly casted to uint).
mbuf_t types in other buffer_store<> are changed for consistency.

* Support F32 in bwd FMHA

hdim = 256 is disabled for now because it uses too much memory on gfx90a

* Support Headdim = 48 (divisible by 16) in fwd

* Add fp32-specific receipts (800 and 801)

* Tune fwd tiles

* Tune bwd tiles

* Use small tiles only for small seqlen_q

* Fix after rebasing

* Fix selection of a fallback tile based on bm0

The assumption that the largest bm0 == 128 is not always true for
current fp32 tiles.

* Remove constraints and adjust filtering for fp32

Custom constraints are no longer needed because now the smallest tile
is selected automtically based on seqlen_q.
Filters related to qr_async_trload disabled valid fp32 tiles.

* Add fp32 tests

* Make splitkv and appendkv compile for fp32 only

There are no instances yet, but API still must compile when only fp32 is
requested.

* Remove unimportant f32 instances

* Add test_ck_tile_fmha_*_fp32 to REGRESSION_TESTS

* Replace magic numbers with a constant, improve comments for dropout

* Update changelog

* Fix condition that dq_acc must be set to zero when mask is used

The change was introduced in #2799

* Replace warp_uniform with recently added amd_wave_read_first_lane

* Add hdim = 96 and 192 to fwd
2025-09-27 18:03:48 +05:00

1085 lines
53 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
using ::testing::Bool;
using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::ValuesIn;
// Random seed used for initializing input tensors. 0 for non-deterministic seed
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
// Whether to run long tests (from smoke_test_fwd.sh)
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS)
#define CHECK_RESULT(result) \
do \
{ \
if(result == fwd_result::no_instance) \
GTEST_SKIP() << "No instance for current parameters"; \
ASSERT_EQ(result, fwd_result::success); \
} while(0)
const ck_tile::stream_config stream_config{
nullptr, // stream_id_
false, // time_kernel_
1, // log_level_
0, // cold_niters_
1, // nrepeat_
true, // is_gpu_timer_
false, // flush_cache_
1, // rotating_count_
};
#define COMMON_ARGS \
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \
stream_config
auto EnableTestIf(bool condition)
{
return ValuesIn(condition ? std::vector<bool>{true} : std::vector<bool>{});
}
class AllLong : public TestWithParam<
std::tuple<bool,
std::tuple<int, int>,
bool,
bool,
mode_enum,
bool,
std::string,
float,
std::tuple<int, int, int, int, int, int, int, int, std::string>>>
{
};
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong);
// Test cases from example/ck_tile/01_fmha/script/smoke_test_fwd.sh
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaFwd,
AllLong,
Combine(EnableTestIf(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))),
HDimValues,
Bool(),
IsVRowmajorValues,
ModeValues,
Bool(),
Values("n", "e", "a"),
Values(0.0f, 0.2f),
Values(std::tuple{2, 2, 1, 16, -1, 55, 256, -1, "0"},
std::tuple{1, 3, -1, -1, -1, 100, 51, -1, "0"},
std::tuple{2, 1, -1, 16, -1, 99, 256, -1, "1"},
std::tuple{1, 2, 1, -1, -1, 1024, 256, -1, "2"},
std::tuple{2, 1, -1, -1, 24, 3, 99, -1, "2"},
std::tuple{3, 2, 1, -1, -1, 200, 520, -1, "t:128,30"},
std::tuple{2, 1, -1, -1, -1, 99, 32, -1, "b:4,35"},
std::tuple{1, 2, 1, -1, -1, 33, 0, -1, "2"},
std::tuple{1, 2, 1, -1, -1, 1, 10, 32, "2"})));
TEST_P(AllLong, Test)
{
auto [_, hdims, perm, is_v_rowmajor, mode, lse, bias_str, p_drop, dims_mask] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [batch, nhead, nhead_k, hdim_q_, hdim_v_, seqlen_q, seqlen_k, seqlen_kpad, mask_str] =
dims_mask;
hdim_q = hdim_q_ == -1 ? hdim_q : hdim_q_;
hdim_v = hdim_v_ == -1 ? hdim_v : hdim_v_;
auto result = fmha_fwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
{adjust_seqlen(seqlen_q)},
{adjust_seqlen(seqlen_k)},
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{seqlen_kpad}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
perm, // i_perm
perm, // o_perm
0, // scale_s
0, // logits_soft_cap
is_v_rowmajor, // is_v_rowmajor
lse, // lse
0, // page_block_size
false, // use_cache_batch_idx
bias_str, // bias_str
p_drop, // p_drop
123, // drop_seed
1024, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
CHECK_RESULT(result);
}
// ---------------------------------------------------------------
// Negative tests: padding not supported with appendkv/splitkv/pagedkv
// ---------------------------------------------------------------
#if CK_TILE_FMHA_FWD_APPENDKV_API
TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail)
{
// batch mode effective lengths simulate padding
auto result = fmha_fwd_run<DataTypeConfig>(
mode_enum::batch,
2, // batch
4, // nhead
-1, // nhead_k
{128}, // seqlen_qs
{128}, // seqlen_ks
64, // hdim_q
64, // hdim_v
32, // seqlen_knew -> triggers appendkv
{}, // seqlen_qpads
{}, // seqlen_kpads
{100, 120}, // q_eff_lens_per_batch
{90, 110}, // kv_eff_lens_per_batch
0, // rotary_dim
true, // i_perm
true, // o_perm
0, // scale_s
0, // logits_soft_cap
def_is_v_rowmajor,
def_lse,
0, // page_block_size
false, // use_cache_batch_idx
"n", // bias
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
"0", // mask
squant,
true, // is_rotary_interleaved
1, // num_splits
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
0,
stream_config);
ASSERT_EQ(result, fwd_result::invalid_args);
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail)
{
// group mode physical padding
auto result = fmha_fwd_run<DataTypeConfig>(
mode_enum::group,
2, // batch
4, // nhead
-1, // nhead_k
{96, 120}, // seqlen_qs logical
{96, 120}, // seqlen_ks logical
64, // hdim_q
64, // hdim_v
0, // seqlen_knew
{128, 128}, // seqlen_qpads
{128, 128}, // seqlen_kpads
{}, // q_eff
{}, // kv_eff
0, // rotary_dim
true, // i_perm
true, // o_perm
0, // scale_s
0, // logits_soft_cap
def_is_v_rowmajor,
def_lse,
0, // page_block_size
false, // use_cache_batch_idx
"n", // bias
0.0f,
0,
0,
false,
"0",
squant,
true,
2, // num_splits (>1 triggers splitkv)
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
0,
stream_config);
ASSERT_EQ(result, fwd_result::invalid_args);
}
#endif
#if CK_TILE_FMHA_FWD_PAGEDKV_API
TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail)
{
auto result = fmha_fwd_run<DataTypeConfig>(
mode_enum::group,
2,
4,
-1,
{80, 100},
{80, 100},
64,
64,
0, // seqlen_knew
{96, 128}, // seqlen_qpads
{96, 128}, // seqlen_kpads
{},
{},
0,
true,
true,
0,
0,
def_is_v_rowmajor,
def_lse,
128, // page_block_size triggers pagedkv
false,
"n",
0.0f,
0,
0,
false,
"0",
squant,
true,
1,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
0,
stream_config);
ASSERT_EQ(result, fwd_result::invalid_args);
}
#endif
class HDimPadding
: public TestWithParam<std::tuple<std::tuple<int, int>,
bool,
bool,
mode_enum,
std::tuple<int, int, int, int, int, int, std::string>>>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
HDimPadding,
Combine(Values(std::tuple{24, 48},
std::tuple{120, 160},
std::tuple{256, 108},
std::tuple{40, 64}),
Bool(),
IsVRowmajorValues,
ModeValues,
Values(std::tuple{1, 4, 2, 480, -1, -1, "0"},
std::tuple{2, 2, -1, 300, 400, 512, "t:64,64"},
std::tuple{1, 4, 1, 512, 201, 256, "1"},
std::tuple{1, 2, -1, 900, 256, -1, "0"},
std::tuple{2, 1, -1, 256, 256, -1, "1"})));
TEST_P(HDimPadding, Test)
{
auto [hdims, perm, is_v_rowmajor, mode, dims_mask] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, seqlen_kpad, mask_str] = dims_mask;
auto result = fmha_fwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
{adjust_seqlen(seqlen_q)},
{adjust_seqlen(seqlen_k)},
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{seqlen_kpad}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
perm, // i_perm
perm, // o_perm
0, // scale_s
0, // logits_soft_cap
is_v_rowmajor, // is_v_rowmajor
def_lse, // lse
0, // page_block_size
false, // use_cache_batch_idx
"n", // bias_str
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
CHECK_RESULT(result);
}
class ElementwiseBias
: public TestWithParam<std::tuple<std::tuple<int, int>,
bool,
mode_enum,
std::string,
std::tuple<int, int, int, int, int, std::string>>>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
ElementwiseBias,
Combine(HDimValues,
Bool(), // layout of bias is controlled by i_perm
ModeValues,
Values("e:0", "e:1", "e:2"),
Values(std::tuple{1, 4, 2, 1024, 100, "0"},
std::tuple{3, 2, -1, 128, 256, "2"},
std::tuple{2, 2, -1, 130, 499, "t:50,64"})));
TEST_P(ElementwiseBias, Test)
{
auto [hdims, i_perm, mode, bias_str, dims_mask] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
auto result = fmha_fwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
{adjust_seqlen(seqlen_q)},
{adjust_seqlen(seqlen_k)},
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
i_perm, // i_perm
false, // o_perm
0, // scale_s
0, // logits_soft_cap
def_is_v_rowmajor, // is_v_rowmajor
def_lse, // lse
0, // page_block_size
false, // use_cache_batch_idx
bias_str, // bias_str
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
CHECK_RESULT(result);
}
class Alibi : public TestWithParam<std::tuple<std::tuple<int, int>,
mode_enum,
std::string,
std::tuple<int, int, int, int, int>,
std::string>>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
Alibi,
Combine(HDimValues,
ModeValues,
Values("a:0", "a:1"),
Values(std::tuple{1, 3, 3, 1024, 1000},
std::tuple{3, 5, 5, 128, 256},
std::tuple{2, 8, 2, 300, 355}),
Values("0", "t", "b", "t:50,64", "b:32,40")));
TEST_P(Alibi, Test)
{
auto [hdims, mode, bias_str, dims, mask_str] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k] = dims;
auto result = fmha_fwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
{adjust_seqlen(seqlen_q)},
{adjust_seqlen(seqlen_k)},
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
true, // i_perm
true, // o_perm
0, // scale_s
0, // logits_soft_cap
def_is_v_rowmajor, // is_v_rowmajor
def_lse, // lse
0, // page_block_size
false, // use_cache_batch_idx
bias_str, // bias_str
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
CHECK_RESULT(result);
}
class Dropout : public TestWithParam<std::tuple<std::tuple<int, int>,
mode_enum,
float,
std::tuple<uint64_t, uint64_t, bool>,
std::tuple<int, int, int, int, int, std::string>>>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
Dropout,
Combine(HDimValues,
ModeValues,
Values(0.123f, 0.5f),
Values(std::tuple{10, 123, false},
std::tuple{34534564645, 7876878876864, true}),
Values(std::tuple{2, 4, 2, 280, 512, "0"},
std::tuple{3, 2, 2, 256, 128, "1"},
std::tuple{4, 3, 1, 100, 768, "2"})));
TEST_P(Dropout, Test)
{
auto [hdims, mode, p_drop, drop_seed_offset_prefs, dims_mask] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
auto result = fmha_fwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
{adjust_seqlen(seqlen_q)},
{adjust_seqlen(seqlen_k)},
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
false, // i_perm
false, // o_perm
0, // scale_s
0, // logits_soft_cap
def_is_v_rowmajor, // is_v_rowmajor
def_lse, // lse
0, // page_block_size
false, // use_cache_batch_idx
"n", // bias_str
p_drop, // p_drop
drop_seed, // drop_seed
drop_offset, // drop_offset
drop_prefs, // drop_prefs
mask_str, // mask_str
squant,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
CHECK_RESULT(result);
}
#if CK_TILE_FMHA_FWD_PAGEDKV_API
class PagedKV : public TestWithParam<std::tuple<std::tuple<int, int>,
bool,
bool,
mode_enum,
int,
std::tuple<int, int, int, int, int, std::string>>>
{
};
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(PagedKV);
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
PagedKV,
Combine(SplitKVHDimValues,
Bool(), // layouts of k and v are controlled by i_perm
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
ModeValues,
Values(128, 256),
Values(std::tuple{2, 3, 1, 200, 1024, "0"},
std::tuple{3, 2, -1, 128, 768, "2"},
std::tuple{2, 2, -1, 230, 899, "t:50,64"})));
TEST_P(PagedKV, Test)
{
auto [hdims, i_perm, is_v_rowmajor, mode, page_block_size, dims_mask] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
auto result = fmha_fwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
{adjust_seqlen(seqlen_q)},
{adjust_seqlen(seqlen_k)},
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
i_perm, // i_perm
false, // o_perm
0, // scale_s
0, // logits_soft_cap
is_v_rowmajor, // is_v_rowmajor
false, // lse
page_block_size, // page_block_size
false, // use_cache_batch_idx
"n", // bias_str
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
CHECK_RESULT(result);
}
#endif // CK_TILE_FMHA_FWD_PAGEDKV_API
#if CK_TILE_FMHA_FWD_SPLITKV_API
class SplitKV : public TestWithParam<std::tuple<std::tuple<int, int>,
bool,
bool,
std::tuple<mode_enum, bool>,
int,
std::tuple<int, int, int, int, int, std::string>>>
{
};
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(SplitKV);
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
SplitKV,
Combine(SplitKVHDimValues,
Bool(), // layouts of k and v are controlled by i_perm
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
Values(std::tuple{mode_enum::batch, false},
std::tuple{mode_enum::batch, true},
std::tuple{mode_enum::group, false}),
Values(3, 4),
Values(std::tuple{4, 3, 1, 200, 1024, "0"},
std::tuple{2, 2, -1, 512, 2000, "0"},
std::tuple{3, 2, -1, 230, 899, "t:128,128"})));
TEST_P(SplitKV, Test)
{
auto [hdims, i_perm, is_v_rowmajor, mode_use_cache_batch_idx, num_splits, dims_mask] =
GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [mode, use_cache_batch_idx] = mode_use_cache_batch_idx;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
auto result = fmha_fwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
{adjust_seqlen(seqlen_q)},
{adjust_seqlen(seqlen_k)},
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
i_perm, // i_perm
false, // o_perm
0, // scale_s
0, // logits_soft_cap
is_v_rowmajor, // is_v_rowmajor
def_lse, // lse
0, // page_block_size
use_cache_batch_idx, // use_cache_batch_idx
"n", // bias_str
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
true, // is_rotary_interleaved
num_splits, // num_splits
COMMON_ARGS);
CHECK_RESULT(result);
}
#endif // CK_TILE_FMHA_FWD_SPLITKV_API
#if CK_TILE_FMHA_FWD_APPENDKV_API
class AppendKV : public TestWithParam<std::tuple<std::tuple<int, int>,
bool,
bool,
std::tuple<int, bool>,
int,
std::tuple<int, int, int, int, int, std::string>>>
{
};
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaFwd,
AppendKV,
Combine(AppendKVHDimValues,
Bool(), // layouts of k and v are controlled by i_perm
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
ValuesIn({std::tuple{0, true}, std::tuple{0, false}, std::tuple{128, false}}),
Values(1, 64, -1),
Values(std::tuple{3, 3, -1, 60, 129, "t:32,32"},
std::tuple{3, 2, 2, 256, 256, "0"},
std::tuple{2, 3, 1, 264, 265, "1"},
std::tuple{4, 4, 2, 71, 64, "1"})));
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKV);
TEST_P(AppendKV, Test)
{
auto [hdims,
i_perm,
is_v_rowmajor,
page_block_size_use_cache_batch_idx,
seqlen_knew,
dims_mask] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [page_block_size, use_cache_batch_idx] = page_block_size_use_cache_batch_idx;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
seqlen_knew = seqlen_knew == -1 ? seqlen_k : seqlen_knew;
auto result = fmha_fwd_run<DataTypeConfig>(mode_enum::batch,
batch,
nhead,
nhead_k,
{adjust_seqlen(seqlen_q)},
{adjust_seqlen(seqlen_k)},
hdim_q,
hdim_v,
seqlen_knew, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
i_perm, // i_perm
true, // o_perm
0, // scale_s
0, // logits_soft_cap
is_v_rowmajor, // is_v_rowmajor
def_lse, // lse
page_block_size, // page_block_size
use_cache_batch_idx, // use_cache_batch_idx
"n", // bias_str
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
false, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
CHECK_RESULT(result);
}
class AppendKVRoPE
: public TestWithParam<std::tuple<bool,
std::tuple<int, int>,
bool,
bool,
std::tuple<int, bool>,
int,
std::tuple<int, int, int, int, int, std::string>>>
{
};
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKVRoPE);
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
AppendKVRoPE,
Combine(EnableTestIf(!std::is_same_v<DataTypeConfig, FmhaFwdFp8>),
AppendKVHDimValues,
Bool(), // layouts of k and v are controlled by i_perm
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
Values(std::tuple{0, false},
std::tuple{16, true},
std::tuple{32, false},
std::tuple{-1, true}),
Values(16, 50, -1),
Values(std::tuple{2, 3, -1, 60, 129, "t:32,32"},
std::tuple{1, 2, 1, 128, 55, "0"},
std::tuple{3, 4, 2, 72, 128, "1"})));
TEST_P(AppendKVRoPE, Test)
{
auto [_, hdims, i_perm, is_v_rowmajor, rotary, seqlen_knew, dims_mask] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [rotary_dim, is_rotary_interleaved] = rotary;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
rotary_dim = rotary_dim == -1 ? hdim_q : rotary_dim;
seqlen_knew = seqlen_knew == -1 ? seqlen_k : seqlen_knew;
auto result = fmha_fwd_run<DataTypeConfig>(mode_enum::batch,
batch,
nhead,
nhead_k,
{adjust_seqlen(seqlen_q)},
{adjust_seqlen(seqlen_k)},
hdim_q,
hdim_v,
seqlen_knew, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
rotary_dim, // rotary_dim
i_perm, // i_perm
true, // o_perm
0, // scale_s
0, // logits_soft_cap
is_v_rowmajor, // is_v_rowmajor
true, // lse
0, // page_block_size
false, // use_cache_batch_idx
"n", // bias_str
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
is_rotary_interleaved, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
CHECK_RESULT(result);
}
#endif // CK_TILE_FMHA_FWD_APPENDKV_API
// ---------------------------------------------------------------
// Parameterized padding tests (batch & group) using Combine+Values
// ---------------------------------------------------------------
using PaddingParam = std::tuple<mode_enum, // mode
int, // batch
int, // nhead
int, // nhead_k
std::vector<int>, // seqlen_qs (logical)
std::vector<int>, // seqlen_ks (logical)
std::vector<int>, // seqlen_qpads (physical padded lengths)
std::vector<int>, // seqlen_kpads (physical padded lengths)
std::vector<int>, // q_eff_lens
std::vector<int>, // kv_eff_lens
bool, // i_perm
bool, // o_perm
std::string>; // mask_str
// Ensure headers for containers / algorithms used in padding param builder.
#include <vector>
#include <array>
#include <cmath>
#include <algorithm>
class PaddingCases : public TestWithParam<PaddingParam>
{
};
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(PaddingCases);
// Build padding test params programmatically to enforce constraints
static std::vector<PaddingParam> BuildPaddingParams()
{
std::vector<PaddingParam> params;
// mask variants to cover
const std::vector<std::string> mask_variants{"0", "t:50,64", "b:32,40"};
const std::vector<std::string> mask_variants_reduced{"0", "t:50,64"}; // used for trimmed sets
// Representative ratio pairs (q_ratio, k_ratio) to avoid explosion
const std::vector<std::pair<double, double>> ratio_pairs_full{
{1.0, 1.0}, // both full
{1.0, 0.5}, // q full, k half
{0.5, 1.0}, // q half, k full
};
const std::vector<std::pair<double, double>> ratio_pairs_reduced{{1.0, 1.0}, {0.5, 1.0}};
// candidate physical seqlens for batch mode (single value) & for group mode (per batch)
const std::vector<int> physical_lengths_full{64, 128, 256};
const std::vector<int> physical_lengths_reduced{64};
// batch sizes to sample
const std::vector<int> batch_sizes{1, 4};
// --------------------------------------------------------------------
// Head configuration space (cover MHA, GQA, MQA)
// - Standard MHA: nhead_k == -1 (treated internally as nhead)
// - GQA: nhead_k > 0 and nhead % nhead_k == 0, nhead_k < nhead
// - MQA: nhead_k == 1
// We choose (9, -1), (9, 3), (9, 1) so that divisibility holds. Full
// combinatorics only applied to the first (standard) configuration to
// avoid test explosion.
// --------------------------------------------------------------------
struct HeadCfg
{
int nhead;
int nhead_k; // -1 for standard; else must divide nhead
bool full; // whether to use full coverage sets
};
const std::vector<HeadCfg> head_cfgs = {
{9, -1, true}, // MHA full
{9, 3, false}, // GQA reduced (nhead/nhead_k=3)
{9, 1, false} // MQA reduced
};
// Helper to clamp and ensure >=1
auto logical_len = [](int physical, double ratio) {
int v = static_cast<int>(std::round(physical * ratio));
v = std::max(1, std::min(v, physical));
return v;
};
// Iterate over head configurations
for(const auto& hc : head_cfgs)
{
const auto& ratio_pairs = hc.full ? ratio_pairs_full : ratio_pairs_reduced;
const auto& phys_lengths_batch = hc.full ? physical_lengths_full : physical_lengths_reduced;
const auto& phys_lengths_group_q = phys_lengths_batch; // reuse
const auto& phys_lengths_group_k = phys_lengths_batch; // reuse
const auto& masks = hc.full ? mask_variants : mask_variants_reduced;
// -----------------
// Batch mode params (effective lengths only)
// -----------------
for(int b : batch_sizes)
{
for(int phys_qkv : phys_lengths_batch)
{
for(const auto& rkpair : ratio_pairs)
{
double rq = rkpair.first;
double rk = rkpair.second;
std::vector<int> q_eff(b), kv_eff(b);
int log_q = logical_len(phys_qkv, rq);
int log_k = logical_len(phys_qkv, rk);
for(int i = 0; i < b; ++i)
{
q_eff[i] = log_q;
kv_eff[i] = log_k;
}
for(const auto& mask : masks)
{
params.emplace_back(PaddingParam{mode_enum::batch,
b,
hc.nhead,
hc.nhead_k,
{phys_qkv}, // seqlen_qs
{phys_qkv}, // seqlen_ks
{}, // seqlen_qpads
{}, // seqlen_kpads
q_eff,
kv_eff,
true,
true,
mask});
}
}
// Single-token logical length case (both q & k = 1)
for(const auto& mask : masks)
{
std::vector<int> q_eff(b, 1), kv_eff(b, 1);
params.emplace_back(PaddingParam{mode_enum::batch,
b,
hc.nhead,
hc.nhead_k,
{phys_qkv},
{phys_qkv},
{},
{},
q_eff,
kv_eff,
true,
true,
mask});
}
}
}
// -----------------
// Group mode params (physical padding + logical variants)
// -----------------
for(int b : batch_sizes)
{
for(int phys_q : phys_lengths_group_q)
{
for(int phys_k : phys_lengths_group_k)
{
for(const auto& rkpair : ratio_pairs)
{
double rq = rkpair.first;
double rk = rkpair.second;
std::vector<int> seqlen_qs(b), seqlen_ks(b), seqlen_qpads(b),
seqlen_kpads(b);
for(int i = 0; i < b; ++i)
{
seqlen_qpads[i] = phys_q;
seqlen_kpads[i] = phys_k;
seqlen_qs[i] = logical_len(phys_q, rq);
seqlen_ks[i] = logical_len(phys_k, rk);
}
std::array<std::pair<std::vector<int>, std::vector<int>>, 3> pad_variants{
std::pair{seqlen_qpads, seqlen_kpads}, // both
std::pair{seqlen_qpads, seqlen_ks}, // only q padding
std::pair{seqlen_qs, seqlen_kpads} // only kv padding
};
for(const auto& mask : masks)
{
for(const auto& pv : pad_variants)
{
params.emplace_back(PaddingParam{mode_enum::group,
b,
hc.nhead,
hc.nhead_k,
seqlen_qs,
seqlen_ks,
pv.first,
pv.second,
{},
{},
true,
true,
mask});
}
}
}
// Single-token logical length case
for(const auto& mask : masks)
{
std::vector<int> seqlen_qs(b, 1), seqlen_ks(b, 1);
std::vector<int> seqlen_qpads(b, phys_q), seqlen_kpads(b, phys_k);
// both padding variant only (others degenerate)
params.emplace_back(PaddingParam{mode_enum::group,
b,
hc.nhead,
hc.nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
{},
{},
true,
true,
mask});
}
}
}
}
}
return params;
}
static const std::vector<PaddingParam> kPaddingParams = BuildPaddingParams();
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd_Padding, PaddingCases, ValuesIn(kPaddingParams));
TEST_P(PaddingCases, Test)
{
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
{
GTEST_SKIP() << "Skip for fp8";
}
auto [mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
q_eff_lens,
kv_eff_lens,
i_perm,
o_perm,
mask_str] = GetParam();
// For batch mode we wrap single logical lengths with adjust_seqlen.
std::vector<int> adj_qs =
(mode == mode_enum::batch) ? std::vector<int>{adjust_seqlen(seqlen_qs.at(0))} : seqlen_qs;
std::vector<int> adj_ks =
(mode == mode_enum::batch) ? std::vector<int>{adjust_seqlen(seqlen_ks.at(0))} : seqlen_ks;
const int hdim_q = 64;
const int hdim_v = 64;
const int seqlen_knew = 0;
auto result = fmha_fwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
adj_qs,
adj_ks,
hdim_q,
hdim_v,
seqlen_knew, // seqlen_knew
seqlen_qpads, // seqlen_qpads
seqlen_kpads, // seqlen_kpads
q_eff_lens, // q_eff_lens_per_batch
kv_eff_lens, // kv_eff_lens_per_batch
0, // rotary_dim
i_perm, // i_perm
o_perm, // o_perm
0, // scale_s
0, // logits_soft_cap
def_is_v_rowmajor,
def_lse, // lse
0, // page_block_size
false, // use_cache_batch_idx
"n", // bias_str
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
CHECK_RESULT(result);
}