mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on gfx950 (#4368) ## Motivation Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline ## Technical Details The microscaling is used when quant scale mode is `BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are fp8/bf8/fp4. Supported features: * only "qr" pipeline is implemented * hdim 128 and 256 (smaller hdim are not possible due to restrictions of "qr" pipeline, but they can be computed using instances with padding) * both 32x32x64 and 16x16x128 scale MFMAs are supported * Q and K scales are applied in hdim, V scales - in seqlen dimension * column-major V only * batch and group mode * bias, Alibi (tested but no instances by default, just like fp8) * masking etc. Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008 ## Test Plan ``` ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8 ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4 ``` ## Test Result The tests must pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1212 lines
59 KiB
C++
1212 lines
59 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include <algorithm>
|
|
#include <array>
|
|
#include <cmath>
|
|
#include <vector>
|
|
|
|
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
|
|
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
|
|
|
|
#include "gtest/gtest.h"
|
|
|
|
#ifndef DataTypeConfig
|
|
#define DataTypeConfig FmhaFwdFp16 // or FmhaFwdBf16 / FmhaFwdFp8Bf16 / FmhaFwdFp32
|
|
#endif
|
|
|
|
using ::testing::Bool;
|
|
using ::testing::Combine;
|
|
using ::testing::TestWithParam;
|
|
using ::testing::Values;
|
|
using ::testing::ValuesIn;
|
|
|
|
template <typename T>
|
|
struct TestConfigs
|
|
{
|
|
static constexpr auto HDimValues = std::array{
|
|
std::tuple{32, -1},
|
|
std::tuple{64, -1},
|
|
std::tuple{96, 128},
|
|
std::tuple{128, -1},
|
|
std::tuple{192, 128},
|
|
std::tuple{192, -1},
|
|
std::tuple{256, -1},
|
|
};
|
|
static constexpr auto SplitKVHDimValues = std::array{
|
|
std::tuple{32, -1},
|
|
std::tuple{64, -1},
|
|
std::tuple{96, -1},
|
|
std::tuple{128, -1},
|
|
std::tuple{256, -1},
|
|
};
|
|
static constexpr auto AppendKVHDimValues = std::array{
|
|
std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
|
|
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
|
static constexpr auto IsVRowmajorValues = std::array{true};
|
|
static constexpr auto qscale_str = "n";
|
|
static constexpr bool def_lse = true;
|
|
static constexpr bool def_is_v_rowmajor = true;
|
|
static constexpr auto init_method = "uf";
|
|
static int adjust_seqlen(int seqlen) { return seqlen; }
|
|
};
|
|
|
|
template <>
|
|
struct TestConfigs<FmhaFwdFp8Bf16>
|
|
{
|
|
static constexpr auto HDimValues =
|
|
std::array{std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
|
|
static constexpr auto SplitKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}};
|
|
static constexpr auto AppendKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}};
|
|
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
|
static constexpr auto IsVRowmajorValues = std::array{true};
|
|
static constexpr auto qscale_str = "pt";
|
|
static constexpr bool def_lse = false;
|
|
static constexpr bool def_is_v_rowmajor = true;
|
|
static constexpr auto init_method = "3";
|
|
// When there are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests:
|
|
// return ck_tile::integer_least_multiple(seqlen, 128);
|
|
static int adjust_seqlen(int seqlen) { return seqlen; }
|
|
};
|
|
|
|
template <>
|
|
struct TestConfigs<FmhaFwdMxFp8>
|
|
{
|
|
static constexpr auto HDimValues = std::array{std::tuple{128, -1}, std::tuple{256, -1}};
|
|
static constexpr auto SplitKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
|
static constexpr auto AppendKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
|
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
|
static constexpr auto IsVRowmajorValues = std::array{false};
|
|
static constexpr auto qscale_str = "mx";
|
|
static constexpr bool def_lse = true;
|
|
static constexpr bool def_is_v_rowmajor = false;
|
|
static constexpr auto init_method = "3";
|
|
static int adjust_seqlen(int seqlen) { return seqlen; }
|
|
};
|
|
|
|
template <>
|
|
struct TestConfigs<FmhaFwdMxFp4>
|
|
{
|
|
static constexpr auto HDimValues = std::array{std::tuple{128, -1}, std::tuple{256, -1}};
|
|
static constexpr auto SplitKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
|
static constexpr auto AppendKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
|
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
|
static constexpr auto IsVRowmajorValues = std::array{false};
|
|
static constexpr auto qscale_str = "mx";
|
|
static constexpr bool def_lse = true;
|
|
static constexpr bool def_is_v_rowmajor = false;
|
|
static constexpr auto init_method = "3";
|
|
static int adjust_seqlen(int seqlen)
|
|
{
|
|
return seqlen < 0 ? seqlen : ck_tile::integer_least_multiple(seqlen, 2);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct TestConfigs<FmhaFwdFp32>
|
|
{
|
|
static constexpr auto HDimValues = std::array{
|
|
std::tuple{32, -1},
|
|
std::tuple{48, -1},
|
|
std::tuple{64, -1},
|
|
std::tuple{96, 128},
|
|
std::tuple{128, -1},
|
|
std::tuple{192, -1},
|
|
std::tuple{256, -1},
|
|
};
|
|
static constexpr auto SplitKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
|
static constexpr auto AppendKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
|
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
|
static constexpr auto IsVRowmajorValues = std::array{true};
|
|
static constexpr auto qscale_str = "n";
|
|
static constexpr bool def_lse = true;
|
|
static constexpr bool def_is_v_rowmajor = true;
|
|
static constexpr auto init_method = "uf";
|
|
static int adjust_seqlen(int seqlen) { return seqlen; }
|
|
};
|
|
|
|
static auto HDimValues = ValuesIn(TestConfigs<DataTypeConfig>::HDimValues);
|
|
static auto SplitKVHDimValues = ValuesIn(TestConfigs<DataTypeConfig>::SplitKVHDimValues);
|
|
static auto AppendKVHDimValues = ValuesIn(TestConfigs<DataTypeConfig>::AppendKVHDimValues);
|
|
static auto ModeValues = ValuesIn(TestConfigs<DataTypeConfig>::ModeValues);
|
|
static auto IsVRowmajorValues = ValuesIn(TestConfigs<DataTypeConfig>::IsVRowmajorValues);
|
|
constexpr static auto qscale_str = TestConfigs<DataTypeConfig>::qscale_str;
|
|
constexpr bool def_lse = TestConfigs<DataTypeConfig>::def_lse;
|
|
constexpr bool def_is_v_rowmajor = TestConfigs<DataTypeConfig>::def_is_v_rowmajor;
|
|
constexpr auto init_method = TestConfigs<DataTypeConfig>::init_method;
|
|
int adjust_seqlen(int seqlen) { return TestConfigs<DataTypeConfig>::adjust_seqlen(seqlen); }
|
|
|
|
// Random seed used for initializing input tensors. 0 for non-deterministic seed
|
|
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
|
|
|
|
// Whether to run long tests (from smoke_test_fwd.sh)
|
|
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS)
|
|
|
|
#define CHECK_RESULT(result) \
|
|
do \
|
|
{ \
|
|
if(result == fwd_result::no_instance) \
|
|
GTEST_SKIP() << "No instance for current parameters"; \
|
|
ASSERT_EQ(result, fwd_result::success); \
|
|
} while(0)
|
|
|
|
const ck_tile::stream_config stream_config{
|
|
nullptr, // stream_id_
|
|
false, // time_kernel_
|
|
1, // log_level_
|
|
0, // cold_niters_
|
|
1, // nrepeat_
|
|
true, // is_gpu_timer_
|
|
false, // flush_cache_
|
|
1, // rotating_count_
|
|
};
|
|
|
|
#define COMMON_ARGS \
|
|
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, 0, \
|
|
stream_config
|
|
|
|
auto EnableTestIf(bool condition)
|
|
{
|
|
return ValuesIn(condition ? std::vector<bool>{true} : std::vector<bool>{});
|
|
}
|
|
|
|
class AllLong : public TestWithParam<
|
|
std::tuple<bool,
|
|
std::tuple<int, int>,
|
|
bool,
|
|
bool,
|
|
mode_enum,
|
|
bool,
|
|
std::string,
|
|
float,
|
|
std::tuple<int, int, int, int, int, int, int, int, std::string>>>
|
|
{
|
|
};
|
|
|
|
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong);
|
|
|
|
// Test cases from example/ck_tile/01_fmha/script/smoke_test_fwd.sh
|
|
|
|
INSTANTIATE_TEST_SUITE_P(
|
|
TestCkTileFmhaFwd,
|
|
AllLong,
|
|
Combine(EnableTestIf(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))),
|
|
HDimValues,
|
|
Bool(),
|
|
IsVRowmajorValues,
|
|
ModeValues,
|
|
Bool(),
|
|
Values("n", "e", "a"),
|
|
Values(0.0f, 0.2f),
|
|
Values(std::tuple{2, 2, 1, 16, -1, 55, 256, -1, "0"},
|
|
std::tuple{1, 3, -1, -1, -1, 100, 51, -1, "0"},
|
|
std::tuple{2, 1, -1, 16, -1, 99, 256, -1, "1"},
|
|
std::tuple{1, 2, 1, -1, -1, 1024, 256, -1, "2"},
|
|
std::tuple{2, 1, -1, -1, 24, 3, 99, -1, "2"},
|
|
std::tuple{3, 2, 1, -1, -1, 200, 520, -1, "t:128,30"},
|
|
std::tuple{2, 1, -1, -1, -1, 99, 32, -1, "b:4,35"},
|
|
std::tuple{1, 2, 1, -1, -1, 33, 0, -1, "2"},
|
|
std::tuple{1, 2, 1, -1, -1, 1, 10, 32, "2"})));
|
|
|
|
TEST_P(AllLong, DataTypeConfig)
|
|
{
|
|
auto [_, hdims, perm, is_v_rowmajor, mode, lse, bias_str, p_drop, dims_mask] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [batch, nhead, nhead_k, hdim_q_, hdim_v_, seqlen_q, seqlen_k, seqlen_kpad, mask_str] =
|
|
dims_mask;
|
|
|
|
hdim_q = hdim_q_ == -1 ? hdim_q : hdim_q_;
|
|
hdim_v = hdim_v_ == -1 ? hdim_v : hdim_v_;
|
|
|
|
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
{adjust_seqlen(seqlen_q)},
|
|
{adjust_seqlen(seqlen_k)},
|
|
hdim_q,
|
|
hdim_v,
|
|
0, // seqlen_knew
|
|
{-1}, // seqlen_qpads
|
|
{seqlen_kpad}, // seqlen_kpads
|
|
{}, // q_eff_lens_per_batch
|
|
{}, // kv_eff_lens_per_batch
|
|
0, // rotary_dim
|
|
perm, // i_perm
|
|
perm, // o_perm
|
|
0, // scale_s
|
|
0, // logits_soft_cap
|
|
is_v_rowmajor, // is_v_rowmajor
|
|
lse, // lse
|
|
0, // page_block_size
|
|
false, // use_cache_batch_idx
|
|
bias_str, // bias_str
|
|
p_drop, // p_drop
|
|
123, // drop_seed
|
|
1024, // drop_offset
|
|
false, // drop_prefs
|
|
mask_str, // mask_str
|
|
qscale_str,
|
|
true, // is_rotary_interleaved
|
|
1, // num_splits
|
|
COMMON_ARGS);
|
|
CHECK_RESULT(result);
|
|
}
|
|
|
|
// ---------------------------------------------------------------
|
|
// Negative tests: padding not supported with appendkv/splitkv/pagedkv
|
|
// ---------------------------------------------------------------
|
|
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail)
|
|
{
|
|
// batch mode effective lengths simulate padding
|
|
auto result = fmha_fwd_run<DataTypeConfig>(
|
|
mode_enum::batch,
|
|
2, // batch
|
|
4, // nhead
|
|
-1, // nhead_k
|
|
{128}, // seqlen_qs
|
|
{128}, // seqlen_ks
|
|
64, // hdim_q
|
|
64, // hdim_v
|
|
32, // seqlen_knew -> triggers appendkv
|
|
{}, // seqlen_qpads
|
|
{}, // seqlen_kpads
|
|
{100, 120}, // q_eff_lens_per_batch
|
|
{90, 110}, // kv_eff_lens_per_batch
|
|
0, // rotary_dim
|
|
true, // i_perm
|
|
true, // o_perm
|
|
0, // scale_s
|
|
0, // logits_soft_cap
|
|
def_is_v_rowmajor,
|
|
def_lse,
|
|
0, // page_block_size
|
|
false, // use_cache_batch_idx
|
|
"n", // bias
|
|
0.0f, // p_drop
|
|
0, // drop_seed
|
|
0, // drop_offset
|
|
false, // drop_prefs
|
|
"0", // mask
|
|
qscale_str,
|
|
true, // is_rotary_interleaved
|
|
1, // num_splits
|
|
init_method,
|
|
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
|
0,
|
|
1, // init_sink
|
|
stream_config);
|
|
ASSERT_EQ(result, fwd_result::invalid_args);
|
|
}
|
|
#endif
|
|
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
|
TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail)
|
|
{
|
|
// group mode physical padding
|
|
auto result = fmha_fwd_run<DataTypeConfig>(
|
|
mode_enum::group,
|
|
2, // batch
|
|
4, // nhead
|
|
-1, // nhead_k
|
|
{96, 120}, // seqlen_qs logical
|
|
{96, 120}, // seqlen_ks logical
|
|
64, // hdim_q
|
|
64, // hdim_v
|
|
0, // seqlen_knew
|
|
{128, 128}, // seqlen_qpads
|
|
{128, 128}, // seqlen_kpads
|
|
{}, // q_eff
|
|
{}, // kv_eff
|
|
0, // rotary_dim
|
|
true, // i_perm
|
|
true, // o_perm
|
|
0, // scale_s
|
|
0, // logits_soft_cap
|
|
def_is_v_rowmajor,
|
|
def_lse,
|
|
0, // page_block_size
|
|
false, // use_cache_batch_idx
|
|
"n", // bias
|
|
0.0f,
|
|
0,
|
|
0,
|
|
false,
|
|
"0",
|
|
qscale_str,
|
|
true,
|
|
2, // num_splits (>1 triggers splitkv)
|
|
init_method,
|
|
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
|
0,
|
|
1, // init_sink
|
|
stream_config);
|
|
ASSERT_EQ(result, fwd_result::invalid_args);
|
|
}
|
|
#endif
|
|
|
|
#if CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail)
|
|
{
|
|
auto result = fmha_fwd_run<DataTypeConfig>(
|
|
mode_enum::group,
|
|
2,
|
|
4,
|
|
-1,
|
|
{80, 100},
|
|
{80, 100},
|
|
64,
|
|
64,
|
|
0, // seqlen_knew
|
|
{96, 128}, // seqlen_qpads
|
|
{96, 128}, // seqlen_kpads
|
|
{},
|
|
{},
|
|
0,
|
|
true,
|
|
true,
|
|
0,
|
|
0,
|
|
def_is_v_rowmajor,
|
|
def_lse,
|
|
128, // page_block_size triggers pagedkv
|
|
false,
|
|
"n",
|
|
0.0f,
|
|
0,
|
|
0,
|
|
false,
|
|
"0",
|
|
qscale_str,
|
|
true,
|
|
1,
|
|
init_method,
|
|
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
|
0,
|
|
1, // init_sink
|
|
stream_config);
|
|
ASSERT_EQ(result, fwd_result::invalid_args);
|
|
}
|
|
#endif
|
|
|
|
class HDimPadding
|
|
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
|
bool,
|
|
bool,
|
|
mode_enum,
|
|
std::tuple<int, int, int, int, int, int, std::string>>>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
|
HDimPadding,
|
|
Combine(Values(std::tuple{24, 48},
|
|
std::tuple{120, 160},
|
|
std::tuple{256, 108},
|
|
std::tuple{40, 64}),
|
|
Bool(),
|
|
IsVRowmajorValues,
|
|
ModeValues,
|
|
Values(std::tuple{1, 4, 2, 480, -1, -1, "0"},
|
|
std::tuple{2, 2, -1, 300, 400, 512, "t:64,64"},
|
|
std::tuple{1, 4, 1, 512, 201, 256, "1"},
|
|
std::tuple{1, 2, -1, 900, 256, -1, "0"},
|
|
std::tuple{2, 1, -1, 256, 256, -1, "1"})));
|
|
|
|
TEST_P(HDimPadding, DataTypeConfig)
|
|
{
|
|
auto [hdims, perm, is_v_rowmajor, mode, dims_mask] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, seqlen_kpad, mask_str] = dims_mask;
|
|
|
|
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
{adjust_seqlen(seqlen_q)},
|
|
{adjust_seqlen(seqlen_k)},
|
|
hdim_q,
|
|
hdim_v,
|
|
0, // seqlen_knew
|
|
{-1}, // seqlen_qpads
|
|
{seqlen_kpad}, // seqlen_kpads
|
|
{}, // q_eff_lens_per_batch
|
|
{}, // kv_eff_lens_per_batch
|
|
0, // rotary_dim
|
|
perm, // i_perm
|
|
perm, // o_perm
|
|
0, // scale_s
|
|
0, // logits_soft_cap
|
|
is_v_rowmajor, // is_v_rowmajor
|
|
def_lse, // lse
|
|
0, // page_block_size
|
|
false, // use_cache_batch_idx
|
|
"n", // bias_str
|
|
0.0f, // p_drop
|
|
0, // drop_seed
|
|
0, // drop_offset
|
|
false, // drop_prefs
|
|
mask_str, // mask_str
|
|
qscale_str,
|
|
true, // is_rotary_interleaved
|
|
1, // num_splits
|
|
COMMON_ARGS);
|
|
CHECK_RESULT(result);
|
|
}
|
|
|
|
class ElementwiseBias
|
|
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
|
bool,
|
|
mode_enum,
|
|
std::string,
|
|
std::tuple<int, int, int, int, int, std::string>>>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
|
ElementwiseBias,
|
|
Combine(HDimValues,
|
|
Bool(), // layout of bias is controlled by i_perm
|
|
ModeValues,
|
|
Values("e:0", "e:1", "e:2"),
|
|
Values(std::tuple{1, 4, 2, 1024, 100, "0"},
|
|
std::tuple{3, 2, -1, 128, 256, "2"},
|
|
std::tuple{2, 2, -1, 130, 499, "t:50,64"})));
|
|
|
|
TEST_P(ElementwiseBias, DataTypeConfig)
|
|
{
|
|
auto [hdims, i_perm, mode, bias_str, dims_mask] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
|
|
|
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
{adjust_seqlen(seqlen_q)},
|
|
{adjust_seqlen(seqlen_k)},
|
|
hdim_q,
|
|
hdim_v,
|
|
0, // seqlen_knew
|
|
{-1}, // seqlen_qpads
|
|
{-1}, // seqlen_kpads
|
|
{}, // q_eff_lens_per_batch
|
|
{}, // kv_eff_lens_per_batch
|
|
0, // rotary_dim
|
|
i_perm, // i_perm
|
|
false, // o_perm
|
|
0, // scale_s
|
|
0, // logits_soft_cap
|
|
def_is_v_rowmajor, // is_v_rowmajor
|
|
def_lse, // lse
|
|
0, // page_block_size
|
|
false, // use_cache_batch_idx
|
|
bias_str, // bias_str
|
|
0.0f, // p_drop
|
|
0, // drop_seed
|
|
0, // drop_offset
|
|
false, // drop_prefs
|
|
mask_str, // mask_str
|
|
qscale_str,
|
|
true, // is_rotary_interleaved
|
|
1, // num_splits
|
|
COMMON_ARGS);
|
|
CHECK_RESULT(result);
|
|
}
|
|
|
|
class Alibi : public TestWithParam<std::tuple<std::tuple<int, int>,
|
|
mode_enum,
|
|
std::string,
|
|
std::tuple<int, int, int, int, int>,
|
|
std::string>>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
|
Alibi,
|
|
Combine(HDimValues,
|
|
ModeValues,
|
|
Values("a:0", "a:1"),
|
|
Values(std::tuple{1, 3, 3, 1024, 1000},
|
|
std::tuple{3, 5, 5, 128, 256},
|
|
std::tuple{2, 8, 2, 300, 355}),
|
|
Values("0", "t", "b", "t:50,64", "b:32,40")));
|
|
|
|
TEST_P(Alibi, DataTypeConfig)
|
|
{
|
|
auto [hdims, mode, bias_str, dims, mask_str] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k] = dims;
|
|
|
|
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
{adjust_seqlen(seqlen_q)},
|
|
{adjust_seqlen(seqlen_k)},
|
|
hdim_q,
|
|
hdim_v,
|
|
0, // seqlen_knew
|
|
{-1}, // seqlen_qpads
|
|
{-1}, // seqlen_kpads
|
|
{}, // q_eff_lens_per_batch
|
|
{}, // kv_eff_lens_per_batch
|
|
0, // rotary_dim
|
|
true, // i_perm
|
|
true, // o_perm
|
|
0, // scale_s
|
|
0, // logits_soft_cap
|
|
def_is_v_rowmajor, // is_v_rowmajor
|
|
def_lse, // lse
|
|
0, // page_block_size
|
|
false, // use_cache_batch_idx
|
|
bias_str, // bias_str
|
|
0.0f, // p_drop
|
|
0, // drop_seed
|
|
0, // drop_offset
|
|
false, // drop_prefs
|
|
mask_str, // mask_str
|
|
qscale_str,
|
|
true, // is_rotary_interleaved
|
|
1, // num_splits
|
|
COMMON_ARGS);
|
|
CHECK_RESULT(result);
|
|
}
|
|
|
|
class Dropout : public TestWithParam<std::tuple<std::tuple<int, int>,
|
|
mode_enum,
|
|
float,
|
|
std::tuple<uint64_t, uint64_t, bool>,
|
|
std::tuple<int, int, int, int, int, std::string>>>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
|
Dropout,
|
|
Combine(HDimValues,
|
|
ModeValues,
|
|
Values(0.123f, 0.5f),
|
|
Values(std::tuple{10, 123, false},
|
|
std::tuple{34534564645, 7876878876864, true}),
|
|
Values(std::tuple{2, 4, 2, 280, 512, "0"},
|
|
std::tuple{3, 2, 2, 256, 128, "1"},
|
|
std::tuple{4, 3, 1, 100, 768, "2"})));
|
|
|
|
TEST_P(Dropout, DataTypeConfig)
|
|
{
|
|
auto [hdims, mode, p_drop, drop_seed_offset_prefs, dims_mask] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
|
|
|
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
{adjust_seqlen(seqlen_q)},
|
|
{adjust_seqlen(seqlen_k)},
|
|
hdim_q,
|
|
hdim_v,
|
|
0, // seqlen_knew
|
|
{-1}, // seqlen_qpads
|
|
{-1}, // seqlen_kpads
|
|
{}, // q_eff_lens_per_batch
|
|
{}, // kv_eff_lens_per_batch
|
|
0, // rotary_dim
|
|
false, // i_perm
|
|
false, // o_perm
|
|
0, // scale_s
|
|
0, // logits_soft_cap
|
|
def_is_v_rowmajor, // is_v_rowmajor
|
|
def_lse, // lse
|
|
0, // page_block_size
|
|
false, // use_cache_batch_idx
|
|
"n", // bias_str
|
|
p_drop, // p_drop
|
|
drop_seed, // drop_seed
|
|
drop_offset, // drop_offset
|
|
drop_prefs, // drop_prefs
|
|
mask_str, // mask_str
|
|
qscale_str,
|
|
true, // is_rotary_interleaved
|
|
1, // num_splits
|
|
COMMON_ARGS);
|
|
CHECK_RESULT(result);
|
|
}
|
|
|
|
#if CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
|
|
class PagedKV : public TestWithParam<std::tuple<std::tuple<int, int>,
|
|
bool,
|
|
bool,
|
|
mode_enum,
|
|
int,
|
|
std::tuple<int, int, int, int, int, std::string>>>
|
|
{
|
|
};
|
|
|
|
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(PagedKV);
|
|
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
|
PagedKV,
|
|
Combine(SplitKVHDimValues,
|
|
Bool(), // layouts of k and v are controlled by i_perm
|
|
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
|
ModeValues,
|
|
Values(128, 256),
|
|
Values(std::tuple{2, 3, 1, 200, 1024, "0"},
|
|
std::tuple{3, 2, -1, 128, 768, "2"},
|
|
std::tuple{2, 2, -1, 230, 899, "t:50,64"})));
|
|
|
|
TEST_P(PagedKV, DataTypeConfig)
|
|
{
|
|
auto [hdims, i_perm, is_v_rowmajor, mode, page_block_size, dims_mask] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
|
|
|
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
{adjust_seqlen(seqlen_q)},
|
|
{adjust_seqlen(seqlen_k)},
|
|
hdim_q,
|
|
hdim_v,
|
|
0, // seqlen_knew
|
|
{-1}, // seqlen_qpads
|
|
{-1}, // seqlen_kpads
|
|
{}, // q_eff_lens_per_batch
|
|
{}, // kv_eff_lens_per_batch
|
|
0, // rotary_dim
|
|
i_perm, // i_perm
|
|
false, // o_perm
|
|
0, // scale_s
|
|
0, // logits_soft_cap
|
|
is_v_rowmajor, // is_v_rowmajor
|
|
false, // lse
|
|
page_block_size, // page_block_size
|
|
false, // use_cache_batch_idx
|
|
"n", // bias_str
|
|
0.0f, // p_drop
|
|
0, // drop_seed
|
|
0, // drop_offset
|
|
false, // drop_prefs
|
|
mask_str, // mask_str
|
|
qscale_str,
|
|
true, // is_rotary_interleaved
|
|
1, // num_splits
|
|
COMMON_ARGS);
|
|
CHECK_RESULT(result);
|
|
}
|
|
|
|
#endif // CK_TILE_FMHA_FWD_PAGEDKV_API
|
|
|
|
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
|
|
|
class SplitKV : public TestWithParam<std::tuple<std::tuple<int, int>,
|
|
bool,
|
|
bool,
|
|
std::tuple<mode_enum, bool>,
|
|
int,
|
|
std::tuple<int, int, int, int, int, std::string>>>
|
|
{
|
|
};
|
|
|
|
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(SplitKV);
|
|
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
|
SplitKV,
|
|
Combine(SplitKVHDimValues,
|
|
Bool(), // layouts of k and v are controlled by i_perm
|
|
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
|
Values(std::tuple{mode_enum::batch, false},
|
|
std::tuple{mode_enum::batch, true},
|
|
std::tuple{mode_enum::group, false}),
|
|
Values(3, 4),
|
|
Values(std::tuple{4, 3, 1, 200, 1024, "0"},
|
|
std::tuple{2, 2, -1, 512, 2000, "0"},
|
|
std::tuple{3, 2, -1, 230, 899, "t:128,128"})));
|
|
|
|
TEST_P(SplitKV, DataTypeConfig)
|
|
{
|
|
auto [hdims, i_perm, is_v_rowmajor, mode_use_cache_batch_idx, num_splits, dims_mask] =
|
|
GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [mode, use_cache_batch_idx] = mode_use_cache_batch_idx;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
|
|
|
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
{adjust_seqlen(seqlen_q)},
|
|
{adjust_seqlen(seqlen_k)},
|
|
hdim_q,
|
|
hdim_v,
|
|
0, // seqlen_knew
|
|
{-1}, // seqlen_qpads
|
|
{-1}, // seqlen_kpads
|
|
{}, // q_eff_lens_per_batch
|
|
{}, // kv_eff_lens_per_batch
|
|
0, // rotary_dim
|
|
i_perm, // i_perm
|
|
false, // o_perm
|
|
0, // scale_s
|
|
0, // logits_soft_cap
|
|
is_v_rowmajor, // is_v_rowmajor
|
|
def_lse, // lse
|
|
0, // page_block_size
|
|
use_cache_batch_idx, // use_cache_batch_idx
|
|
"n", // bias_str
|
|
0.0f, // p_drop
|
|
0, // drop_seed
|
|
0, // drop_offset
|
|
false, // drop_prefs
|
|
mask_str, // mask_str
|
|
qscale_str,
|
|
true, // is_rotary_interleaved
|
|
num_splits, // num_splits
|
|
COMMON_ARGS);
|
|
CHECK_RESULT(result);
|
|
}
|
|
|
|
#endif // CK_TILE_FMHA_FWD_SPLITKV_API
|
|
|
|
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
|
|
|
class AppendKV : public TestWithParam<std::tuple<std::tuple<int, int>,
|
|
bool,
|
|
bool,
|
|
std::tuple<int, bool>,
|
|
int,
|
|
std::tuple<int, int, int, int, int, std::string>>>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(
|
|
TestCkTileFmhaFwd,
|
|
AppendKV,
|
|
Combine(AppendKVHDimValues,
|
|
Bool(), // layouts of k and v are controlled by i_perm
|
|
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
|
ValuesIn({std::tuple{0, true}, std::tuple{0, false}, std::tuple{128, false}}),
|
|
Values(1, 64, -1),
|
|
Values(std::tuple{3, 3, -1, 60, 129, "t:32,32"},
|
|
std::tuple{3, 2, 2, 256, 256, "0"},
|
|
std::tuple{2, 3, 1, 264, 265, "1"},
|
|
std::tuple{4, 4, 2, 71, 64, "1"})));
|
|
|
|
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKV);
|
|
|
|
TEST_P(AppendKV, DataTypeConfig)
|
|
{
|
|
auto [hdims,
|
|
i_perm,
|
|
is_v_rowmajor,
|
|
page_block_size_use_cache_batch_idx,
|
|
seqlen_knew,
|
|
dims_mask] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [page_block_size, use_cache_batch_idx] = page_block_size_use_cache_batch_idx;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
|
|
|
seqlen_knew = seqlen_knew == -1 ? seqlen_k : seqlen_knew;
|
|
|
|
auto result = fmha_fwd_run<DataTypeConfig>(mode_enum::batch,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
{adjust_seqlen(seqlen_q)},
|
|
{adjust_seqlen(seqlen_k)},
|
|
hdim_q,
|
|
hdim_v,
|
|
seqlen_knew, // seqlen_knew
|
|
{-1}, // seqlen_qpads
|
|
{-1}, // seqlen_kpads
|
|
{}, // q_eff_lens_per_batch
|
|
{}, // kv_eff_lens_per_batch
|
|
0, // rotary_dim
|
|
i_perm, // i_perm
|
|
true, // o_perm
|
|
0, // scale_s
|
|
0, // logits_soft_cap
|
|
is_v_rowmajor, // is_v_rowmajor
|
|
def_lse, // lse
|
|
page_block_size, // page_block_size
|
|
use_cache_batch_idx, // use_cache_batch_idx
|
|
"n", // bias_str
|
|
0.0f, // p_drop
|
|
0, // drop_seed
|
|
0, // drop_offset
|
|
false, // drop_prefs
|
|
mask_str, // mask_str
|
|
qscale_str,
|
|
false, // is_rotary_interleaved
|
|
1, // num_splits
|
|
COMMON_ARGS);
|
|
CHECK_RESULT(result);
|
|
}
|
|
|
|
class AppendKVRoPE
|
|
: public TestWithParam<std::tuple<bool,
|
|
std::tuple<int, int>,
|
|
bool,
|
|
bool,
|
|
std::tuple<int, bool>,
|
|
int,
|
|
std::tuple<int, int, int, int, int, std::string>>>
|
|
{
|
|
};
|
|
|
|
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKVRoPE);
|
|
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
|
AppendKVRoPE,
|
|
Combine(EnableTestIf(!std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>),
|
|
AppendKVHDimValues,
|
|
Bool(), // layouts of k and v are controlled by i_perm
|
|
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
|
Values(std::tuple{0, false},
|
|
std::tuple{16, true},
|
|
std::tuple{32, false},
|
|
std::tuple{-1, true}),
|
|
Values(16, 50, -1),
|
|
Values(std::tuple{2, 3, -1, 60, 129, "t:32,32"},
|
|
std::tuple{1, 2, 1, 128, 55, "0"},
|
|
std::tuple{3, 4, 2, 72, 128, "1"})));
|
|
|
|
TEST_P(AppendKVRoPE, DataTypeConfig)
|
|
{
|
|
auto [_, hdims, i_perm, is_v_rowmajor, rotary, seqlen_knew, dims_mask] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [rotary_dim, is_rotary_interleaved] = rotary;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
|
|
|
rotary_dim = rotary_dim == -1 ? hdim_q : rotary_dim;
|
|
seqlen_knew = seqlen_knew == -1 ? seqlen_k : seqlen_knew;
|
|
|
|
auto result = fmha_fwd_run<DataTypeConfig>(mode_enum::batch,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
{adjust_seqlen(seqlen_q)},
|
|
{adjust_seqlen(seqlen_k)},
|
|
hdim_q,
|
|
hdim_v,
|
|
seqlen_knew, // seqlen_knew
|
|
{-1}, // seqlen_qpads
|
|
{-1}, // seqlen_kpads
|
|
{}, // q_eff_lens_per_batch
|
|
{}, // kv_eff_lens_per_batch
|
|
rotary_dim, // rotary_dim
|
|
i_perm, // i_perm
|
|
true, // o_perm
|
|
0, // scale_s
|
|
0, // logits_soft_cap
|
|
is_v_rowmajor, // is_v_rowmajor
|
|
true, // lse
|
|
0, // page_block_size
|
|
false, // use_cache_batch_idx
|
|
"n", // bias_str
|
|
0.0f, // p_drop
|
|
0, // drop_seed
|
|
0, // drop_offset
|
|
false, // drop_prefs
|
|
mask_str, // mask_str
|
|
qscale_str,
|
|
is_rotary_interleaved, // is_rotary_interleaved
|
|
1, // num_splits
|
|
COMMON_ARGS);
|
|
CHECK_RESULT(result);
|
|
}
|
|
|
|
#endif // CK_TILE_FMHA_FWD_APPENDKV_API
|
|
|
|
// ---------------------------------------------------------------
|
|
// Parameterized padding tests (batch & group) using Combine+Values
|
|
// ---------------------------------------------------------------
|
|
|
|
using PaddingParam = std::tuple<mode_enum, // mode
|
|
int, // batch
|
|
int, // nhead
|
|
int, // nhead_k
|
|
std::vector<int>, // seqlen_qs (logical)
|
|
std::vector<int>, // seqlen_ks (logical)
|
|
std::vector<int>, // seqlen_qpads (physical padded lengths)
|
|
std::vector<int>, // seqlen_kpads (physical padded lengths)
|
|
std::vector<int>, // q_eff_lens
|
|
std::vector<int>, // kv_eff_lens
|
|
bool, // i_perm
|
|
bool, // o_perm
|
|
std::string>; // mask_str
|
|
|
|
class PaddingCases : public TestWithParam<PaddingParam>
|
|
{
|
|
};
|
|
|
|
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(PaddingCases);
|
|
|
|
// Build padding test params programmatically to enforce constraints
|
|
static std::vector<PaddingParam> BuildPaddingParams()
|
|
{
|
|
std::vector<PaddingParam> params;
|
|
|
|
if constexpr(ck_tile::is_any_of<DataTypeConfig, FmhaFwdFp8Bf16, FmhaFwdMxFp8, FmhaFwdMxFp4>::
|
|
value)
|
|
{
|
|
return params;
|
|
}
|
|
|
|
// mask variants to cover
|
|
const std::vector<std::string> mask_variants{"0", "t:50,64", "b:32,40"};
|
|
const std::vector<std::string> mask_variants_reduced{"0", "t:50,64"}; // used for trimmed sets
|
|
|
|
// Representative ratio pairs (q_ratio, k_ratio) to avoid explosion
|
|
const std::vector<std::pair<double, double>> ratio_pairs_full{
|
|
{1.0, 1.0}, // both full
|
|
{1.0, 0.5}, // q full, k half
|
|
{0.5, 1.0}, // q half, k full
|
|
};
|
|
const std::vector<std::pair<double, double>> ratio_pairs_reduced{{1.0, 1.0}, {0.5, 1.0}};
|
|
|
|
// candidate physical seqlens for batch mode (single value) & for group mode (per batch)
|
|
const std::vector<int> physical_lengths_full{64, 128, 256};
|
|
const std::vector<int> physical_lengths_reduced{64};
|
|
|
|
// batch sizes to sample
|
|
const std::vector<int> batch_sizes{1, 4};
|
|
// --------------------------------------------------------------------
|
|
// Head configuration space (cover MHA, GQA, MQA)
|
|
// - Standard MHA: nhead_k == -1 (treated internally as nhead)
|
|
// - GQA: nhead_k > 0 and nhead % nhead_k == 0, nhead_k < nhead
|
|
// - MQA: nhead_k == 1
|
|
// We choose (9, -1), (9, 3), (9, 1) so that divisibility holds. Full
|
|
// combinatorics only applied to the first (standard) configuration to
|
|
// avoid test explosion.
|
|
// --------------------------------------------------------------------
|
|
struct HeadCfg
|
|
{
|
|
int nhead;
|
|
int nhead_k; // -1 for standard; else must divide nhead
|
|
bool full; // whether to use full coverage sets
|
|
};
|
|
const std::vector<HeadCfg> head_cfgs = {
|
|
{9, -1, true}, // MHA full
|
|
{9, 3, false}, // GQA reduced (nhead/nhead_k=3)
|
|
{9, 1, false} // MQA reduced
|
|
};
|
|
|
|
// Helper to clamp and ensure >=1
|
|
auto logical_len = [](int physical, double ratio) {
|
|
int v = static_cast<int>(std::round(physical * ratio));
|
|
v = std::max(1, std::min(v, physical));
|
|
return v;
|
|
};
|
|
// Iterate over head configurations
|
|
for(const auto& hc : head_cfgs)
|
|
{
|
|
const auto& ratio_pairs = hc.full ? ratio_pairs_full : ratio_pairs_reduced;
|
|
const auto& phys_lengths_batch = hc.full ? physical_lengths_full : physical_lengths_reduced;
|
|
const auto& phys_lengths_group_q = phys_lengths_batch; // reuse
|
|
const auto& phys_lengths_group_k = phys_lengths_batch; // reuse
|
|
const auto& masks = hc.full ? mask_variants : mask_variants_reduced;
|
|
|
|
// -----------------
|
|
// Batch mode params (effective lengths only)
|
|
// -----------------
|
|
for(int b : batch_sizes)
|
|
{
|
|
for(int phys_qkv : phys_lengths_batch)
|
|
{
|
|
for(const auto& rkpair : ratio_pairs)
|
|
{
|
|
double rq = rkpair.first;
|
|
double rk = rkpair.second;
|
|
std::vector<int> q_eff(b), kv_eff(b);
|
|
int log_q = logical_len(phys_qkv, rq);
|
|
int log_k = logical_len(phys_qkv, rk);
|
|
for(int i = 0; i < b; ++i)
|
|
{
|
|
q_eff[i] = log_q;
|
|
kv_eff[i] = log_k;
|
|
}
|
|
for(const auto& mask : masks)
|
|
{
|
|
params.emplace_back(PaddingParam{mode_enum::batch,
|
|
b,
|
|
hc.nhead,
|
|
hc.nhead_k,
|
|
{phys_qkv}, // seqlen_qs
|
|
{phys_qkv}, // seqlen_ks
|
|
{}, // seqlen_qpads
|
|
{}, // seqlen_kpads
|
|
q_eff,
|
|
kv_eff,
|
|
true,
|
|
true,
|
|
mask});
|
|
}
|
|
}
|
|
// Single-token logical length case (both q & k = 1)
|
|
for(const auto& mask : masks)
|
|
{
|
|
std::vector<int> q_eff(b, 1), kv_eff(b, 1);
|
|
params.emplace_back(PaddingParam{mode_enum::batch,
|
|
b,
|
|
hc.nhead,
|
|
hc.nhead_k,
|
|
{phys_qkv},
|
|
{phys_qkv},
|
|
{},
|
|
{},
|
|
q_eff,
|
|
kv_eff,
|
|
true,
|
|
true,
|
|
mask});
|
|
}
|
|
}
|
|
}
|
|
|
|
// -----------------
|
|
// Group mode params (physical padding + logical variants)
|
|
// -----------------
|
|
for(int b : batch_sizes)
|
|
{
|
|
for(int phys_q : phys_lengths_group_q)
|
|
{
|
|
for(int phys_k : phys_lengths_group_k)
|
|
{
|
|
for(const auto& rkpair : ratio_pairs)
|
|
{
|
|
double rq = rkpair.first;
|
|
double rk = rkpair.second;
|
|
std::vector<int> seqlen_qs(b), seqlen_ks(b), seqlen_qpads(b),
|
|
seqlen_kpads(b);
|
|
for(int i = 0; i < b; ++i)
|
|
{
|
|
seqlen_qpads[i] = phys_q;
|
|
seqlen_kpads[i] = phys_k;
|
|
seqlen_qs[i] = logical_len(phys_q, rq);
|
|
seqlen_ks[i] = logical_len(phys_k, rk);
|
|
}
|
|
std::array<std::pair<std::vector<int>, std::vector<int>>, 3> pad_variants{
|
|
std::pair{seqlen_qpads, seqlen_kpads}, // both
|
|
std::pair{seqlen_qpads, seqlen_ks}, // only q padding
|
|
std::pair{seqlen_qs, seqlen_kpads} // only kv padding
|
|
};
|
|
for(const auto& mask : masks)
|
|
{
|
|
for(const auto& pv : pad_variants)
|
|
{
|
|
params.emplace_back(PaddingParam{mode_enum::group,
|
|
b,
|
|
hc.nhead,
|
|
hc.nhead_k,
|
|
seqlen_qs,
|
|
seqlen_ks,
|
|
pv.first,
|
|
pv.second,
|
|
{},
|
|
{},
|
|
true,
|
|
true,
|
|
mask});
|
|
}
|
|
}
|
|
}
|
|
// Single-token logical length case
|
|
for(const auto& mask : masks)
|
|
{
|
|
std::vector<int> seqlen_qs(b, 1), seqlen_ks(b, 1);
|
|
std::vector<int> seqlen_qpads(b, phys_q), seqlen_kpads(b, phys_k);
|
|
// both padding variant only (others degenerate)
|
|
params.emplace_back(PaddingParam{mode_enum::group,
|
|
b,
|
|
hc.nhead,
|
|
hc.nhead_k,
|
|
seqlen_qs,
|
|
seqlen_ks,
|
|
seqlen_qpads,
|
|
seqlen_kpads,
|
|
{},
|
|
{},
|
|
true,
|
|
true,
|
|
mask});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return params;
|
|
}
|
|
|
|
static const std::vector<PaddingParam> kPaddingParams = BuildPaddingParams();
|
|
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, PaddingCases, ValuesIn(kPaddingParams));
|
|
|
|
TEST_P(PaddingCases, DataTypeConfig)
|
|
{
|
|
auto [mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
seqlen_qs,
|
|
seqlen_ks,
|
|
seqlen_qpads,
|
|
seqlen_kpads,
|
|
q_eff_lens,
|
|
kv_eff_lens,
|
|
i_perm,
|
|
o_perm,
|
|
mask_str] = GetParam();
|
|
|
|
// For batch mode we wrap single logical lengths with adjust_seqlen.
|
|
std::vector<int> adj_qs =
|
|
(mode == mode_enum::batch) ? std::vector<int>{adjust_seqlen(seqlen_qs.at(0))} : seqlen_qs;
|
|
std::vector<int> adj_ks =
|
|
(mode == mode_enum::batch) ? std::vector<int>{adjust_seqlen(seqlen_ks.at(0))} : seqlen_ks;
|
|
|
|
const int hdim_q = 64;
|
|
const int hdim_v = 64;
|
|
const int seqlen_knew = 0;
|
|
|
|
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
adj_qs,
|
|
adj_ks,
|
|
hdim_q,
|
|
hdim_v,
|
|
seqlen_knew, // seqlen_knew
|
|
seqlen_qpads, // seqlen_qpads
|
|
seqlen_kpads, // seqlen_kpads
|
|
q_eff_lens, // q_eff_lens_per_batch
|
|
kv_eff_lens, // kv_eff_lens_per_batch
|
|
0, // rotary_dim
|
|
i_perm, // i_perm
|
|
o_perm, // o_perm
|
|
0, // scale_s
|
|
0, // logits_soft_cap
|
|
def_is_v_rowmajor,
|
|
def_lse, // lse
|
|
0, // page_block_size
|
|
false, // use_cache_batch_idx
|
|
"n", // bias_str
|
|
0.0f, // p_drop
|
|
0, // drop_seed
|
|
0, // drop_offset
|
|
false, // drop_prefs
|
|
mask_str, // mask_str
|
|
qscale_str,
|
|
true, // is_rotary_interleaved
|
|
1, // num_splits
|
|
COMMON_ARGS);
|
|
CHECK_RESULT(result);
|
|
}
|