mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Support fp8 dynamic quantization for fmha (#3206)
* Support qscale for dynamic quant, remove static quant * Support hdim=256 * Remove bias test case for fp8 --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
@@ -9,10 +9,10 @@ set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
|
||||
set(TEST_NAME "test_ck_tile_fmha")
|
||||
|
||||
function(add_gtest_fwd test_group)
|
||||
set(V_TYPES "fp16" "bf16" "fp8" "fp32")
|
||||
set(V_TYPES "fp16" "bf16" "fp8bf16" "fp32")
|
||||
set(CPP_TYPE_fp16 "FmhaFwdFp16")
|
||||
set(CPP_TYPE_bf16 "FmhaFwdBf16")
|
||||
set(CPP_TYPE_fp8 "FmhaFwdFp8")
|
||||
set(CPP_TYPE_fp8bf16 "FmhaFwdFp8Bf16")
|
||||
set(CPP_TYPE_fp32 "FmhaFwdFp32")
|
||||
|
||||
set(all_tests)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#ifndef DataTypeConfig
|
||||
#define DataTypeConfig FmhaFwdFp16 // or FmhaFwdBf16 / FmhaFwdFp8 / FmhaFwdFp32
|
||||
#define DataTypeConfig FmhaFwdFp16 // or FmhaFwdBf16 / FmhaFwdFp8Bf16 / FmhaFwdFp32
|
||||
#endif
|
||||
|
||||
using ::testing::Bool;
|
||||
@@ -39,13 +39,14 @@ struct TestConfigs
|
||||
std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
||||
static constexpr auto IsVRowmajorValues = std::array{true};
|
||||
static constexpr bool squant = false;
|
||||
static constexpr auto qscale_str = "n";
|
||||
static constexpr bool def_lse = true;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
static int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TestConfigs<FmhaFwdFp8>
|
||||
struct TestConfigs<FmhaFwdFp8Bf16>
|
||||
{
|
||||
static constexpr auto HDimValues =
|
||||
std::array{std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
|
||||
@@ -53,13 +54,14 @@ struct TestConfigs<FmhaFwdFp8>
|
||||
static constexpr auto AppendKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}};
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
||||
static constexpr auto IsVRowmajorValues = std::array{true};
|
||||
static constexpr bool squant = true;
|
||||
static constexpr auto qscale_str = "pt";
|
||||
static constexpr bool def_lse = false;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
// When there are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests:
|
||||
// return ck_tile::integer_least_multiple(seqlen, 128);
|
||||
static int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TestConfigs<FmhaFwdFp32>
|
||||
{
|
||||
@@ -76,7 +78,7 @@ struct TestConfigs<FmhaFwdFp32>
|
||||
static constexpr auto AppendKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
||||
static constexpr auto IsVRowmajorValues = std::array{true};
|
||||
static constexpr bool squant = false;
|
||||
static constexpr auto qscale_str = "n";
|
||||
static constexpr bool def_lse = true;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
static int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
@@ -87,7 +89,7 @@ static auto SplitKVHDimValues = ValuesIn(TestConfigs<DataTypeConfig>::SplitKV
|
||||
static auto AppendKVHDimValues = ValuesIn(TestConfigs<DataTypeConfig>::AppendKVHDimValues);
|
||||
static auto ModeValues = ValuesIn(TestConfigs<DataTypeConfig>::ModeValues);
|
||||
static auto IsVRowmajorValues = ValuesIn(TestConfigs<DataTypeConfig>::IsVRowmajorValues);
|
||||
constexpr bool squant = TestConfigs<DataTypeConfig>::squant;
|
||||
constexpr static auto qscale_str = TestConfigs<DataTypeConfig>::qscale_str;
|
||||
constexpr bool def_lse = TestConfigs<DataTypeConfig>::def_lse;
|
||||
constexpr bool def_is_v_rowmajor = TestConfigs<DataTypeConfig>::def_is_v_rowmajor;
|
||||
int adjust_seqlen(int seqlen) { return TestConfigs<DataTypeConfig>::adjust_seqlen(seqlen); }
|
||||
@@ -203,7 +205,7 @@ TEST_P(AllLong, DataTypeConfig)
|
||||
1024, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -247,7 +249,7 @@ TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
"0", // mask
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
init_method,
|
||||
@@ -291,7 +293,7 @@ TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail)
|
||||
0,
|
||||
false,
|
||||
"0",
|
||||
squant,
|
||||
qscale_str,
|
||||
true,
|
||||
2, // num_splits (>1 triggers splitkv)
|
||||
init_method,
|
||||
@@ -334,7 +336,7 @@ TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail)
|
||||
0,
|
||||
false,
|
||||
"0",
|
||||
squant,
|
||||
qscale_str,
|
||||
true,
|
||||
1,
|
||||
init_method,
|
||||
@@ -403,7 +405,7 @@ TEST_P(HDimPadding, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -463,7 +465,7 @@ TEST_P(ElementwiseBias, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -522,7 +524,7 @@ TEST_P(Alibi, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -583,7 +585,7 @@ TEST_P(Dropout, DataTypeConfig)
|
||||
drop_offset, // drop_offset
|
||||
drop_prefs, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -648,7 +650,7 @@ TEST_P(PagedKV, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -719,7 +721,7 @@ TEST_P(SplitKV, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
num_splits, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -796,7 +798,7 @@ TEST_P(AppendKV, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
false, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -818,7 +820,7 @@ GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKVRoPE);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
AppendKVRoPE,
|
||||
Combine(EnableTestIf(!std::is_same_v<DataTypeConfig, FmhaFwdFp8>),
|
||||
Combine(EnableTestIf(!std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>),
|
||||
AppendKVHDimValues,
|
||||
Bool(), // layouts of k and v are controlled by i_perm
|
||||
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
||||
@@ -869,7 +871,7 @@ TEST_P(AppendKVRoPE, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
is_rotary_interleaved, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -1105,7 +1107,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd_Padding, PaddingCases, ValuesIn(kPadd
|
||||
|
||||
TEST_P(PaddingCases, DataTypeConfig)
|
||||
{
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>)
|
||||
{
|
||||
GTEST_SKIP() << "Skip for fp8";
|
||||
}
|
||||
@@ -1162,7 +1164,7 @@ TEST_P(PaddingCases, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
|
||||
Reference in New Issue
Block a user