mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[rocm-libraries] ROCm/rocm-libraries#7272 (commit d02f3c0)
[ck_tile][fmha_bwd] Fix sink_host OOB in group mode reference runner (#7272) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary In `fmha_bwd_runner.hpp`, the `sink_host` `HostTensor` is allocated with first dimension `shape_batch` (= 1 in group mode), but the reference forward loop accesses `sink_host(wb, i_h)` with `wb ∈ [0, batch-1]`. For any `wb >= 1` this is an out-of-bounds heap read, silently corrupting the reference forward math chain (`lse_host`, `o_host`) and turning the bwd-side `d_sink_head_acc` reference into non-deterministic garbage. `HostTensor::operator()` does not bounds check, so the OOB is not caught at runtime. This manifests as intermittent `tile_example_fmha_bwd` failures (25–67% fail rate) when `-sink_grad=1` is combined with `-mode=1` (group mode), with bit-exact but spurious `max_err` values like 4.27 / 14.6. ## Fix One-line: allocate `sink_host` with `batch` (the real per-batch dim) instead of `shape_batch`, mirroring how `sink_host` is accessed by the loop. ```diff - sink_grad ? std::array<ck_tile::index_t, 2>{shape_batch, nhead} + sink_grad ? std::array<ck_tile::index_t, 2>{batch, nhead} Repro tile_example_fmha_bwd -b=2 -h=2 -s=516 -s_k=253 -prec=bf16 -d=72 \ -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 \ -v=3 -mode=1 -kname=1 -sink_grad=1 Verification - 0/30 fail on the repro config after fix - Baselines (before fix): - sink=1, mask=n: 25% fail rate (p ≈ 1.8e-4) - sink=1, mask=t: 67% fail rate (p ≈ 6e-15) Attribution Shape bug introduced together with sink_grad in #5504. Unrelated to #6914 (which is a fwd-only fix on a different code path) ``` ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
6989cf800c
commit
5c7b7ec3f1
@@ -264,7 +264,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
|
||||
ck_tile::HostTensor<LSEDataType> sink_host(
|
||||
sink_grad ? std::array<ck_tile::index_t, 2>{shape_batch, nhead}
|
||||
sink_grad ? std::array<ck_tile::index_t, 2>{batch, nhead}
|
||||
: std::array<ck_tile::index_t, 2>{1, 1} /* dummy when sink is disabled */);
|
||||
if(sink_grad)
|
||||
{
|
||||
|
||||
@@ -995,3 +995,83 @@ TEST_P(MultiBatchPadding, DataTypeConfig)
|
||||
GTEST_SKIP() << "No instance for multi-batch padding";
|
||||
ASSERT_EQ(result, bwd_result::success);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Regression test for sink_host group-mode OOB fix (PR #7272)
|
||||
// ----------------------------------------------------------------------------
|
||||
// Bug: in group mode, fmha_bwd_runner.hpp allocated sink_host with first
|
||||
// dimension shape_batch (=1) but the fwd reference loop iterates wb in
|
||||
// [0, batch-1], causing out-of-bounds reads of heap garbage when batch > 1.
|
||||
//
|
||||
// Repro condition: sink_grad=true AND mode=group AND batch>=2.
|
||||
// Without the fix, the fwd reference computes a poisoned LSE and the bwd
|
||||
// validation fails non-deterministically (~25-67% failure rate observed
|
||||
// across 30 trial runs at b=2,h=2,s=516,s_k=253,d=72,bf16,mask=no).
|
||||
// With the fix (1-line change shape_batch -> batch on line 267 of
|
||||
// fmha_bwd_runner.hpp), all 30 runs PASS.
|
||||
//
|
||||
// This test exercises the fixed code path; a regression that re-introduces
|
||||
// the OOB will be detected as flaky/failing validation in CI.
|
||||
// ============================================================================
|
||||
class SinkGradGroupMode : public TestWithParam<FmhaBwdTestParam>
|
||||
{
|
||||
};
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
SinkGradGroupMode,
|
||||
Combine(Values(mode_enum::group), // group mode required to hit OOB
|
||||
Values(std::tuple{72, -1}, // hdim covered by repro command
|
||||
std::tuple{64, -1},
|
||||
std::tuple{128, -1}),
|
||||
Values(std::tuple{true, true}), // perm matching repro
|
||||
Values("n"), // bias=n matching repro
|
||||
Values(false), // use_dbias
|
||||
Values(0.0f), // no dropout
|
||||
Values(std::tuple{0, 0, false}), // seed/offset/prefs
|
||||
Values(std::tuple{2, 2, -1, 516, 253, "0"}, // exact repro config
|
||||
std::tuple{2, 2, -1, 516, 253, "1"}, // + causal top-left
|
||||
std::tuple{
|
||||
2, 2, -1, 516, 253, "2"}, // + causal bottom-right
|
||||
std::tuple{3, 4, 2, 259, -1, "0"}, // larger batch, square
|
||||
std::tuple{4, 2, -1, 200, 180, "0"}), // batch=4 stress
|
||||
Values(false) // deterministic
|
||||
));
|
||||
TEST_P(SinkGradGroupMode, DataTypeConfig)
|
||||
{
|
||||
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [i_perm, o_perm] = perm;
|
||||
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(
|
||||
mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
{-1},
|
||||
{-1},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
0, // scale
|
||||
bias_str,
|
||||
use_dbias,
|
||||
p_drop,
|
||||
drop_seed,
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
true, // sink_grad: critical to trigger sink_host alloc/access path
|
||||
det,
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
stream_config);
|
||||
|
||||
if(result == bwd_result::no_instance)
|
||||
GTEST_SKIP() << "No instance for sink_grad group-mode regression";
|
||||
ASSERT_EQ(result, bwd_result::success);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user