mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[rocm-libraries] ROCm/rocm-libraries#7256 (commit 1fc20eb)
=?UTF-8?q?Skip=20numeric=20drop-out=20when=20PComputeWind?= =?UTF-8?q?ow=20is=20a=20null=5Ftile=5Fwindow=20in=20Bl=E2=80=A6=20(#7256)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The BlockDropout implementation already provides very complete logic for generating random numbers and executing dropout for the P tensor after first attention Gemm with capability to support both Warp-Gemm 32x32 and 16x16 as well as to run on both wave32 and wave64 arch. But in some situation, we only need the block-layer process to generate random numbers, rather than simultaneously execute dropout in real-time on the vgpr tile. For example, xformers' `test_mem_eff_attention.py::test_dropout_ck` requires the host reference implementation of `attention forward with dropout` to use the same random numbers to compare & verify the device side implementation of `attention forward with dropout`, so a standalone kernel to generate random numbers only is required. This PR will enable xformers's random_val generating kernel (in file `ck_tiled_rand_uniform_kernel.h`) to depend on BlockDropout's `Run()` operator completely to generate random numbers for a `[MPerBlock, NPerBlock]` tile during the tile iteration, no need to replicate the logic of BlockDropout in the xformers kernel
This commit is contained in:
committed by
assistant-librarian[bot]
parent
5c7b7ec3f1
commit
acf3d65966
@@ -381,24 +381,28 @@ struct BlockDropout
|
||||
store_tile(randval_dram_window, randval_store);
|
||||
}
|
||||
move_tile_window(randval_dram_window, {0, kNPerStep});
|
||||
// Drop values of P based on the generated probabilities
|
||||
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
|
||||
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto p_idx0 =
|
||||
tile_distributed_index<i_m0 * MIterPerWarp +
|
||||
idx0.impl_.template at<0>()>{};
|
||||
constexpr auto p_idx1 =
|
||||
tile_distributed_index<i_n0,
|
||||
idx1.impl_.template at<1>(),
|
||||
idx1.impl_.template at<2>()>{};
|
||||
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
|
||||
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
|
||||
? p_compute[p_idx] * rp_undrop
|
||||
: PComputeDataType(0);
|
||||
|
||||
if constexpr(!is_null_tile_window_v<PComputeWindow>)
|
||||
{
|
||||
// Drop values of P based on the generated probabilities
|
||||
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
|
||||
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto p_idx0 =
|
||||
tile_distributed_index<i_m0 * MIterPerWarp +
|
||||
idx0.impl_.template at<0>()>{};
|
||||
constexpr auto p_idx1 =
|
||||
tile_distributed_index<i_n0,
|
||||
idx1.impl_.template at<1>(),
|
||||
idx1.impl_.template at<2>()>{};
|
||||
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
|
||||
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
|
||||
? p_compute[p_idx] * rp_undrop
|
||||
: PComputeDataType(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user