diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp index 37c1fe4805..78d68a482e 100644 --- a/include/ck_tile/ops/fmha/block/block_dropout.hpp +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -381,24 +381,28 @@ struct BlockDropout store_tile(randval_dram_window, randval_store); } move_tile_window(randval_dram_window, {0, kNPerStep}); - // Drop values of P based on the generated probabilities - constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); - sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { - constexpr auto p_idx0 = - tile_distributed_index()>{}; - constexpr auto p_idx1 = - tile_distributed_index(), - idx1.impl_.template at<2>()>{}; - constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); - constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); - p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t - ? p_compute[p_idx] * rp_undrop - : PComputeDataType(0); + + if constexpr(!is_null_tile_window_v) + { + // Drop values of P based on the generated probabilities + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto p_idx0 = + tile_distributed_index()>{}; + constexpr auto p_idx1 = + tile_distributed_index(), + idx1.impl_.template at<2>()>{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] * rp_undrop + : PComputeDataType(0); + }); }); - }); + } }); move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); });