Support fp8 dynamic quantization for fmha (#3206)

* Support qscale for dynamic quant, remove static quant

* Support hdim=256

* Remove bias test case for fp8

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
rocking
2025-11-24 16:28:25 +08:00
committed by GitHub
parent 096f0a3b23
commit 5948dbffe4
17 changed files with 369 additions and 280 deletions

View File

@@ -7,7 +7,7 @@
#include "gtest/gtest.h"
#ifndef DataTypeConfig
#define DataTypeConfig FmhaFwdFp16 // or FmhaFwdBf16 / FmhaFwdFp8 / FmhaFwdFp32
#define DataTypeConfig FmhaFwdFp16 // or FmhaFwdBf16 / FmhaFwdFp8Bf16 / FmhaFwdFp32
#endif
using ::testing::Bool;
@@ -39,13 +39,14 @@ struct TestConfigs
std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
static constexpr auto IsVRowmajorValues = std::array{true};
static constexpr bool squant = false;
static constexpr auto qscale_str = "n";
static constexpr bool def_lse = true;
static constexpr bool def_is_v_rowmajor = true;
static int adjust_seqlen(int seqlen) { return seqlen; }
};
template <>
struct TestConfigs<FmhaFwdFp8>
struct TestConfigs<FmhaFwdFp8Bf16>
{
static constexpr auto HDimValues =
std::array{std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
@@ -53,13 +54,14 @@ struct TestConfigs<FmhaFwdFp8>
static constexpr auto AppendKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}};
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
static constexpr auto IsVRowmajorValues = std::array{true};
static constexpr bool squant = true;
static constexpr auto qscale_str = "pt";
static constexpr bool def_lse = false;
static constexpr bool def_is_v_rowmajor = true;
// When there are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests:
// return ck_tile::integer_least_multiple(seqlen, 128);
static int adjust_seqlen(int seqlen) { return seqlen; }
};
template <>
struct TestConfigs<FmhaFwdFp32>
{
@@ -76,7 +78,7 @@ struct TestConfigs<FmhaFwdFp32>
static constexpr auto AppendKVHDimValues = std::array<std::tuple<int, int>, 0>{};
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
static constexpr auto IsVRowmajorValues = std::array{true};
static constexpr bool squant = false;
static constexpr auto qscale_str = "n";
static constexpr bool def_lse = true;
static constexpr bool def_is_v_rowmajor = true;
static int adjust_seqlen(int seqlen) { return seqlen; }
@@ -87,7 +89,7 @@ static auto SplitKVHDimValues = ValuesIn(TestConfigs<DataTypeConfig>::SplitKV
static auto AppendKVHDimValues = ValuesIn(TestConfigs<DataTypeConfig>::AppendKVHDimValues);
static auto ModeValues = ValuesIn(TestConfigs<DataTypeConfig>::ModeValues);
static auto IsVRowmajorValues = ValuesIn(TestConfigs<DataTypeConfig>::IsVRowmajorValues);
constexpr bool squant = TestConfigs<DataTypeConfig>::squant;
constexpr static auto qscale_str = TestConfigs<DataTypeConfig>::qscale_str;
constexpr bool def_lse = TestConfigs<DataTypeConfig>::def_lse;
constexpr bool def_is_v_rowmajor = TestConfigs<DataTypeConfig>::def_is_v_rowmajor;
int adjust_seqlen(int seqlen) { return TestConfigs<DataTypeConfig>::adjust_seqlen(seqlen); }
@@ -203,7 +205,7 @@ TEST_P(AllLong, DataTypeConfig)
1024, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
qscale_str,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
@@ -247,7 +249,7 @@ TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail)
0, // drop_offset
false, // drop_prefs
"0", // mask
squant,
qscale_str,
true, // is_rotary_interleaved
1, // num_splits
init_method,
@@ -291,7 +293,7 @@ TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail)
0,
false,
"0",
squant,
qscale_str,
true,
2, // num_splits (>1 triggers splitkv)
init_method,
@@ -334,7 +336,7 @@ TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail)
0,
false,
"0",
squant,
qscale_str,
true,
1,
init_method,
@@ -403,7 +405,7 @@ TEST_P(HDimPadding, DataTypeConfig)
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
qscale_str,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
@@ -463,7 +465,7 @@ TEST_P(ElementwiseBias, DataTypeConfig)
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
qscale_str,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
@@ -522,7 +524,7 @@ TEST_P(Alibi, DataTypeConfig)
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
qscale_str,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
@@ -583,7 +585,7 @@ TEST_P(Dropout, DataTypeConfig)
drop_offset, // drop_offset
drop_prefs, // drop_prefs
mask_str, // mask_str
squant,
qscale_str,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
@@ -648,7 +650,7 @@ TEST_P(PagedKV, DataTypeConfig)
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
qscale_str,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
@@ -719,7 +721,7 @@ TEST_P(SplitKV, DataTypeConfig)
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
qscale_str,
true, // is_rotary_interleaved
num_splits, // num_splits
COMMON_ARGS);
@@ -796,7 +798,7 @@ TEST_P(AppendKV, DataTypeConfig)
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
qscale_str,
false, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
@@ -818,7 +820,7 @@ GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKVRoPE);
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
AppendKVRoPE,
Combine(EnableTestIf(!std::is_same_v<DataTypeConfig, FmhaFwdFp8>),
Combine(EnableTestIf(!std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>),
AppendKVHDimValues,
Bool(), // layouts of k and v are controlled by i_perm
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
@@ -869,7 +871,7 @@ TEST_P(AppendKVRoPE, DataTypeConfig)
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
qscale_str,
is_rotary_interleaved, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
@@ -1105,7 +1107,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd_Padding, PaddingCases, ValuesIn(kPadd
TEST_P(PaddingCases, DataTypeConfig)
{
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>)
{
GTEST_SKIP() << "Skip for fp8";
}
@@ -1162,7 +1164,7 @@ TEST_P(PaddingCases, DataTypeConfig)
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
qscale_str,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);