Files
composable_kernel/test/ck_tile/fmha/test_fmha_bwd.cpp
Anton Gorenko 1e77695fe8 [CK_TILE] Support WMMA (gfx12) in FMHA (#2528)
* Pass hdim to tile_example_fmha_fwd in fp8 tests

* Add WMMA support to fwd FMHA pipelines

* Tune tile sizes a bit for less spilling

fp16 256 is still quite slow

* Fix Q grad tile distribution for warp size = 32 and hdim >= 256

With AccDataType = float and warp size = 32, K0 becomes 0, K repeat is required to correcty distribute the tile.

* Use code based on BlockDropout in BlockDropoutBwd

* Fix split KV combine kernel for gfx12 (warp size 32) and make it more universal

* Fix LSE LDS tensor descriptors: kMaxSplits and kM0 were swapped, it worked on gfx9
  because they both equal to 8 while on gfx12 they are 8 and 4;
* Fix Oacc LDS tensor descriptor: it was transposed even though its shape=[4 * kM0, kN1],
  it worked on gfx9 because 4 * kM == kN1 == 32;
* Removing these hidden dependecies allows to support:
    * any number of warps (power-of-2), not only 4;
    * kN1 = 16, not only 32;
    * any number of splits;

* Rename ids like o_acc_4 and Oacc4 to eliminate confusion: kNumWarps doesn't have to be 4 now

* Replace hard-coded kN1 in dispatch code with the requested tile size

* Add gfx12-specific tile sizes for split KV

* Pass GPU architecture to kernel generation scripts

This is still a temporary solution.

* Build and run FMHA CI tests for gfx12

* Fix issue after merging

* Fix bwd tile sizes

The current pipelines always read only one tile K and V tile, this
requires bk0 == bhdq and bk2 == bhdv (kK0 == kQKHeaddim and
kK2 == kVHeaddim).

* Use hardware f32->f8 on gfx12, remove v_perm

__builtin_amdgcn_perm is not needed because
__builtin_amdgcn_cvt_pk_fp8_f32 allows to specify which word (16 bit of
 32-bit dword) is used to store results (two f8 values).

* Update changelog

* Add WMMA support to pagedkv

* Fix scripts after rebasing

* Support 16x16 (MFMA, WMMA) and 32x32 (MFMA) tiles in fwd and bwd BlockDropout

Add comments with dropout implementation details

Fix performance regression of fwd+dropout

    * Remove some usage of type punning (reinterpret_cast with ref or ptr) in Philox;
    * "scalarize" seed and offset, they may come either from kernel args or from device memory
      (presumably loaded with vector loads).

    These changes help the compiler to procude more optimal code and reduce register spilling.

Use WarpGemmDispatcher instead of explicit WarpGemmMfma... to get  CWarpDstrEncoding

Use code based on BlockDropout in BlockDropoutBwd

Refactor BlockDropout (fwd)

Implement BlockDropout (fwd) for WMMA

    Originally BlockDropout only supported 32x32 tiles (IsWG32 = true),
    this version supports 16x16 tiles.
    If MPerBlock > MWarp * 16, it can generate numbers for two 16x16 tiles, similarly
    to BlockDropoutBwd.

Implement BlockDropoutBwd for WMMA

Remove MakeRandValLds* functions unused in BlockDropoutBwd

Remove unused Run overload from BlockDropoutBwd

* Fix regression with philox seed and offset when they exceed 32-bit int

__builtin_amdgcn_readfirstlane works with 32-bit values, seed and offset
are 64-bit so they get truncated.

* Fix names after cherry-picking

* Fix selection of a fallback tile based on bm0

The assumption that the largest bm0 == 128 is not always true for
current fp32 tiles.

* Do not use filters related to qr_async_trload

They disable tiles/pipelines which are valid for gfx12.

* Use different dstr encoding when C is transposed

* Do not call GetQKBlockGemm (and hence WarpGemmDispatcher) in host code

Some WarpGemmDispatcher instantiations are defined only
for specific archs and undefined on host.
Calculations related to sched barriers are moved from Pipeline's public
fields into pipeline's operator().

* Fix incorrect name WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution

Correct name is WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution
because it's 32x32x16 with IterateK = 2 so K = 32, also all tiles used
in codegen scripts are 32, 32, 32.

* Generalize usages of WarpGemmDispatcher for MFMA and WMMA

WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution is still
used explicitly becaus of swizzle factor = 4.

* Mark has_load_tr as maybe_unused

There are no transpose loading for RDNA.

* Remove CK_TILE_USE_MFMA/WMMA from fmha-related code

* Detect BlockSize on host based on warp size of the current device

If kBlockSize == kNumWarps * get_warp_size(), the kernel is launched with
kBlockSize / 2 because on host get_warp_size() == 64 always.

* Fix calculation of grid size for combine kernel with warp size = 32

* Add missing includes and header

* Support multiple archs in one binary for fwd

* Support multiple archs in one binary for fwd_splitkv, fwd_appendkv, pagedkv_prefill

* Support multiple archs in one binary for bwd

* trload kernels are compiled only for gfx950;
* instances with padding are checked after instances without padding so
  they can be used as fallbacks (similarly to fwd);

* Extract common code from register_traits

* Revert "Fix regression with philox seed and offset when they exceed 32-bit int"

To simplify merging , the proper fix is in develop already.

* Support new numerical d paddings in trait ordering checks

* Build fp32 tests only on gfx9

* Do not use hardcoded M0 = 64 for dot bwd kernel

* Use textwrap.indent from standard library

* Make fp8 pipelines on gfx12 consistent with gfx9

* Update tests for current pipelines

* Make ninja check more responsive in CI

ninja buffers output so this job looks hanging.

* Support fp8fp32 by limiting O vector size

The fp32 output type requires storing 8 * sizeof(float) = 32 bytes,
which is not implemented (here 8 is the number of C values per lane for
v_wmma_f32_16x16x16...).

* Remove unused cmake options

* Unify including  amd_buffer_addressing.hpp/_builtins.hpp

* Temporarily use amd_buffer_addressing.hpp on >=gfx10

amd_buffer_addressing_builtins.hpp uses inline asm for loads/stores
which is not compatible with >=gfx10:
 * 1 scalar for exec masks instead of 2,
 * gfx12 uses different instruction names etc.

* Update asm in bf16 conversions to work with warp 32

* Do not generate splitkv/appendkv with vlayout=col for consistency with fwd

* Add arch tags to kernels/host funcs, compile for each arch separately

* Add kM0 to fmha_bwd_dot_do_o kernel name to match filename

* Add workaround for miscompilation of bwd with padded hdim

SWDEV-559729: v_wmma instructions can be incorrectly placed in divergent
branches used to store padded tensors (when some lanes are inactive due
to padding). Inline asm with dummy dependencies on VGPRs of the tensors
prevents the compiler doing this.

* Fix add_gtest_executable for absolute paths

Some tests (like gemm_tile_engine) pass absolute paths to source files.
In CI the branch name is a part of the root dir, and if the branch name
contains "wmma", "xdl" etc., files can be incorrectly excluded.

* Run only hdim 128 smoke tests for fp8fp32

There are no instances for hdim 64 and 256.

* Format py with ruff to simplify merging develop

* Fix incorrect var name

* Codegen for gfx9,gfx950 when --targets is not specified

Aiter and Pytorch require changes for passing their targets to the codegen scripts.
With this temporary solution the files are generated but not all of them
have to be really built (depending on the used --offload-arch=).

* Combine arch-related values into ArchTrait

This more centralized approach removes duplication of various formatting templates.

* Try a workaround for Jenkins error "groovyjarjarasm.asm.MethodTooLargeException: Method too large"

Some code is extracted into a function.
2025-10-29 13:31:08 -07:00

990 lines
38 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "example/ck_tile/01_fmha/fmha_bwd.hpp"
#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp"
#include "gtest/gtest.h"
#ifndef DataTypeConfig
#define DataTypeConfig FmhaBwdFp16 // or FmhaBwdBf16 / FmhaBwdFp32
#endif
using ::testing::Bool;
using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::ValuesIn;
template <typename T>
struct TestConfigs
{
static constexpr auto HDimValues = std::array{
std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
};
template <>
struct TestConfigs<FmhaBwdFp32>
{
static constexpr auto HDimValues =
std::array{std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}};
};
static auto HDimValues = ValuesIn(TestConfigs<DataTypeConfig>::HDimValues);
const auto ModeValues = ValuesIn(std::vector<mode_enum>{mode_enum::batch, mode_enum::group});
constexpr auto init_method = "uf";
// Random seed used for initializing input tensors. 0 for non-deterministic seed
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
// Whether to run long tests (from smoke_test_fwd.sh)
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS)
const ck_tile::stream_config stream_config{
nullptr, // stream_id_
false, // time_kernel_
1, // log_level_
0, // cold_niters_
1, // nrepeat_
true, // is_gpu_timer_
false, // flush_cache_
1, // rotating_count_
};
// batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str
using FmhaBwdDimsMaskParam = std::tuple<int, int, int, int, int, std::string>;
using FmhaBwdTestParam = std::tuple< //
mode_enum, // mode
std::tuple<int, int>, // hdim_q, hdim_v
std::tuple<bool, bool>, // io_perm
std::string, // bias_str
bool, // use_dbias
float, // p_drop
std::tuple<uint64_t, uint64_t, bool>, // drop_seed, drop_offset, drop_prefs
FmhaBwdDimsMaskParam,
bool // deterministic
>;
void fmha_bwd_test(const FmhaBwdTestParam& param)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = param;
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
{seqlen_q},
{seqlen_k},
{-1},
{-1},
hdim_q,
hdim_v,
i_perm,
o_perm,
0, // scale
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for current parameters";
ASSERT_EQ(result, bwd_result::success);
}
// Test cases from example/ck_tile/01_fmha/script/smoke_test_bwd.sh
class AllLong : public TestWithParam<FmhaBwdTestParam>
{
};
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong);
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
AllLong,
Combine(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))
? ModeValues
: ValuesIn(std::vector<mode_enum>{}),
HDimValues,
Values(std::tuple{true, true}, std::tuple{false, false}), // perm
Values("n", "a"),
Values(false), // use_dbias
Values(0.0f, 0.2f), // p_drop
Values(std::tuple{123, 1024, true}), // seed/offset/prefs
Values(std::tuple{1, 4, 2, 259, -1, "0"},
std::tuple{2, 2, -1, 516, 253, "0"},
std::tuple{1, 4, 1, 500, 251, "1"},
std::tuple{1, 2, -1, 900, 258, "2"},
std::tuple{2, 1, -1, 987, 219, "t:128,30"},
std::tuple{2, 3, 1, 244, 499, "b:4,35"}),
Values(false) // deterministic
));
TEST_P(AllLong, DataTypeConfig) { fmha_bwd_test(GetParam()); }
class HDimPadding : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
HDimPadding,
Combine(ModeValues,
Values(std::tuple{24, 48},
std::tuple{48, 48},
std::tuple{72, 72},
std::tuple{40, 88},
std::tuple{96, 96},
std::tuple{120, 160},
std::tuple{256, 108},
std::tuple{40, 64}),
Values(std::tuple{true, true}, std::tuple{false, false}), // perm
Values("n"), // bias_str
Values(false), // use_dbias
Values(0.0f), // p_drop
Values(std::tuple{0, 0, false}), // seed/offset/prefs
Values(std::tuple{1, 4, 2, 480, -1, "0"},
std::tuple{2, 2, -1, 300, 400, "t:64,64"},
std::tuple{1, 4, 1, 512, 201, "1"},
std::tuple{1, 2, -1, 900, 256, "0"},
std::tuple{2, 1, -1, 256, 256, "1"}),
Values(false) // deterministic
));
TEST_P(HDimPadding, DataTypeConfig) { fmha_bwd_test(GetParam()); }
class ElementwiseBias : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
ElementwiseBias,
Combine(ModeValues,
HDimValues,
// layouts of bias and dbias are controlled by i_perm
Values(std::tuple{true, false}, std::tuple{false, false}),
Values("e:0", "e:1", "e:2"),
Bool(), // use_dbias
Values(0.0f), // p_drop
Values(std::tuple{0, 0, false}), // seed/offset/prefs
Values(std::tuple{1, 4, 2, 1024, 100, "0"},
std::tuple{3, 2, -1, 128, 256, "2"},
std::tuple{2, 2, -1, 130, 499, "t:50,64"}),
Values(false) // deterministic
));
TEST_P(ElementwiseBias, DataTypeConfig) { fmha_bwd_test(GetParam()); }
class Alibi : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaBwd,
Alibi,
Combine(ModeValues,
HDimValues,
Values(std::tuple{true, true}), // perm
Values("a:0", "a:1"),
Values(false), // use_dbias
Values(0.0f), // p_drop
Values(std::tuple{0, 0, false}), // seed/offset/prefs
ValuesIn([]() {
const std::array dims{
std::tuple{1, 3, 3, 1024, 1000},
std::tuple{3, 5, 5, 128, 256},
std::tuple{2, 8, 4, 130, 320},
};
const std::array mask_strs{"0", "t", "b", "t:50,64", "b:32,40"};
std::vector<FmhaBwdDimsMaskParam> dims_masks;
std::for_each(dims.begin(), dims.end(), [&](const auto& d) {
const auto& [b, h, hk, sq, sk] = d;
std::for_each(mask_strs.begin(), mask_strs.end(), [&](const auto& m) {
dims_masks.push_back(std::tuple{b, h, hk, sq, sk, m});
});
});
return dims_masks;
}()),
Values(false) // deterministic
));
TEST_P(Alibi, DataTypeConfig) { fmha_bwd_test(GetParam()); }
class Dropout : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
Dropout,
Combine(ModeValues,
HDimValues,
Values(std::tuple{true, true}), // perm
Values("n"), // bias_str
Values(false), // use_dbias
Values(0.123f, 0.5f), // p_drop
Values(std::tuple{10, 123, false}, // seed/offset/prefs
std::tuple{34534564645, 7876878876864, true}),
Values(std::tuple{2, 6, 2, 180, 512, "0"},
std::tuple{3, 2, 2, 256, 128, "1"},
std::tuple{4, 2, 1, 100, 768, "2"}),
Values(false) // deterministic
));
TEST_P(Dropout, DataTypeConfig) { fmha_bwd_test(GetParam()); }
class Deterministic : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
Deterministic,
Combine(ModeValues,
HDimValues,
Values(std::tuple{false, true}, std::tuple{true, true}), // perm
Values("n"), // bias_str
Values(false), // use_dbias
Values(0.0f), // p_drop
Values(std::tuple{0, 0, false}), // seed/offset/prefs
Values(std::tuple{2, 6, 2, 180, 512, "0"},
std::tuple{3, 3, 1, 256, 128, "1"},
std::tuple{4, 2, 2, 768, 100, "2"}),
Values(true) // deterministic
));
TEST_P(Deterministic, DataTypeConfig) { fmha_bwd_test(GetParam()); }
// ============================================================================
// Q/KV Padding Tests - High Priority
// ============================================================================
// 1. BasicQPadding: Test Q padding only (K/V have no padding)
class BasicQPadding : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaBwd,
BasicQPadding,
Combine(Values(mode_enum::group), // Only group mode supports padding
HDimValues,
Values(std::tuple{true, true}), // perm
Values("n"), // no bias for basic test
Values(false), // use_dbias
Values(0.0f), // no dropout
Values(std::tuple{0, 0, false}), // seed/offset/prefs
ValuesIn([]() {
// Define test cases with Q padding: seqlen_q < seqlen_qpad
// Format: {batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str}
// Note: Will set seqlen_qpad separately in the test
std::vector<FmhaBwdDimsMaskParam> test_cases;
// Small padding: logical length close to physical
test_cases.push_back(std::tuple{2, 2, 2, 127, 128, "0"}); // Q: 127->128
test_cases.push_back(std::tuple{3, 4, 2, 250, 256, "0"}); // Q: 250->256
// Medium padding: ~20-30% padding
test_cases.push_back(std::tuple{2, 2, 1, 180, 256, "0"}); // Q: 180->256
test_cases.push_back(std::tuple{3, 3, 3, 350, 512, "1"}); // Q: 350->512, causal
// Large padding: ~50% padding
test_cases.push_back(std::tuple{2, 4, 2, 128, 256, "0"}); // Q: 128->256
test_cases.push_back(std::tuple{2, 2, 2, 200, 512, "2"}); // Q: 200->512, causal
return test_cases;
}()),
Values(false) // deterministic
));
TEST_P(BasicQPadding, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
// Set up Q padding: physical length larger than logical
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
// Calculate physical Q length (padded)
ck_tile::index_t seqlen_qpad = ((seqlen_q + 63) / 64) * 64; // Round up to multiple of 64
if(seqlen_q > 256)
seqlen_qpad = ((seqlen_q + 127) / 128) * 128; // Larger alignment for longer sequences
std::vector<ck_tile::index_t> seqlen_qpads(batch, seqlen_qpad);
std::vector<ck_tile::index_t> seqlen_kpads(batch, seqlen_k); // No K padding
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0, // scale
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for Q padding with hdim_q=" << hdim_q;
ASSERT_EQ(result, bwd_result::success);
}
// 2. BasicKVPadding: Test K/V padding only (Q has no padding)
class BasicKVPadding : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaBwd,
BasicKVPadding,
Combine(Values(mode_enum::group),
HDimValues,
Values(std::tuple{true, true}),
Values("n"),
Values(false),
Values(0.0f),
Values(std::tuple{0, 0, false}),
ValuesIn([]() {
std::vector<FmhaBwdDimsMaskParam> test_cases;
// Small K/V padding
test_cases.push_back(std::tuple{2, 2, 2, 128, 127, "0"}); // K: 127->128
test_cases.push_back(std::tuple{3, 4, 2, 256, 250, "0"}); // K: 250->256
// Medium K/V padding
test_cases.push_back(std::tuple{2, 2, 1, 256, 180, "0"}); // K: 180->256
test_cases.push_back(std::tuple{3, 3, 3, 512, 350, "1"}); // K: 350->512
// Large K/V padding
test_cases.push_back(std::tuple{2, 4, 2, 256, 128, "0"}); // K: 128->256
test_cases.push_back(std::tuple{2, 2, 2, 512, 200, "2"}); // K: 200->512
return test_cases;
}()),
Values(false)));
TEST_P(BasicKVPadding, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
// No Q padding
std::vector<ck_tile::index_t> seqlen_qpads(batch, seqlen_q);
// Set up K/V padding
ck_tile::index_t seqlen_kpad = ((seqlen_k + 63) / 64) * 64;
if(seqlen_k > 256)
seqlen_kpad = ((seqlen_k + 127) / 128) * 128;
std::vector<ck_tile::index_t> seqlen_kpads(batch, seqlen_kpad);
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for K/V padding with hdim_q=" << hdim_q;
ASSERT_EQ(result, bwd_result::success);
}
// 3. QKVPadding: Test both Q and K/V padding simultaneously
class QKVPadding : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaBwd,
QKVPadding,
Combine(Values(mode_enum::group),
HDimValues,
Values(std::tuple{true, true}),
Values("n"),
Values(false),
Values(0.0f),
Values(std::tuple{0, 0, false}),
ValuesIn([]() {
std::vector<FmhaBwdDimsMaskParam> test_cases;
// Both Q and K have small padding
test_cases.push_back(std::tuple{2, 2, 2, 120, 125, "0"}); // Q:120->128, K:125->128
// Both Q and K have medium padding
test_cases.push_back(std::tuple{2, 4, 2, 180, 200, "0"}); // Q:180->256, K:200->256
test_cases.push_back(std::tuple{3, 3, 3, 300, 350, "1"}); // Q:300->320, K:350->384
// Both Q and K have large padding
test_cases.push_back(std::tuple{2, 2, 1, 150, 180, "0"}); // Q:150->256, K:180->256
test_cases.push_back(std::tuple{2, 4, 2, 256, 300, "2"}); // Q:256->384, K:300->384
// Asymmetric padding (Q more padded than K)
test_cases.push_back(std::tuple{2, 2, 2, 100, 200, "0"}); // Q:100->128, K:200->256
// Asymmetric padding (K more padded than Q)
test_cases.push_back(std::tuple{2, 3, 1, 200, 100, "0"}); // Q:200->256, K:100->128
return test_cases;
}()),
Values(false)));
TEST_P(QKVPadding, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
// Set up both Q and K/V padding
ck_tile::index_t seqlen_qpad = ((seqlen_q + 63) / 64) * 64;
if(seqlen_q > 256)
seqlen_qpad = ((seqlen_q + 127) / 128) * 128;
ck_tile::index_t seqlen_kpad = ((seqlen_k + 63) / 64) * 64;
if(seqlen_k > 256)
seqlen_kpad = ((seqlen_k + 127) / 128) * 128;
std::vector<ck_tile::index_t> seqlen_qpads(batch, seqlen_qpad);
std::vector<ck_tile::index_t> seqlen_kpads(batch, seqlen_kpad);
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for Q+K/V padding with hdim_q=" << hdim_q;
ASSERT_EQ(result, bwd_result::success);
}
// 4. ZeroLengthPadding: Test zero-length sequences with padding
class ZeroLengthPadding : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
ZeroLengthPadding,
Combine(Values(mode_enum::group),
Values(std::tuple{64, -1},
std::tuple{128, -1}), // Limited hdim for edge cases
Values(std::tuple{true, true}),
Values("n"),
Values(false),
Values(0.0f),
Values(std::tuple{0, 0, false}),
Values(
// Test case 1: First batch has zero Q length
std::tuple{3, 2, 2, 0, 128, "0"},
// Test case 2: Middle batch has zero Q length (multi-batch)
std::tuple{3, 2, 1, 100, 128, "0"},
// Test case 3: Last batch has zero Q length
std::tuple{3, 3, 3, 150, 200, "0"},
// Test case 4: Zero K length (first batch)
std::tuple{3, 2, 2, 128, 0, "0"},
// Test case 5: Mixed zero lengths with padding
std::tuple{4, 2, 2, 80, 100, "0"}),
Values(false)));
TEST_P(ZeroLengthPadding, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
// Create varied sequence lengths with some zero-length sequences
std::vector<ck_tile::index_t> seqlen_qs;
std::vector<ck_tile::index_t> seqlen_ks;
std::vector<ck_tile::index_t> seqlen_qpads;
std::vector<ck_tile::index_t> seqlen_kpads;
for(int b = 0; b < batch; ++b)
{
// Create pattern with zero-length sequences
ck_tile::index_t q_len, k_len;
if(seqlen_q == 0 && b == 1) // Middle batch zero Q
{
q_len = (b == 1) ? 0 : ((b == 0) ? 100 : 80);
k_len = seqlen_k;
}
else if(seqlen_k == 0 && b == 0) // First batch zero K
{
q_len = seqlen_q;
k_len = (b == 0) ? 0 : 100;
}
else
{
// Varied lengths
q_len = (b == 0 && seqlen_q == 0) ? 0 : (seqlen_q + b * 10);
k_len = seqlen_k + b * 15;
}
seqlen_qs.push_back(q_len);
seqlen_ks.push_back(k_len);
// Add padding for non-zero lengths
ck_tile::index_t qpad = (q_len == 0) ? 0 : ((q_len + 63) / 64) * 64;
ck_tile::index_t kpad = (k_len == 0) ? 0 : ((k_len + 63) / 64) * 64;
seqlen_qpads.push_back(qpad);
seqlen_kpads.push_back(kpad);
}
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for zero-length padding";
ASSERT_EQ(result, bwd_result::success);
}
// ============================================================================
// Q/KV Padding Tests - Medium Priority
// ============================================================================
// 5. VariedPaddingRatios: Test different padding ratios (waste ratios)
class VariedPaddingRatios : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaBwd,
VariedPaddingRatios,
Combine(Values(mode_enum::group),
HDimValues,
Values(std::tuple{true, true}),
Values("n"),
Values(false),
Values(0.0f),
Values(std::tuple{0, 0, false}),
ValuesIn([]() {
std::vector<FmhaBwdDimsMaskParam> test_cases;
// Minimal waste: ~1-5% padding (logical ≈ physical - small delta)
test_cases.push_back(
std::tuple{2, 2, 2, 127, 127, "0"}); // Q:127->128 (~0.8%), K:127->128
test_cases.push_back(
std::tuple{2, 4, 2, 252, 250, "0"}); // Q:252->256 (~1.6%), K:250->256
test_cases.push_back(std::tuple{2, 2, 1, 509, 505, "1"}); // Q:509->512, K:505->512
// Low waste: ~10-20% padding
test_cases.push_back(
std::tuple{2, 3, 3, 220, 210, "0"}); // Q:220->256 (~16%), K:210->256
test_cases.push_back(
std::tuple{3, 2, 2, 440, 420, "0"}); // Q:440->512 (~16%), K:420->512
test_cases.push_back(std::tuple{2, 4, 2, 350, 340, "1"}); // Q:350->384, K:340->384
// Medium waste: ~30-40% padding
test_cases.push_back(
std::tuple{2, 2, 2, 180, 170, "0"}); // Q:180->256 (~42%), K:170->256
test_cases.push_back(
std::tuple{2, 3, 1, 320, 310, "0"}); // Q:320->384 (~20%), K:310->384
test_cases.push_back(std::tuple{3, 2, 2, 350, 340, "2"}); // Q:350->512, K:340->512
// High waste: ~50%+ padding
test_cases.push_back(
std::tuple{2, 2, 2, 130, 130, "0"}); // Q:130->256 (~97%), K:130->256
test_cases.push_back(
std::tuple{2, 4, 2, 260, 260, "0"}); // Q:260->512 (~97%), K:260->512
test_cases.push_back(
std::tuple{2, 2, 1, 200, 200, "1"}); // Q:200->256 (~28%), K:200->256
// Extreme waste: very small logical vs large physical
test_cases.push_back(std::tuple{2, 2, 2, 65, 70, "0"}); // Q:65->128, K:70->128
test_cases.push_back(std::tuple{2, 3, 3, 100, 90, "0"}); // Q:100->128, K:90->128
return test_cases;
}()),
Values(false)));
TEST_P(VariedPaddingRatios, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
// Calculate padding based on common alignment strategies
auto calc_pad = [](ck_tile::index_t len) -> ck_tile::index_t {
if(len <= 64)
return 64;
else if(len <= 128)
return 128;
else if(len <= 256)
return 256;
else if(len <= 384)
return 384;
else if(len <= 512)
return 512;
else
return ((len + 127) / 128) * 128;
};
std::vector<ck_tile::index_t> seqlen_qpads(batch, calc_pad(seqlen_q));
std::vector<ck_tile::index_t> seqlen_kpads(batch, calc_pad(seqlen_k));
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for varied padding ratios";
ASSERT_EQ(result, bwd_result::success);
}
// 6. PaddingWithMask: Test padding combined with various mask types
class PaddingWithMask : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaBwd,
PaddingWithMask,
Combine(Values(mode_enum::group),
Values(std::tuple{64, -1}, std::tuple{128, -1}), // Focus on common sizes
Values(std::tuple{true, true}),
Values("n"),
Values(false),
Values(0.0f),
Values(std::tuple{0, 0, false}),
ValuesIn([]() {
std::vector<FmhaBwdDimsMaskParam> test_cases;
// No mask with padding (baseline)
test_cases.push_back(std::tuple{2, 2, 2, 200, 180, "0"});
// Causal mask (top-left) with Q padding
test_cases.push_back(std::tuple{2, 2, 2, 200, 256, "1"}); // Q padded, K exact
test_cases.push_back(std::tuple{2, 4, 2, 180, 200, "t"}); // Both padded, causal
// Causal mask (bottom-right) with K/V padding
test_cases.push_back(std::tuple{2, 2, 1, 256, 180, "2"}); // K padded, Q exact
test_cases.push_back(
std::tuple{2, 3, 3, 200, 180, "b"}); // Both padded, bottom-right
// Sliding window attention with padding
test_cases.push_back(std::tuple{2, 2, 2, 200, 190, "t:64,32"}); // SWA + padding
test_cases.push_back(std::tuple{2, 4, 2, 180, 170, "b:32,64"}); // SWA + padding
test_cases.push_back(std::tuple{3, 2, 1, 220, 210, "t:100,50"}); // Larger window
// Sliding window with asymmetric padding
test_cases.push_back(std::tuple{2, 2, 2, 150, 250, "t:80,40"}); // Q more padded
test_cases.push_back(std::tuple{2, 3, 3, 250, 150, "b:50,70"}); // K more padded
// Mixed scenarios
test_cases.push_back(std::tuple{2, 4, 2, 190, 185, "t:50,50"}); // Symmetric window
test_cases.push_back(std::tuple{3, 2, 2, 300, 280, "1"}); // Multi-batch causal
return test_cases;
}()),
Values(false)));
TEST_P(PaddingWithMask, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
// Apply padding
ck_tile::index_t seqlen_qpad = ((seqlen_q + 63) / 64) * 64;
ck_tile::index_t seqlen_kpad = ((seqlen_k + 63) / 64) * 64;
if(seqlen_q > 256)
seqlen_qpad = ((seqlen_q + 127) / 128) * 128;
if(seqlen_k > 256)
seqlen_kpad = ((seqlen_k + 127) / 128) * 128;
std::vector<ck_tile::index_t> seqlen_qpads(batch, seqlen_qpad);
std::vector<ck_tile::index_t> seqlen_kpads(batch, seqlen_kpad);
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for padding with mask";
ASSERT_EQ(result, bwd_result::success);
}
// 7. MultiBatchPadding: Test multiple batches with different padding configurations
class MultiBatchPadding : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
MultiBatchPadding,
Combine(Values(mode_enum::group),
Values(std::tuple{64, -1}, std::tuple{128, -1}),
Values(std::tuple{true, true}),
Values("n"),
Values(false),
Values(0.0f),
Values(std::tuple{0, 0, false}),
Values(
// 3 batches with varied Q/K lengths and padding
std::tuple{3, 2, 2, 150, 200, "0"},
// 4 batches with different patterns
std::tuple{4, 3, 3, 180, 220, "0"},
// 5 batches with mixed scenarios
std::tuple{5, 2, 1, 120, 160, "1"},
// 3 batches with causal mask
std::tuple{3, 4, 2, 200, 180, "t"},
// 4 batches with sliding window
std::tuple{4, 2, 2, 160, 140, "t:50,30"}),
Values(false)));
TEST_P(MultiBatchPadding, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, base_seqlen_q, base_seqlen_k, mask_str] = dims_mask;
// Create varied sequence lengths for each batch
std::vector<ck_tile::index_t> seqlen_qs;
std::vector<ck_tile::index_t> seqlen_ks;
std::vector<ck_tile::index_t> seqlen_qpads;
std::vector<ck_tile::index_t> seqlen_kpads;
for(int b = 0; b < batch; ++b)
{
// Generate varied lengths across batches
// Pattern: decreasing, increasing, or random variation
ck_tile::index_t q_len, k_len;
switch(b % 3)
{
case 0: // Decreasing
q_len = base_seqlen_q - b * 20;
k_len = base_seqlen_k - b * 25;
break;
case 1: // Increasing
q_len = base_seqlen_q + b * 15;
k_len = base_seqlen_k + b * 20;
break;
case 2: // Mixed
q_len = base_seqlen_q + (b % 2 == 0 ? 10 : -10) * b;
k_len = base_seqlen_k + (b % 2 == 0 ? -15 : 15) * b;
break;
}
// Ensure positive lengths
q_len = std::max<ck_tile::index_t>(64, q_len);
k_len = std::max<ck_tile::index_t>(64, k_len);
seqlen_qs.push_back(q_len);
seqlen_ks.push_back(k_len);
// Calculate different padding strategies per batch
ck_tile::index_t qpad, kpad;
if(b % 4 == 0)
{
// Tight padding (minimal waste)
qpad = ((q_len + 31) / 32) * 32;
kpad = ((k_len + 31) / 32) * 32;
}
else if(b % 4 == 1)
{
// Medium padding
qpad = ((q_len + 63) / 64) * 64;
kpad = ((k_len + 63) / 64) * 64;
}
else if(b % 4 == 2)
{
// Loose padding
qpad = ((q_len + 127) / 128) * 128;
kpad = ((k_len + 127) / 128) * 128;
}
else
{
// Mixed: Q tight, K loose
qpad = ((q_len + 31) / 32) * 32;
kpad = ((k_len + 127) / 128) * 128;
}
seqlen_qpads.push_back(qpad);
seqlen_kpads.push_back(kpad);
}
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for multi-batch padding";
ASSERT_EQ(result, bwd_result::success);
}