Files
composable_kernel/test/ck_tile/fmha/test_fmha_bwd.inc

348 lines
16 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
using ::testing::Bool;
using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::ValuesIn;
// Random seed used for initializing input tensors. 0 for non-deterministic seed
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
// Whether to run long tests (from smoke_test_fwd.sh)
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS)
#define CHECK_RESULT(result) \
do \
{ \
if(result == bwd_result::no_instance) \
GTEST_SKIP() << "No instance for current parameters"; \
ASSERT_EQ(result, bwd_result::success); \
} while(0)
const ck_tile::stream_config stream_config{
nullptr, // stream_id_
false, // time_kernel_
1, // log_level_
0, // cold_niters_
1, // nrepeat_
true, // is_gpu_timer_
false, // flush_cache_
1, // rotating_count_
};
#define COMMON_ARGS \
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \
stream_config
auto EnableTestIf(bool condition)
{
return ValuesIn(condition ? std::vector<bool>{true} : std::vector<bool>{});
}
class AllLong : public TestWithParam<std::tuple<bool,
std::tuple<int, int>,
bool,
mode_enum,
std::string,
float,
std::tuple<int, int, int, int, int, std::string>>>
{
};
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong);
// Test cases from example/ck_tile/01_fmha/script/smoke_test_bwd.sh
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaBwd,
AllLong,
Combine(EnableTestIf(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))),
HDimValues,
Bool(),
ModeValues,
Values("n", "a"),
Values(0.0f, 0.2f),
Values(std::tuple{1, 4, 2, 259, -1, "0"},
std::tuple{2, 2, -1, 516, 253, "0"},
std::tuple{1, 4, 1, 500, 251, "1"},
std::tuple{1, 2, -1, 900, 258, "2"},
std::tuple{2, 1, -1, 987, 219, "t:128,30"},
std::tuple{2, 3, 1, 244, 499, "b:4,35"})));
TEST_P(AllLong, Test)
{
auto [_, hdims, perm, mode, bias_str, p_drop, dims_mask] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
auto result = fmha_bwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
{seqlen_q},
{seqlen_k},
hdim_q,
hdim_v,
perm, // i_perm
perm, // o_perm
0, // scale
bias_str, // bias_str
false, // use_dbias
p_drop, // p_drop
123, // drop_seed
1024, // drop_offset
true, // drop_prefs
mask_str, // mask_str
false, // deterministic
COMMON_ARGS);
CHECK_RESULT(result);
}
class HDimPadding
: public TestWithParam<std::tuple<std::tuple<int, int>,
bool,
mode_enum,
std::tuple<int, int, int, int, int, std::string>>>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
HDimPadding,
Combine(Values(std::tuple{24, 48},
std::tuple{48, 48},
std::tuple{72, 72},
std::tuple{96, 96},
std::tuple{120, 160},
std::tuple{256, 108},
std::tuple{40, 64}),
Bool(),
ModeValues,
Values(std::tuple{1, 4, 2, 480, -1, "0"},
std::tuple{2, 2, -1, 300, 400, "t:64,64"},
std::tuple{1, 4, 1, 512, 201, "1"},
std::tuple{1, 2, -1, 900, 256, "0"},
std::tuple{2, 1, -1, 256, 256, "1"})));
TEST_P(HDimPadding, Test)
{
auto [hdims, perm, mode, dims_mask] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
auto result = fmha_bwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
{seqlen_q},
{seqlen_k},
hdim_q,
hdim_v,
perm, // i_perm
perm, // o_perm
0, // scale
"n", // bias_str
false, // use_dbias
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
false, // deterministic
COMMON_ARGS);
CHECK_RESULT(result);
}
class ElementwiseBias
: public TestWithParam<std::tuple<std::tuple<int, int>,
bool,
mode_enum,
std::string,
bool,
std::tuple<int, int, int, int, int, std::string>>>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
ElementwiseBias,
Combine(HDimValues,
Bool(), // layouts of bias and dbias are controlled by i_perm
ModeValues,
Values("e:0", "e:1", "e:2"),
Bool(),
Values(std::tuple{1, 4, 2, 1024, 100, "0"},
std::tuple{3, 2, -1, 128, 256, "2"},
std::tuple{2, 2, -1, 130, 499, "t:50,64"})));
TEST_P(ElementwiseBias, Test)
{
auto [hdims, i_perm, mode, bias_str, use_dbias, dims_mask] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
auto result = fmha_bwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
{seqlen_q},
{seqlen_k},
hdim_q,
hdim_v,
i_perm, // i_perm
false, // o_perm
0, // scale
bias_str, // bias_str
use_dbias, // use_dbias
0.0f, // p_drop
123, // drop_seed
1024, // drop_offset
true, // drop_prefs
mask_str, // mask_str
false, // deterministic
COMMON_ARGS);
CHECK_RESULT(result);
}
class Alibi : public TestWithParam<std::tuple<std::tuple<int, int>,
mode_enum,
std::string,
std::tuple<int, int, int, int, int>,
std::string>>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
Alibi,
Combine(HDimValues,
ModeValues,
Values("a:0", "a:1"),
Values(std::tuple{1, 3, 3, 1024, 1000},
std::tuple{3, 5, 5, 128, 256},
std::tuple{2, 8, 4, 130, 320}),
Values("0", "t", "b", "t:50,64", "b:32,40")));
TEST_P(Alibi, Test)
{
auto [hdims, mode, bias_str, dims, mask_str] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k] = dims;
auto result = fmha_bwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
{seqlen_q},
{seqlen_k},
hdim_q,
hdim_v,
true, // i_perm
true, // o_perm
0, // scale
bias_str, // bias_str
false, // use_dbias
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
false, // deterministic
COMMON_ARGS);
CHECK_RESULT(result);
}
class Dropout : public TestWithParam<std::tuple<std::tuple<int, int>,
mode_enum,
float,
std::tuple<uint64_t, uint64_t, bool>,
std::tuple<int, int, int, int, int, std::string>>>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
Dropout,
Combine(HDimValues,
ModeValues,
Values(0.123f, 0.5f),
Values(std::tuple{10, 123, false},
std::tuple{34534564645, 7876878876864, true}),
Values(std::tuple{2, 6, 2, 180, 512, "0"},
std::tuple{3, 2, 2, 256, 128, "1"},
std::tuple{4, 2, 1, 100, 768, "2"})));
TEST_P(Dropout, Test)
{
auto [hdims, mode, p_drop, drop_seed_offset_prefs, dims_mask] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
auto result = fmha_bwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
{seqlen_q},
{seqlen_k},
hdim_q,
hdim_v,
true, // i_perm
true, // o_perm
0.1f, // scale
"n", // bias_str
false, // use_dbias
p_drop, // p_drop
drop_seed, // drop_seed
drop_offset, // drop_offset
drop_prefs, // drop_prefs
mask_str, // mask_str
false, // deterministic
COMMON_ARGS);
CHECK_RESULT(result);
}
class Deterministic
: public TestWithParam<std::tuple<std::tuple<int, int>,
bool,
mode_enum,
std::tuple<int, int, int, int, int, std::string>>>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
Deterministic,
Combine(HDimValues,
Bool(),
ModeValues,
Values(std::tuple{2, 6, 2, 180, 512, "0"},
std::tuple{3, 3, 1, 256, 128, "1"},
std::tuple{4, 2, 2, 768, 100, "2"})));
TEST_P(Deterministic, Test)
{
auto [hdims, i_perm, mode, dims_mask] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
auto result = fmha_bwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
{seqlen_q},
{seqlen_k},
hdim_q,
hdim_v,
i_perm, // i_perm
true, // o_perm
0, // scale
"n", // bias_str
false, // use_dbias
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
true, // deterministic
COMMON_ARGS);
CHECK_RESULT(result);
}