From acf3d65966921d91be1a44c63cf3f00a738e8da6 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Wed, 13 May 2026 09:42:28 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#7256 (commit 1fc20eb) =?UTF-8?q?Skip=20numeric=20drop-out=20when=20PComputeWind?= =?UTF-8?q?ow=20is=20a=20null=5Ftile=5Fwindow=20in=20Bl=E2=80=A6=20(#7256)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The BlockDropout implementation already provides very complete logic for generating random numbers and executing dropout for the P tensor after first attention Gemm with capability to support both Warp-Gemm 32x32 and 16x16 as well as to run on both wave32 and wave64 arch. But in some situation, we only need the block-layer process to generate random numbers, rather than simultaneously execute dropout in real-time on the vgpr tile. For example, xformers' `test_mem_eff_attention.py::test_dropout_ck` requires the host reference implementation of `attention forward with dropout` to use the same random numbers to compare & verify the device side implementation of `attention forward with dropout`, so a standalone kernel to generate random numbers only is required. This PR will enable xformers's random_val generating kernel (in file `ck_tiled_rand_uniform_kernel.h`) to depend on BlockDropout's `Run()` operator completely to generate random numbers for a `[MPerBlock, NPerBlock]` tile during the tile iteration, no need to replicate the logic of BlockDropout in the xformers kernel --- .../ck_tile/ops/fmha/block/block_dropout.hpp | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp index 37c1fe4805..78d68a482e 100644 --- a/include/ck_tile/ops/fmha/block/block_dropout.hpp +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -381,24 +381,28 @@ struct BlockDropout store_tile(randval_dram_window, randval_store); } move_tile_window(randval_dram_window, {0, kNPerStep}); - // Drop values of P based on the generated probabilities - constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); - sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { - constexpr auto p_idx0 = - tile_distributed_index()>{}; - constexpr auto p_idx1 = - tile_distributed_index(), - idx1.impl_.template at<2>()>{}; - constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); - constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); - p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t - ? p_compute[p_idx] * rp_undrop - : PComputeDataType(0); + + if constexpr(!is_null_tile_window_v) + { + // Drop values of P based on the generated probabilities + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto p_idx0 = + tile_distributed_index()>{}; + constexpr auto p_idx1 = + tile_distributed_index(), + idx1.impl_.template at<2>()>{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] * rp_undrop + : PComputeDataType(0); + }); }); - }); + } }); move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); });