mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
990 lines
38 KiB
C++
990 lines
38 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include "example/ck_tile/01_fmha/fmha_bwd.hpp"
|
|
#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp"
|
|
|
|
#include "gtest/gtest.h"
|
|
|
|
#ifndef DataTypeConfig
|
|
#define DataTypeConfig FmhaBwdFp16 // or FmhaBwdBf16 / FmhaBwdFp32
|
|
#endif
|
|
|
|
using ::testing::Bool;
|
|
using ::testing::Combine;
|
|
using ::testing::TestWithParam;
|
|
using ::testing::Values;
|
|
using ::testing::ValuesIn;
|
|
|
|
template <typename T>
|
|
struct TestConfigs
|
|
{
|
|
static constexpr auto HDimValues = std::array{
|
|
std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
|
|
};
|
|
template <>
|
|
struct TestConfigs<FmhaBwdFp32>
|
|
{
|
|
static constexpr auto HDimValues =
|
|
std::array{std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}};
|
|
};
|
|
static auto HDimValues = ValuesIn(TestConfigs<DataTypeConfig>::HDimValues);
|
|
const auto ModeValues = ValuesIn(std::vector<mode_enum>{mode_enum::batch, mode_enum::group});
|
|
constexpr auto init_method = "uf";
|
|
|
|
// Random seed used for initializing input tensors. 0 for non-deterministic seed
|
|
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
|
|
|
|
// Whether to run long tests (from smoke_test_fwd.sh)
|
|
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS)
|
|
|
|
const ck_tile::stream_config stream_config{
|
|
nullptr, // stream_id_
|
|
false, // time_kernel_
|
|
1, // log_level_
|
|
0, // cold_niters_
|
|
1, // nrepeat_
|
|
true, // is_gpu_timer_
|
|
false, // flush_cache_
|
|
1, // rotating_count_
|
|
};
|
|
|
|
// batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str
|
|
using FmhaBwdDimsMaskParam = std::tuple<int, int, int, int, int, std::string>;
|
|
using FmhaBwdTestParam = std::tuple< //
|
|
mode_enum, // mode
|
|
std::tuple<int, int>, // hdim_q, hdim_v
|
|
std::tuple<bool, bool>, // io_perm
|
|
std::string, // bias_str
|
|
bool, // use_dbias
|
|
float, // p_drop
|
|
std::tuple<uint64_t, uint64_t, bool>, // drop_seed, drop_offset, drop_prefs
|
|
FmhaBwdDimsMaskParam,
|
|
bool // deterministic
|
|
>;
|
|
void fmha_bwd_test(const FmhaBwdTestParam& param)
|
|
{
|
|
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = param;
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [i_perm, o_perm] = perm;
|
|
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
|
|
|
auto result = fmha_bwd_run<DataTypeConfig>(
|
|
mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
{seqlen_q},
|
|
{seqlen_k},
|
|
{-1},
|
|
{-1},
|
|
hdim_q,
|
|
hdim_v,
|
|
i_perm,
|
|
o_perm,
|
|
0, // scale
|
|
bias_str,
|
|
use_dbias,
|
|
p_drop,
|
|
drop_seed,
|
|
drop_offset,
|
|
drop_prefs,
|
|
mask_str,
|
|
det, // deterministic
|
|
init_method,
|
|
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
|
1,
|
|
stream_config);
|
|
|
|
if(result == bwd_result::no_instance)
|
|
GTEST_SKIP() << "No instance for current parameters";
|
|
ASSERT_EQ(result, bwd_result::success);
|
|
}
|
|
|
|
// Test cases from example/ck_tile/01_fmha/script/smoke_test_bwd.sh
|
|
class AllLong : public TestWithParam<FmhaBwdTestParam>
|
|
{
|
|
};
|
|
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong);
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
|
AllLong,
|
|
Combine(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))
|
|
? ModeValues
|
|
: ValuesIn(std::vector<mode_enum>{}),
|
|
HDimValues,
|
|
Values(std::tuple{true, true}, std::tuple{false, false}), // perm
|
|
Values("n", "a"),
|
|
Values(false), // use_dbias
|
|
Values(0.0f, 0.2f), // p_drop
|
|
Values(std::tuple{123, 1024, true}), // seed/offset/prefs
|
|
Values(std::tuple{1, 4, 2, 259, -1, "0"},
|
|
std::tuple{2, 2, -1, 516, 253, "0"},
|
|
std::tuple{1, 4, 1, 500, 251, "1"},
|
|
std::tuple{1, 2, -1, 900, 258, "2"},
|
|
std::tuple{2, 1, -1, 987, 219, "t:128,30"},
|
|
std::tuple{2, 3, 1, 244, 499, "b:4,35"}),
|
|
Values(false) // deterministic
|
|
));
|
|
TEST_P(AllLong, DataTypeConfig) { fmha_bwd_test(GetParam()); }
|
|
|
|
class HDimPadding : public TestWithParam<FmhaBwdTestParam>
|
|
{
|
|
};
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
|
HDimPadding,
|
|
Combine(ModeValues,
|
|
Values(std::tuple{24, 48},
|
|
std::tuple{48, 48},
|
|
std::tuple{72, 72},
|
|
std::tuple{40, 88},
|
|
std::tuple{96, 96},
|
|
std::tuple{120, 160},
|
|
std::tuple{256, 108},
|
|
std::tuple{40, 64}),
|
|
Values(std::tuple{true, true}, std::tuple{false, false}), // perm
|
|
Values("n"), // bias_str
|
|
Values(false), // use_dbias
|
|
Values(0.0f), // p_drop
|
|
Values(std::tuple{0, 0, false}), // seed/offset/prefs
|
|
Values(std::tuple{1, 4, 2, 480, -1, "0"},
|
|
std::tuple{2, 2, -1, 300, 400, "t:64,64"},
|
|
std::tuple{1, 4, 1, 512, 201, "1"},
|
|
std::tuple{1, 2, -1, 900, 256, "0"},
|
|
std::tuple{2, 1, -1, 256, 256, "1"}),
|
|
Values(false) // deterministic
|
|
));
|
|
TEST_P(HDimPadding, DataTypeConfig) { fmha_bwd_test(GetParam()); }
|
|
|
|
class ElementwiseBias : public TestWithParam<FmhaBwdTestParam>
|
|
{
|
|
};
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
|
ElementwiseBias,
|
|
Combine(ModeValues,
|
|
HDimValues,
|
|
// layouts of bias and dbias are controlled by i_perm
|
|
Values(std::tuple{true, false}, std::tuple{false, false}),
|
|
Values("e:0", "e:1", "e:2"),
|
|
Bool(), // use_dbias
|
|
Values(0.0f), // p_drop
|
|
Values(std::tuple{0, 0, false}), // seed/offset/prefs
|
|
Values(std::tuple{1, 4, 2, 1024, 100, "0"},
|
|
std::tuple{3, 2, -1, 128, 256, "2"},
|
|
std::tuple{2, 2, -1, 130, 499, "t:50,64"}),
|
|
Values(false) // deterministic
|
|
));
|
|
TEST_P(ElementwiseBias, DataTypeConfig) { fmha_bwd_test(GetParam()); }
|
|
class Alibi : public TestWithParam<FmhaBwdTestParam>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(
|
|
TestCkTileFmhaBwd,
|
|
Alibi,
|
|
Combine(ModeValues,
|
|
HDimValues,
|
|
Values(std::tuple{true, true}), // perm
|
|
Values("a:0", "a:1"),
|
|
Values(false), // use_dbias
|
|
Values(0.0f), // p_drop
|
|
Values(std::tuple{0, 0, false}), // seed/offset/prefs
|
|
ValuesIn([]() {
|
|
const std::array dims{
|
|
std::tuple{1, 3, 3, 1024, 1000},
|
|
std::tuple{3, 5, 5, 128, 256},
|
|
std::tuple{2, 8, 4, 130, 320},
|
|
};
|
|
const std::array mask_strs{"0", "t", "b", "t:50,64", "b:32,40"};
|
|
std::vector<FmhaBwdDimsMaskParam> dims_masks;
|
|
std::for_each(dims.begin(), dims.end(), [&](const auto& d) {
|
|
const auto& [b, h, hk, sq, sk] = d;
|
|
std::for_each(mask_strs.begin(), mask_strs.end(), [&](const auto& m) {
|
|
dims_masks.push_back(std::tuple{b, h, hk, sq, sk, m});
|
|
});
|
|
});
|
|
return dims_masks;
|
|
}()),
|
|
Values(false) // deterministic
|
|
));
|
|
TEST_P(Alibi, DataTypeConfig) { fmha_bwd_test(GetParam()); }
|
|
|
|
class Dropout : public TestWithParam<FmhaBwdTestParam>
|
|
{
|
|
};
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
|
Dropout,
|
|
Combine(ModeValues,
|
|
HDimValues,
|
|
Values(std::tuple{true, true}), // perm
|
|
Values("n"), // bias_str
|
|
Values(false), // use_dbias
|
|
Values(0.123f, 0.5f), // p_drop
|
|
Values(std::tuple{10, 123, false}, // seed/offset/prefs
|
|
std::tuple{34534564645, 7876878876864, true}),
|
|
Values(std::tuple{2, 6, 2, 180, 512, "0"},
|
|
std::tuple{3, 2, 2, 256, 128, "1"},
|
|
std::tuple{4, 2, 1, 100, 768, "2"}),
|
|
Values(false) // deterministic
|
|
));
|
|
|
|
TEST_P(Dropout, DataTypeConfig) { fmha_bwd_test(GetParam()); }
|
|
|
|
class Deterministic : public TestWithParam<FmhaBwdTestParam>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
|
Deterministic,
|
|
Combine(ModeValues,
|
|
HDimValues,
|
|
Values(std::tuple{false, true}, std::tuple{true, true}), // perm
|
|
Values("n"), // bias_str
|
|
Values(false), // use_dbias
|
|
Values(0.0f), // p_drop
|
|
Values(std::tuple{0, 0, false}), // seed/offset/prefs
|
|
Values(std::tuple{2, 6, 2, 180, 512, "0"},
|
|
std::tuple{3, 3, 1, 256, 128, "1"},
|
|
std::tuple{4, 2, 2, 768, 100, "2"}),
|
|
Values(true) // deterministic
|
|
));
|
|
TEST_P(Deterministic, DataTypeConfig) { fmha_bwd_test(GetParam()); }
|
|
|
|
// ============================================================================
|
|
// Q/KV Padding Tests - High Priority
|
|
// ============================================================================
|
|
|
|
// 1. BasicQPadding: Test Q padding only (K/V have no padding)
|
|
class BasicQPadding : public TestWithParam<FmhaBwdTestParam>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(
|
|
TestCkTileFmhaBwd,
|
|
BasicQPadding,
|
|
Combine(Values(mode_enum::group), // Only group mode supports padding
|
|
HDimValues,
|
|
Values(std::tuple{true, true}), // perm
|
|
Values("n"), // no bias for basic test
|
|
Values(false), // use_dbias
|
|
Values(0.0f), // no dropout
|
|
Values(std::tuple{0, 0, false}), // seed/offset/prefs
|
|
ValuesIn([]() {
|
|
// Define test cases with Q padding: seqlen_q < seqlen_qpad
|
|
// Format: {batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str}
|
|
// Note: Will set seqlen_qpad separately in the test
|
|
std::vector<FmhaBwdDimsMaskParam> test_cases;
|
|
|
|
// Small padding: logical length close to physical
|
|
test_cases.push_back(std::tuple{2, 2, 2, 127, 128, "0"}); // Q: 127->128
|
|
test_cases.push_back(std::tuple{3, 4, 2, 250, 256, "0"}); // Q: 250->256
|
|
|
|
// Medium padding: ~20-30% padding
|
|
test_cases.push_back(std::tuple{2, 2, 1, 180, 256, "0"}); // Q: 180->256
|
|
test_cases.push_back(std::tuple{3, 3, 3, 350, 512, "1"}); // Q: 350->512, causal
|
|
|
|
// Large padding: ~50% padding
|
|
test_cases.push_back(std::tuple{2, 4, 2, 128, 256, "0"}); // Q: 128->256
|
|
test_cases.push_back(std::tuple{2, 2, 2, 200, 512, "2"}); // Q: 200->512, causal
|
|
|
|
return test_cases;
|
|
}()),
|
|
Values(false) // deterministic
|
|
));
|
|
|
|
TEST_P(BasicQPadding, DataTypeConfig)
|
|
{
|
|
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [i_perm, o_perm] = perm;
|
|
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
|
|
|
// Set up Q padding: physical length larger than logical
|
|
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
|
|
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
|
|
|
|
// Calculate physical Q length (padded)
|
|
ck_tile::index_t seqlen_qpad = ((seqlen_q + 63) / 64) * 64; // Round up to multiple of 64
|
|
if(seqlen_q > 256)
|
|
seqlen_qpad = ((seqlen_q + 127) / 128) * 128; // Larger alignment for longer sequences
|
|
|
|
std::vector<ck_tile::index_t> seqlen_qpads(batch, seqlen_qpad);
|
|
std::vector<ck_tile::index_t> seqlen_kpads(batch, seqlen_k); // No K padding
|
|
|
|
auto result = fmha_bwd_run<DataTypeConfig>(
|
|
mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
seqlen_qs,
|
|
seqlen_ks,
|
|
seqlen_qpads,
|
|
seqlen_kpads,
|
|
hdim_q,
|
|
hdim_v,
|
|
i_perm,
|
|
o_perm,
|
|
0, // scale
|
|
bias_str,
|
|
use_dbias,
|
|
p_drop,
|
|
drop_seed,
|
|
drop_offset,
|
|
drop_prefs,
|
|
mask_str,
|
|
det,
|
|
init_method,
|
|
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
|
1,
|
|
stream_config);
|
|
|
|
if(result == bwd_result::no_instance)
|
|
GTEST_SKIP() << "No instance for Q padding with hdim_q=" << hdim_q;
|
|
ASSERT_EQ(result, bwd_result::success);
|
|
}
|
|
|
|
// 2. BasicKVPadding: Test K/V padding only (Q has no padding)
|
|
class BasicKVPadding : public TestWithParam<FmhaBwdTestParam>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(
|
|
TestCkTileFmhaBwd,
|
|
BasicKVPadding,
|
|
Combine(Values(mode_enum::group),
|
|
HDimValues,
|
|
Values(std::tuple{true, true}),
|
|
Values("n"),
|
|
Values(false),
|
|
Values(0.0f),
|
|
Values(std::tuple{0, 0, false}),
|
|
ValuesIn([]() {
|
|
std::vector<FmhaBwdDimsMaskParam> test_cases;
|
|
|
|
// Small K/V padding
|
|
test_cases.push_back(std::tuple{2, 2, 2, 128, 127, "0"}); // K: 127->128
|
|
test_cases.push_back(std::tuple{3, 4, 2, 256, 250, "0"}); // K: 250->256
|
|
|
|
// Medium K/V padding
|
|
test_cases.push_back(std::tuple{2, 2, 1, 256, 180, "0"}); // K: 180->256
|
|
test_cases.push_back(std::tuple{3, 3, 3, 512, 350, "1"}); // K: 350->512
|
|
|
|
// Large K/V padding
|
|
test_cases.push_back(std::tuple{2, 4, 2, 256, 128, "0"}); // K: 128->256
|
|
test_cases.push_back(std::tuple{2, 2, 2, 512, 200, "2"}); // K: 200->512
|
|
|
|
return test_cases;
|
|
}()),
|
|
Values(false)));
|
|
|
|
TEST_P(BasicKVPadding, DataTypeConfig)
|
|
{
|
|
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [i_perm, o_perm] = perm;
|
|
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
|
|
|
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
|
|
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
|
|
|
|
// No Q padding
|
|
std::vector<ck_tile::index_t> seqlen_qpads(batch, seqlen_q);
|
|
|
|
// Set up K/V padding
|
|
ck_tile::index_t seqlen_kpad = ((seqlen_k + 63) / 64) * 64;
|
|
if(seqlen_k > 256)
|
|
seqlen_kpad = ((seqlen_k + 127) / 128) * 128;
|
|
std::vector<ck_tile::index_t> seqlen_kpads(batch, seqlen_kpad);
|
|
|
|
auto result = fmha_bwd_run<DataTypeConfig>(
|
|
mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
seqlen_qs,
|
|
seqlen_ks,
|
|
seqlen_qpads,
|
|
seqlen_kpads,
|
|
hdim_q,
|
|
hdim_v,
|
|
i_perm,
|
|
o_perm,
|
|
0,
|
|
bias_str,
|
|
use_dbias,
|
|
p_drop,
|
|
drop_seed,
|
|
drop_offset,
|
|
drop_prefs,
|
|
mask_str,
|
|
det,
|
|
init_method,
|
|
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
|
1,
|
|
stream_config);
|
|
|
|
if(result == bwd_result::no_instance)
|
|
GTEST_SKIP() << "No instance for K/V padding with hdim_q=" << hdim_q;
|
|
ASSERT_EQ(result, bwd_result::success);
|
|
}
|
|
|
|
// 3. QKVPadding: Test both Q and K/V padding simultaneously
|
|
class QKVPadding : public TestWithParam<FmhaBwdTestParam>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(
|
|
TestCkTileFmhaBwd,
|
|
QKVPadding,
|
|
Combine(Values(mode_enum::group),
|
|
HDimValues,
|
|
Values(std::tuple{true, true}),
|
|
Values("n"),
|
|
Values(false),
|
|
Values(0.0f),
|
|
Values(std::tuple{0, 0, false}),
|
|
ValuesIn([]() {
|
|
std::vector<FmhaBwdDimsMaskParam> test_cases;
|
|
|
|
// Both Q and K have small padding
|
|
test_cases.push_back(std::tuple{2, 2, 2, 120, 125, "0"}); // Q:120->128, K:125->128
|
|
|
|
// Both Q and K have medium padding
|
|
test_cases.push_back(std::tuple{2, 4, 2, 180, 200, "0"}); // Q:180->256, K:200->256
|
|
test_cases.push_back(std::tuple{3, 3, 3, 300, 350, "1"}); // Q:300->320, K:350->384
|
|
|
|
// Both Q and K have large padding
|
|
test_cases.push_back(std::tuple{2, 2, 1, 150, 180, "0"}); // Q:150->256, K:180->256
|
|
test_cases.push_back(std::tuple{2, 4, 2, 256, 300, "2"}); // Q:256->384, K:300->384
|
|
|
|
// Asymmetric padding (Q more padded than K)
|
|
test_cases.push_back(std::tuple{2, 2, 2, 100, 200, "0"}); // Q:100->128, K:200->256
|
|
|
|
// Asymmetric padding (K more padded than Q)
|
|
test_cases.push_back(std::tuple{2, 3, 1, 200, 100, "0"}); // Q:200->256, K:100->128
|
|
|
|
return test_cases;
|
|
}()),
|
|
Values(false)));
|
|
|
|
TEST_P(QKVPadding, DataTypeConfig)
|
|
{
|
|
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [i_perm, o_perm] = perm;
|
|
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
|
|
|
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
|
|
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
|
|
|
|
// Set up both Q and K/V padding
|
|
ck_tile::index_t seqlen_qpad = ((seqlen_q + 63) / 64) * 64;
|
|
if(seqlen_q > 256)
|
|
seqlen_qpad = ((seqlen_q + 127) / 128) * 128;
|
|
|
|
ck_tile::index_t seqlen_kpad = ((seqlen_k + 63) / 64) * 64;
|
|
if(seqlen_k > 256)
|
|
seqlen_kpad = ((seqlen_k + 127) / 128) * 128;
|
|
|
|
std::vector<ck_tile::index_t> seqlen_qpads(batch, seqlen_qpad);
|
|
std::vector<ck_tile::index_t> seqlen_kpads(batch, seqlen_kpad);
|
|
|
|
auto result = fmha_bwd_run<DataTypeConfig>(
|
|
mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
seqlen_qs,
|
|
seqlen_ks,
|
|
seqlen_qpads,
|
|
seqlen_kpads,
|
|
hdim_q,
|
|
hdim_v,
|
|
i_perm,
|
|
o_perm,
|
|
0,
|
|
bias_str,
|
|
use_dbias,
|
|
p_drop,
|
|
drop_seed,
|
|
drop_offset,
|
|
drop_prefs,
|
|
mask_str,
|
|
det,
|
|
init_method,
|
|
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
|
1,
|
|
stream_config);
|
|
|
|
if(result == bwd_result::no_instance)
|
|
GTEST_SKIP() << "No instance for Q+K/V padding with hdim_q=" << hdim_q;
|
|
ASSERT_EQ(result, bwd_result::success);
|
|
}
|
|
|
|
// 4. ZeroLengthPadding: Test zero-length sequences with padding
|
|
class ZeroLengthPadding : public TestWithParam<FmhaBwdTestParam>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
|
ZeroLengthPadding,
|
|
Combine(Values(mode_enum::group),
|
|
Values(std::tuple{64, -1},
|
|
std::tuple{128, -1}), // Limited hdim for edge cases
|
|
Values(std::tuple{true, true}),
|
|
Values("n"),
|
|
Values(false),
|
|
Values(0.0f),
|
|
Values(std::tuple{0, 0, false}),
|
|
Values(
|
|
// Test case 1: First batch has zero Q length
|
|
std::tuple{3, 2, 2, 0, 128, "0"},
|
|
// Test case 2: Middle batch has zero Q length (multi-batch)
|
|
std::tuple{3, 2, 1, 100, 128, "0"},
|
|
// Test case 3: Last batch has zero Q length
|
|
std::tuple{3, 3, 3, 150, 200, "0"},
|
|
// Test case 4: Zero K length (first batch)
|
|
std::tuple{3, 2, 2, 128, 0, "0"},
|
|
// Test case 5: Mixed zero lengths with padding
|
|
std::tuple{4, 2, 2, 80, 100, "0"}),
|
|
Values(false)));
|
|
|
|
TEST_P(ZeroLengthPadding, DataTypeConfig)
|
|
{
|
|
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [i_perm, o_perm] = perm;
|
|
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
|
|
|
// Create varied sequence lengths with some zero-length sequences
|
|
std::vector<ck_tile::index_t> seqlen_qs;
|
|
std::vector<ck_tile::index_t> seqlen_ks;
|
|
std::vector<ck_tile::index_t> seqlen_qpads;
|
|
std::vector<ck_tile::index_t> seqlen_kpads;
|
|
|
|
for(int b = 0; b < batch; ++b)
|
|
{
|
|
// Create pattern with zero-length sequences
|
|
ck_tile::index_t q_len, k_len;
|
|
|
|
if(seqlen_q == 0 && b == 1) // Middle batch zero Q
|
|
{
|
|
q_len = (b == 1) ? 0 : ((b == 0) ? 100 : 80);
|
|
k_len = seqlen_k;
|
|
}
|
|
else if(seqlen_k == 0 && b == 0) // First batch zero K
|
|
{
|
|
q_len = seqlen_q;
|
|
k_len = (b == 0) ? 0 : 100;
|
|
}
|
|
else
|
|
{
|
|
// Varied lengths
|
|
q_len = (b == 0 && seqlen_q == 0) ? 0 : (seqlen_q + b * 10);
|
|
k_len = seqlen_k + b * 15;
|
|
}
|
|
|
|
seqlen_qs.push_back(q_len);
|
|
seqlen_ks.push_back(k_len);
|
|
|
|
// Add padding for non-zero lengths
|
|
ck_tile::index_t qpad = (q_len == 0) ? 0 : ((q_len + 63) / 64) * 64;
|
|
ck_tile::index_t kpad = (k_len == 0) ? 0 : ((k_len + 63) / 64) * 64;
|
|
|
|
seqlen_qpads.push_back(qpad);
|
|
seqlen_kpads.push_back(kpad);
|
|
}
|
|
|
|
auto result = fmha_bwd_run<DataTypeConfig>(
|
|
mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
seqlen_qs,
|
|
seqlen_ks,
|
|
seqlen_qpads,
|
|
seqlen_kpads,
|
|
hdim_q,
|
|
hdim_v,
|
|
i_perm,
|
|
o_perm,
|
|
0,
|
|
bias_str,
|
|
use_dbias,
|
|
p_drop,
|
|
drop_seed,
|
|
drop_offset,
|
|
drop_prefs,
|
|
mask_str,
|
|
det,
|
|
init_method,
|
|
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
|
1,
|
|
stream_config);
|
|
|
|
if(result == bwd_result::no_instance)
|
|
GTEST_SKIP() << "No instance for zero-length padding";
|
|
ASSERT_EQ(result, bwd_result::success);
|
|
}
|
|
|
|
// ============================================================================
|
|
// Q/KV Padding Tests - Medium Priority
|
|
// ============================================================================
|
|
|
|
// 5. VariedPaddingRatios: Test different padding ratios (waste ratios)
|
|
class VariedPaddingRatios : public TestWithParam<FmhaBwdTestParam>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(
|
|
TestCkTileFmhaBwd,
|
|
VariedPaddingRatios,
|
|
Combine(Values(mode_enum::group),
|
|
HDimValues,
|
|
Values(std::tuple{true, true}),
|
|
Values("n"),
|
|
Values(false),
|
|
Values(0.0f),
|
|
Values(std::tuple{0, 0, false}),
|
|
ValuesIn([]() {
|
|
std::vector<FmhaBwdDimsMaskParam> test_cases;
|
|
|
|
// Minimal waste: ~1-5% padding (logical ≈ physical - small delta)
|
|
test_cases.push_back(
|
|
std::tuple{2, 2, 2, 127, 127, "0"}); // Q:127->128 (~0.8%), K:127->128
|
|
test_cases.push_back(
|
|
std::tuple{2, 4, 2, 252, 250, "0"}); // Q:252->256 (~1.6%), K:250->256
|
|
test_cases.push_back(std::tuple{2, 2, 1, 509, 505, "1"}); // Q:509->512, K:505->512
|
|
|
|
// Low waste: ~10-20% padding
|
|
test_cases.push_back(
|
|
std::tuple{2, 3, 3, 220, 210, "0"}); // Q:220->256 (~16%), K:210->256
|
|
test_cases.push_back(
|
|
std::tuple{3, 2, 2, 440, 420, "0"}); // Q:440->512 (~16%), K:420->512
|
|
test_cases.push_back(std::tuple{2, 4, 2, 350, 340, "1"}); // Q:350->384, K:340->384
|
|
|
|
// Medium waste: ~30-40% padding
|
|
test_cases.push_back(
|
|
std::tuple{2, 2, 2, 180, 170, "0"}); // Q:180->256 (~42%), K:170->256
|
|
test_cases.push_back(
|
|
std::tuple{2, 3, 1, 320, 310, "0"}); // Q:320->384 (~20%), K:310->384
|
|
test_cases.push_back(std::tuple{3, 2, 2, 350, 340, "2"}); // Q:350->512, K:340->512
|
|
|
|
// High waste: ~50%+ padding
|
|
test_cases.push_back(
|
|
std::tuple{2, 2, 2, 130, 130, "0"}); // Q:130->256 (~97%), K:130->256
|
|
test_cases.push_back(
|
|
std::tuple{2, 4, 2, 260, 260, "0"}); // Q:260->512 (~97%), K:260->512
|
|
test_cases.push_back(
|
|
std::tuple{2, 2, 1, 200, 200, "1"}); // Q:200->256 (~28%), K:200->256
|
|
|
|
// Extreme waste: very small logical vs large physical
|
|
test_cases.push_back(std::tuple{2, 2, 2, 65, 70, "0"}); // Q:65->128, K:70->128
|
|
test_cases.push_back(std::tuple{2, 3, 3, 100, 90, "0"}); // Q:100->128, K:90->128
|
|
|
|
return test_cases;
|
|
}()),
|
|
Values(false)));
|
|
|
|
TEST_P(VariedPaddingRatios, DataTypeConfig)
|
|
{
|
|
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [i_perm, o_perm] = perm;
|
|
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
|
|
|
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
|
|
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
|
|
|
|
// Calculate padding based on common alignment strategies
|
|
auto calc_pad = [](ck_tile::index_t len) -> ck_tile::index_t {
|
|
if(len <= 64)
|
|
return 64;
|
|
else if(len <= 128)
|
|
return 128;
|
|
else if(len <= 256)
|
|
return 256;
|
|
else if(len <= 384)
|
|
return 384;
|
|
else if(len <= 512)
|
|
return 512;
|
|
else
|
|
return ((len + 127) / 128) * 128;
|
|
};
|
|
|
|
std::vector<ck_tile::index_t> seqlen_qpads(batch, calc_pad(seqlen_q));
|
|
std::vector<ck_tile::index_t> seqlen_kpads(batch, calc_pad(seqlen_k));
|
|
|
|
auto result = fmha_bwd_run<DataTypeConfig>(
|
|
mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
seqlen_qs,
|
|
seqlen_ks,
|
|
seqlen_qpads,
|
|
seqlen_kpads,
|
|
hdim_q,
|
|
hdim_v,
|
|
i_perm,
|
|
o_perm,
|
|
0,
|
|
bias_str,
|
|
use_dbias,
|
|
p_drop,
|
|
drop_seed,
|
|
drop_offset,
|
|
drop_prefs,
|
|
mask_str,
|
|
det,
|
|
init_method,
|
|
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
|
1,
|
|
stream_config);
|
|
|
|
if(result == bwd_result::no_instance)
|
|
GTEST_SKIP() << "No instance for varied padding ratios";
|
|
ASSERT_EQ(result, bwd_result::success);
|
|
}
|
|
|
|
// 6. PaddingWithMask: Test padding combined with various mask types
|
|
class PaddingWithMask : public TestWithParam<FmhaBwdTestParam>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(
|
|
TestCkTileFmhaBwd,
|
|
PaddingWithMask,
|
|
Combine(Values(mode_enum::group),
|
|
Values(std::tuple{64, -1}, std::tuple{128, -1}), // Focus on common sizes
|
|
Values(std::tuple{true, true}),
|
|
Values("n"),
|
|
Values(false),
|
|
Values(0.0f),
|
|
Values(std::tuple{0, 0, false}),
|
|
ValuesIn([]() {
|
|
std::vector<FmhaBwdDimsMaskParam> test_cases;
|
|
|
|
// No mask with padding (baseline)
|
|
test_cases.push_back(std::tuple{2, 2, 2, 200, 180, "0"});
|
|
|
|
// Causal mask (top-left) with Q padding
|
|
test_cases.push_back(std::tuple{2, 2, 2, 200, 256, "1"}); // Q padded, K exact
|
|
test_cases.push_back(std::tuple{2, 4, 2, 180, 200, "t"}); // Both padded, causal
|
|
|
|
// Causal mask (bottom-right) with K/V padding
|
|
test_cases.push_back(std::tuple{2, 2, 1, 256, 180, "2"}); // K padded, Q exact
|
|
test_cases.push_back(
|
|
std::tuple{2, 3, 3, 200, 180, "b"}); // Both padded, bottom-right
|
|
|
|
// Sliding window attention with padding
|
|
test_cases.push_back(std::tuple{2, 2, 2, 200, 190, "t:64,32"}); // SWA + padding
|
|
test_cases.push_back(std::tuple{2, 4, 2, 180, 170, "b:32,64"}); // SWA + padding
|
|
test_cases.push_back(std::tuple{3, 2, 1, 220, 210, "t:100,50"}); // Larger window
|
|
|
|
// Sliding window with asymmetric padding
|
|
test_cases.push_back(std::tuple{2, 2, 2, 150, 250, "t:80,40"}); // Q more padded
|
|
test_cases.push_back(std::tuple{2, 3, 3, 250, 150, "b:50,70"}); // K more padded
|
|
|
|
// Mixed scenarios
|
|
test_cases.push_back(std::tuple{2, 4, 2, 190, 185, "t:50,50"}); // Symmetric window
|
|
test_cases.push_back(std::tuple{3, 2, 2, 300, 280, "1"}); // Multi-batch causal
|
|
|
|
return test_cases;
|
|
}()),
|
|
Values(false)));
|
|
|
|
TEST_P(PaddingWithMask, DataTypeConfig)
|
|
{
|
|
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [i_perm, o_perm] = perm;
|
|
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
|
|
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
|
|
|
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
|
|
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
|
|
|
|
// Apply padding
|
|
ck_tile::index_t seqlen_qpad = ((seqlen_q + 63) / 64) * 64;
|
|
ck_tile::index_t seqlen_kpad = ((seqlen_k + 63) / 64) * 64;
|
|
|
|
if(seqlen_q > 256)
|
|
seqlen_qpad = ((seqlen_q + 127) / 128) * 128;
|
|
if(seqlen_k > 256)
|
|
seqlen_kpad = ((seqlen_k + 127) / 128) * 128;
|
|
|
|
std::vector<ck_tile::index_t> seqlen_qpads(batch, seqlen_qpad);
|
|
std::vector<ck_tile::index_t> seqlen_kpads(batch, seqlen_kpad);
|
|
|
|
auto result = fmha_bwd_run<DataTypeConfig>(
|
|
mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
seqlen_qs,
|
|
seqlen_ks,
|
|
seqlen_qpads,
|
|
seqlen_kpads,
|
|
hdim_q,
|
|
hdim_v,
|
|
i_perm,
|
|
o_perm,
|
|
0,
|
|
bias_str,
|
|
use_dbias,
|
|
p_drop,
|
|
drop_seed,
|
|
drop_offset,
|
|
drop_prefs,
|
|
mask_str,
|
|
det,
|
|
init_method,
|
|
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
|
1,
|
|
stream_config);
|
|
|
|
if(result == bwd_result::no_instance)
|
|
GTEST_SKIP() << "No instance for padding with mask";
|
|
ASSERT_EQ(result, bwd_result::success);
|
|
}
|
|
|
|
// 7. MultiBatchPadding: Test multiple batches with different padding configurations
|
|
class MultiBatchPadding : public TestWithParam<FmhaBwdTestParam>
|
|
{
|
|
};
|
|
|
|
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
|
MultiBatchPadding,
|
|
Combine(Values(mode_enum::group),
|
|
Values(std::tuple{64, -1}, std::tuple{128, -1}),
|
|
Values(std::tuple{true, true}),
|
|
Values("n"),
|
|
Values(false),
|
|
Values(0.0f),
|
|
Values(std::tuple{0, 0, false}),
|
|
Values(
|
|
// 3 batches with varied Q/K lengths and padding
|
|
std::tuple{3, 2, 2, 150, 200, "0"},
|
|
// 4 batches with different patterns
|
|
std::tuple{4, 3, 3, 180, 220, "0"},
|
|
// 5 batches with mixed scenarios
|
|
std::tuple{5, 2, 1, 120, 160, "1"},
|
|
// 3 batches with causal mask
|
|
std::tuple{3, 4, 2, 200, 180, "t"},
|
|
// 4 batches with sliding window
|
|
std::tuple{4, 2, 2, 160, 140, "t:50,30"}),
|
|
Values(false)));
|
|
|
|
TEST_P(MultiBatchPadding, DataTypeConfig)
|
|
{
|
|
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
|
|
auto [hdim_q, hdim_v] = hdims;
|
|
auto [i_perm, o_perm] = perm;
|
|
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
|
|
auto [batch, nhead, nhead_k, base_seqlen_q, base_seqlen_k, mask_str] = dims_mask;
|
|
|
|
// Create varied sequence lengths for each batch
|
|
std::vector<ck_tile::index_t> seqlen_qs;
|
|
std::vector<ck_tile::index_t> seqlen_ks;
|
|
std::vector<ck_tile::index_t> seqlen_qpads;
|
|
std::vector<ck_tile::index_t> seqlen_kpads;
|
|
|
|
for(int b = 0; b < batch; ++b)
|
|
{
|
|
// Generate varied lengths across batches
|
|
// Pattern: decreasing, increasing, or random variation
|
|
ck_tile::index_t q_len, k_len;
|
|
|
|
switch(b % 3)
|
|
{
|
|
case 0: // Decreasing
|
|
q_len = base_seqlen_q - b * 20;
|
|
k_len = base_seqlen_k - b * 25;
|
|
break;
|
|
case 1: // Increasing
|
|
q_len = base_seqlen_q + b * 15;
|
|
k_len = base_seqlen_k + b * 20;
|
|
break;
|
|
case 2: // Mixed
|
|
q_len = base_seqlen_q + (b % 2 == 0 ? 10 : -10) * b;
|
|
k_len = base_seqlen_k + (b % 2 == 0 ? -15 : 15) * b;
|
|
break;
|
|
}
|
|
|
|
// Ensure positive lengths
|
|
q_len = std::max<ck_tile::index_t>(64, q_len);
|
|
k_len = std::max<ck_tile::index_t>(64, k_len);
|
|
|
|
seqlen_qs.push_back(q_len);
|
|
seqlen_ks.push_back(k_len);
|
|
|
|
// Calculate different padding strategies per batch
|
|
ck_tile::index_t qpad, kpad;
|
|
|
|
if(b % 4 == 0)
|
|
{
|
|
// Tight padding (minimal waste)
|
|
qpad = ((q_len + 31) / 32) * 32;
|
|
kpad = ((k_len + 31) / 32) * 32;
|
|
}
|
|
else if(b % 4 == 1)
|
|
{
|
|
// Medium padding
|
|
qpad = ((q_len + 63) / 64) * 64;
|
|
kpad = ((k_len + 63) / 64) * 64;
|
|
}
|
|
else if(b % 4 == 2)
|
|
{
|
|
// Loose padding
|
|
qpad = ((q_len + 127) / 128) * 128;
|
|
kpad = ((k_len + 127) / 128) * 128;
|
|
}
|
|
else
|
|
{
|
|
// Mixed: Q tight, K loose
|
|
qpad = ((q_len + 31) / 32) * 32;
|
|
kpad = ((k_len + 127) / 128) * 128;
|
|
}
|
|
|
|
seqlen_qpads.push_back(qpad);
|
|
seqlen_kpads.push_back(kpad);
|
|
}
|
|
|
|
auto result = fmha_bwd_run<DataTypeConfig>(
|
|
mode,
|
|
batch,
|
|
nhead,
|
|
nhead_k,
|
|
seqlen_qs,
|
|
seqlen_ks,
|
|
seqlen_qpads,
|
|
seqlen_kpads,
|
|
hdim_q,
|
|
hdim_v,
|
|
i_perm,
|
|
o_perm,
|
|
0,
|
|
bias_str,
|
|
use_dbias,
|
|
p_drop,
|
|
drop_seed,
|
|
drop_offset,
|
|
drop_prefs,
|
|
mask_str,
|
|
det,
|
|
init_method,
|
|
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
|
1,
|
|
stream_config);
|
|
|
|
if(result == bwd_result::no_instance)
|
|
GTEST_SKIP() << "No instance for multi-batch padding";
|
|
ASSERT_EQ(result, bwd_result::success);
|
|
}
|