diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index f81ae34501..ac86bf4635 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -264,7 +264,7 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::HostTensor lse_host( std::array{shape_batch, nhead, shape_seqlen_q}); ck_tile::HostTensor sink_host( - sink_grad ? std::array{shape_batch, nhead} + sink_grad ? std::array{batch, nhead} : std::array{1, 1} /* dummy when sink is disabled */); if(sink_grad) { diff --git a/test/ck_tile/fmha/test_fmha_bwd.cpp b/test/ck_tile/fmha/test_fmha_bwd.cpp index 3aee76131e..8d90ad9143 100644 --- a/test/ck_tile/fmha/test_fmha_bwd.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd.cpp @@ -995,3 +995,83 @@ TEST_P(MultiBatchPadding, DataTypeConfig) GTEST_SKIP() << "No instance for multi-batch padding"; ASSERT_EQ(result, bwd_result::success); } + +// ============================================================================ +// Regression test for sink_host group-mode OOB fix (PR #7272) +// ---------------------------------------------------------------------------- +// Bug: in group mode, fmha_bwd_runner.hpp allocated sink_host with first +// dimension shape_batch (=1) but the fwd reference loop iterates wb in +// [0, batch-1], causing out-of-bounds reads of heap garbage when batch > 1. +// +// Repro condition: sink_grad=true AND mode=group AND batch>=2. +// Without the fix, the fwd reference computes a poisoned LSE and the bwd +// validation fails non-deterministically (~25-67% failure rate observed +// across 30 trial runs at b=2,h=2,s=516,s_k=253,d=72,bf16,mask=no). +// With the fix (1-line change shape_batch -> batch on line 267 of +// fmha_bwd_runner.hpp), all 30 runs PASS. +// +// This test exercises the fixed code path; a regression that re-introduces +// the OOB will be detected as flaky/failing validation in CI. +// ============================================================================ +class SinkGradGroupMode : public TestWithParam +{ +}; +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, + SinkGradGroupMode, + Combine(Values(mode_enum::group), // group mode required to hit OOB + Values(std::tuple{72, -1}, // hdim covered by repro command + std::tuple{64, -1}, + std::tuple{128, -1}), + Values(std::tuple{true, true}), // perm matching repro + Values("n"), // bias=n matching repro + Values(false), // use_dbias + Values(0.0f), // no dropout + Values(std::tuple{0, 0, false}), // seed/offset/prefs + Values(std::tuple{2, 2, -1, 516, 253, "0"}, // exact repro config + std::tuple{2, 2, -1, 516, 253, "1"}, // + causal top-left + std::tuple{ + 2, 2, -1, 516, 253, "2"}, // + causal bottom-right + std::tuple{3, 4, 2, 259, -1, "0"}, // larger batch, square + std::tuple{4, 2, -1, 200, 180, "0"}), // batch=4 stress + Values(false) // deterministic + )); +TEST_P(SinkGradGroupMode, DataTypeConfig) +{ + auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [i_perm, o_perm] = perm; + auto [drop_seed, drop_offset, drop_prefs] = drop_misc; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + auto result = fmha_bwd_run( + mode, + batch, + nhead, + nhead_k, + {seqlen_q}, + {seqlen_k}, + {-1}, + {-1}, + hdim_q, + hdim_v, + i_perm, + o_perm, + 0, // scale + bias_str, + use_dbias, + p_drop, + drop_seed, + drop_offset, + drop_prefs, + mask_str, + true, // sink_grad: critical to trigger sink_host alloc/access path + det, + init_method, + static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), + 1, + stream_config); + + if(result == bwd_result::no_instance) + GTEST_SKIP() << "No instance for sink_grad group-mode regression"; + ASSERT_EQ(result, bwd_result::success); +}