mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
Adding seed and offset pointer support to the philox random number generator. (#1523)
* Adding seed and offset pointer support to the philox random number generator. * Separating seed and offset pointer checks with different condition statements. * Changes include, adding support for device seed and offset pointers, union is used to store seed/offset values and device pointers to minimize device SGPRs. * Correcting a typo in the readme file * Re-format files using remod.py * Use STL type for API parameters * Use simpler struct design for drop_seed & drop_offset * Undo unnecessary changes * Sync kargs style for fmha_fwd.hpp/.cpp * Use templated union to reduce code * Use structured binding to make code more readable --------- Co-authored-by: Sudhir Kylasa <sukylasa@amd.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for random number generator")
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("drop_prefs",
|
||||
"0",
|
||||
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert(
|
||||
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
|
||||
@@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
|
||||
if(p_drop < 0.0f || p_drop > 1.0f)
|
||||
{
|
||||
std::cerr << "The value of p_drop should be 0~1" << std::endl;
|
||||
@@ -756,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
|
||||
@@ -774,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
|
||||
rotary_cos_buf.ToDevice(rotary_cos_host.data());
|
||||
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
||||
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
||||
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
block_table_buf.ToDevice(block_table_host.data());
|
||||
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
|
||||
@@ -1013,9 +1022,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.nhead_stride_randval = nhead_stride_randval;
|
||||
args.batch_stride_randval = batch_stride_randval;
|
||||
|
||||
args.p_drop = p_drop;
|
||||
args.s_randval = s_randval;
|
||||
args.drop_seed_offset = std::tie(drop_seed, drop_offset);
|
||||
args.p_drop = p_drop;
|
||||
args.s_randval = s_randval;
|
||||
if(drop_prefs)
|
||||
{
|
||||
args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(),
|
||||
drop_offset_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user