[rocm-libraries] ROCm/rocm-libraries#7256 (commit 1fc20eb)

=?UTF-8?q?Skip=20numeric=20drop-out=20when=20PComputeWind?=
 =?UTF-8?q?ow=20is=20a=20null=5Ftile=5Fwindow=20in=20Bl=E2=80=A6=20(#7256)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The BlockDropout implementation already provides very complete logic for
generating random numbers and executing dropout for the P tensor after
first attention Gemm with capability to support both Warp-Gemm 32x32 and
16x16 as well as to run on both wave32 and wave64 arch.

But in some situation, we only need the block-layer process to generate
random numbers, rather than simultaneously execute dropout in real-time
on the vgpr tile. For example, xformers'
`test_mem_eff_attention.py::test_dropout_ck` requires the host reference
implementation of `attention forward with dropout` to use the same
random numbers to compare & verify the device side implementation of
`attention forward with dropout`, so a standalone kernel to generate
random numbers only is required.

This PR will enable xformers's random_val generating kernel (in file
`ck_tiled_rand_uniform_kernel.h`) to depend on BlockDropout's `Run()`
operator completely to generate random numbers for a `[MPerBlock,
NPerBlock]` tile during the tile iteration, no need to replicate the
logic of BlockDropout in the xformers kernel
This commit is contained in:
Qianfeng
2026-05-13 09:42:28 +00:00
committed by assistant-librarian[bot]
parent 5c7b7ec3f1
commit acf3d65966

View File

@@ -381,24 +381,28 @@ struct BlockDropout
store_tile(randval_dram_window, randval_store);
}
move_tile_window(randval_dram_window, {0, kNPerStep});
// Drop values of P based on the generated probabilities
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto p_idx0 =
tile_distributed_index<i_m0 * MIterPerWarp +
idx0.impl_.template at<0>()>{};
constexpr auto p_idx1 =
tile_distributed_index<i_n0,
idx1.impl_.template at<1>(),
idx1.impl_.template at<2>()>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
? p_compute[p_idx] * rp_undrop
: PComputeDataType(0);
if constexpr(!is_null_tile_window_v<PComputeWindow>)
{
// Drop values of P based on the generated probabilities
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto p_idx0 =
tile_distributed_index<i_m0 * MIterPerWarp +
idx0.impl_.template at<0>()>{};
constexpr auto p_idx1 =
tile_distributed_index<i_n0,
idx1.impl_.template at<1>(),
idx1.impl_.template at<2>()>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
? p_compute[p_idx] * rp_undrop
: PComputeDataType(0);
});
});
});
}
});
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
});