[rocm-libraries] ROCm/rocm-libraries#7272 (commit d02f3c0)

[ck_tile][fmha_bwd] Fix sink_host OOB in group mode reference
 runner (#7272)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

In `fmha_bwd_runner.hpp`, the `sink_host` `HostTensor` is allocated with
first
dimension `shape_batch` (= 1 in group mode), but the reference forward
loop
accesses `sink_host(wb, i_h)` with `wb ∈ [0, batch-1]`. For any `wb >=
1` this
is an out-of-bounds heap read, silently corrupting the reference forward
math
chain (`lse_host`, `o_host`) and turning the bwd-side `d_sink_head_acc`
  reference into non-deterministic garbage.

`HostTensor::operator()` does not bounds check, so the OOB is not caught
at
runtime. This manifests as intermittent `tile_example_fmha_bwd` failures
(25–67% fail rate) when `-sink_grad=1` is combined with `-mode=1` (group
mode),
  with bit-exact but spurious `max_err` values like 4.27 / 14.6.

  ## Fix

One-line: allocate `sink_host` with `batch` (the real per-batch dim)
instead of
  `shape_batch`, mirroring how `sink_host` is accessed by the loop.

  ```diff
  -    sink_grad ? std::array<ck_tile::index_t, 2>{shape_batch, nhead}
  +    sink_grad ? std::array<ck_tile::index_t, 2>{batch, nhead}

  Repro

  tile_example_fmha_bwd -b=2 -h=2 -s=516 -s_k=253 -prec=bf16 -d=72 \
    -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 \
    -v=3 -mode=1 -kname=1 -sink_grad=1

  Verification

  - 0/30 fail on the repro config after fix
  - Baselines (before fix):
    - sink=1, mask=n: 25% fail rate (p ≈ 1.8e-4)
    - sink=1, mask=t: 67% fail rate (p ≈ 6e-15)

  Attribution

Shape bug introduced together with sink_grad in #5504. Unrelated to
#6914
  (which is a fwd-only fix on a different code path)
```

## Submission Checklist

- [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Linjun-AMD
2026-05-13 08:49:13 +00:00
committed by assistant-librarian[bot]
parent 6989cf800c
commit 5c7b7ec3f1
2 changed files with 81 additions and 1 deletions

View File

@@ -264,7 +264,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<LSEDataType> sink_host(
sink_grad ? std::array<ck_tile::index_t, 2>{shape_batch, nhead}
sink_grad ? std::array<ck_tile::index_t, 2>{batch, nhead}
: std::array<ck_tile::index_t, 2>{1, 1} /* dummy when sink is disabled */);
if(sink_grad)
{

View File

@@ -995,3 +995,83 @@ TEST_P(MultiBatchPadding, DataTypeConfig)
GTEST_SKIP() << "No instance for multi-batch padding";
ASSERT_EQ(result, bwd_result::success);
}
// ============================================================================
// Regression test for sink_host group-mode OOB fix (PR #7272)
// ----------------------------------------------------------------------------
// Bug: in group mode, fmha_bwd_runner.hpp allocated sink_host with first
// dimension shape_batch (=1) but the fwd reference loop iterates wb in
// [0, batch-1], causing out-of-bounds reads of heap garbage when batch > 1.
//
// Repro condition: sink_grad=true AND mode=group AND batch>=2.
// Without the fix, the fwd reference computes a poisoned LSE and the bwd
// validation fails non-deterministically (~25-67% failure rate observed
// across 30 trial runs at b=2,h=2,s=516,s_k=253,d=72,bf16,mask=no).
// With the fix (1-line change shape_batch -> batch on line 267 of
// fmha_bwd_runner.hpp), all 30 runs PASS.
//
// This test exercises the fixed code path; a regression that re-introduces
// the OOB will be detected as flaky/failing validation in CI.
// ============================================================================
class SinkGradGroupMode : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
SinkGradGroupMode,
Combine(Values(mode_enum::group), // group mode required to hit OOB
Values(std::tuple{72, -1}, // hdim covered by repro command
std::tuple{64, -1},
std::tuple{128, -1}),
Values(std::tuple{true, true}), // perm matching repro
Values("n"), // bias=n matching repro
Values(false), // use_dbias
Values(0.0f), // no dropout
Values(std::tuple{0, 0, false}), // seed/offset/prefs
Values(std::tuple{2, 2, -1, 516, 253, "0"}, // exact repro config
std::tuple{2, 2, -1, 516, 253, "1"}, // + causal top-left
std::tuple{
2, 2, -1, 516, 253, "2"}, // + causal bottom-right
std::tuple{3, 4, 2, 259, -1, "0"}, // larger batch, square
std::tuple{4, 2, -1, 200, 180, "0"}), // batch=4 stress
Values(false) // deterministic
));
TEST_P(SinkGradGroupMode, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
{seqlen_q},
{seqlen_k},
{-1},
{-1},
hdim_q,
hdim_v,
i_perm,
o_perm,
0, // scale
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
true, // sink_grad: critical to trigger sink_host alloc/access path
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for sink_grad group-mode regression";
ASSERT_EQ(result, bwd_result::success);
}