mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Hacking ck_tile fmha Dropout facility (#1344)
* Add NullBlockDropout to be used when kHasDropout is false * Change to BlockDropout::Run() for forward to reduce conditional checkings * Re-format files --------- Co-authored-by: PoYen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -744,29 +744,23 @@ struct FmhaFwdKernel
|
||||
}
|
||||
}();
|
||||
|
||||
// dropout
|
||||
float rp_undrop = 1;
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
uint64_t drop_seed = 0;
|
||||
uint64_t drop_offset = 0;
|
||||
bool is_store_randval = false;
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
rp_undrop = kargs.rp_undrop;
|
||||
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
|
||||
drop_seed = kargs.drop_seed;
|
||||
drop_offset = kargs.drop_offset;
|
||||
is_store_randval = kargs.is_store_randval;
|
||||
}
|
||||
BlockDropout dropout(i_batch,
|
||||
i_nhead,
|
||||
kargs.num_head_q,
|
||||
drop_seed,
|
||||
drop_offset,
|
||||
rp_undrop,
|
||||
p_undrop_in_uint8_t,
|
||||
is_store_randval);
|
||||
auto dropout = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return BlockDropout{i_batch,
|
||||
i_nhead,
|
||||
kargs.num_head_q,
|
||||
kargs.drop_seed,
|
||||
kargs.drop_offset,
|
||||
kargs.rp_undrop,
|
||||
kargs.p_undrop_in_uint8_t,
|
||||
kargs.is_store_randval};
|
||||
}
|
||||
else
|
||||
{
|
||||
return NullBlockDropout{};
|
||||
};
|
||||
}();
|
||||
|
||||
auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
constexpr auto randval_dram_window_lengths =
|
||||
|
||||
Reference in New Issue
Block a user