mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Adding seed and offset pointer support to the philox random number generator. (#1523)
* Adding seed and offset pointer support to the philox random number generator.
* Separating seed and offset pointer checks with different condition statements.
* Changes include, adding support for device seed and offset pointers, union is used to store seed/offset values and device pointers to minimize device SGPRs.
* Correcting a typo in the readme file
* Re-format files using remod.py
* Use STL type for API parameters
* Use simpler struct design for drop_seed & drop_offset
* Undo unnecessary changes
* Sync kargs style for fmha_fwd.hpp/.cpp
* Use templated union to reduce code
* Use structured binding to make code more readable
---------
Co-authored-by: Sudhir Kylasa <sukylasa@amd.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
[ROCm/composable_kernel commit: c24fae2346]
This commit is contained in:
@@ -70,8 +70,13 @@ args:
|
||||
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
-drop_seed seed for the random number generator for the dropout layer, default is 1
|
||||
-drop_offset offset for the dropout layer which is used during random number generation, default is 0
|
||||
-drop_prefs flag to indicate `drop_seed` and `drop_offset` values if present on the GPU, default is 0, 0 - host, 1 - GPU
|
||||
```
|
||||
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
|
||||
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
|
||||
|
||||
## support features
|
||||
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
|
||||
|
||||
@@ -85,6 +85,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for random number generator")
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("drop_prefs",
|
||||
"0",
|
||||
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
@@ -158,6 +161,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
|
||||
if(use_dbias && bias.type != bias_enum::elementwise_bias)
|
||||
{
|
||||
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
|
||||
@@ -381,6 +386,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
|
||||
|
||||
@@ -391,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
do_buf.ToDevice(do_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
||||
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
|
||||
// clang-format off
|
||||
@@ -472,6 +481,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t split_stride_dq_acc =
|
||||
(shape_batch * nhead * shape_seqlen_q * hdim_q);
|
||||
|
||||
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
|
||||
if(drop_prefs)
|
||||
{
|
||||
return std::make_pair(drop_seed_buf.GetDeviceBuffer(),
|
||||
drop_offset_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_pair(drop_seed, drop_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
@@ -545,7 +566,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_drop,
|
||||
p_undrop,
|
||||
{drop_seed, drop_offset}};
|
||||
drop_seed_offset};
|
||||
}();
|
||||
|
||||
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
|
||||
|
||||
@@ -9,7 +9,10 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "bias.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaBwdTypeConfig;
|
||||
@@ -135,7 +138,8 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t mask_type;
|
||||
float p_drop;
|
||||
float p_undrop;
|
||||
std::tuple<uint64_t, uint64_t> drop_seed_offset;
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
template <typename FmhaBwdDQDKDVKernel>
|
||||
|
||||
@@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for random number generator")
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("drop_prefs",
|
||||
"0",
|
||||
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert(
|
||||
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
|
||||
@@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
|
||||
if(p_drop < 0.0f || p_drop > 1.0f)
|
||||
{
|
||||
std::cerr << "The value of p_drop should be 0~1" << std::endl;
|
||||
@@ -756,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
|
||||
@@ -774,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
|
||||
rotary_cos_buf.ToDevice(rotary_cos_host.data());
|
||||
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
||||
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
||||
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
block_table_buf.ToDevice(block_table_host.data());
|
||||
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
|
||||
@@ -1013,9 +1022,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.nhead_stride_randval = nhead_stride_randval;
|
||||
args.batch_stride_randval = batch_stride_randval;
|
||||
|
||||
args.p_drop = p_drop;
|
||||
args.s_randval = s_randval;
|
||||
args.drop_seed_offset = std::tie(drop_seed, drop_offset);
|
||||
args.p_drop = p_drop;
|
||||
args.s_randval = s_randval;
|
||||
if(drop_prefs)
|
||||
{
|
||||
args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(),
|
||||
drop_offset_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
#include "rotary.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaFwdTypeConfig;
|
||||
@@ -144,7 +146,9 @@ struct fmha_fwd_args
|
||||
|
||||
float p_drop;
|
||||
bool s_randval;
|
||||
std::tuple<uint64_t, uint64_t> drop_seed_offset;
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
struct fmha_fwd_splitkv_args
|
||||
|
||||
@@ -6,8 +6,11 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
@@ -194,11 +197,23 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
struct FmhaBwdCommonDropoutKargs
|
||||
struct FmhaBwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(const float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const float raw_scale)
|
||||
template <typename T>
|
||||
union ValueOrPointer
|
||||
{
|
||||
T val;
|
||||
const T* ptr;
|
||||
};
|
||||
|
||||
ValueOrPointer<uint64_t> drop_seed;
|
||||
ValueOrPointer<uint64_t> drop_offset;
|
||||
bool is_drop_seed_offset_from_host;
|
||||
};
|
||||
|
||||
struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
@@ -206,23 +221,41 @@ struct FmhaBwdDQDKDVKernel
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
scale_rp_undrop = rp_undrop * raw_scale;
|
||||
|
||||
drop_seed = std::get<0>(drop_seed_offset);
|
||||
drop_offset = std::get<1>(drop_seed_offset);
|
||||
this->drop_seed.val = seed;
|
||||
this->drop_offset.val = offset;
|
||||
this->is_drop_seed_offset_from_host = true;
|
||||
}
|
||||
|
||||
void init_dropout(float p_drop,
|
||||
const uint64_t* seed_ptr,
|
||||
const uint64_t* offset_ptr,
|
||||
float raw_scale)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
scale_rp_undrop = rp_undrop * raw_scale;
|
||||
|
||||
this->drop_seed.ptr = seed_ptr;
|
||||
this->drop_offset.ptr = offset_ptr;
|
||||
this->is_drop_seed_offset_from_host = false;
|
||||
}
|
||||
|
||||
float rp_undrop = 1;
|
||||
float scale_rp_undrop = 1;
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
uint64_t drop_seed = 1;
|
||||
uint64_t drop_offset = 0;
|
||||
void* rand_val_ptr = nullptr;
|
||||
|
||||
ck_tile::index_t stride_randval = 0;
|
||||
ck_tile::index_t nhead_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaBwdDeterministicKargs
|
||||
{
|
||||
ck_tile::index_t split_stride_dq_acc = 0;
|
||||
@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset, scale);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset, scale);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr),
|
||||
scale);
|
||||
}
|
||||
|
||||
if constexpr(kIsStoreRandval)
|
||||
{
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset, scale);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset, scale);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr),
|
||||
scale);
|
||||
}
|
||||
|
||||
if constexpr(kIsStoreRandval)
|
||||
{
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
return FmhaDropout{i_batch_,
|
||||
i_nhead_,
|
||||
kargs.num_head_q,
|
||||
kargs.drop_seed,
|
||||
kargs.drop_offset,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
|
||||
: *kargs.drop_seed.ptr,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
|
||||
: *kargs.drop_offset.ptr,
|
||||
kargs.rp_undrop,
|
||||
kargs.p_undrop_in_uint8_t};
|
||||
}
|
||||
|
||||
@@ -6,8 +6,11 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
@@ -170,29 +173,55 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_lse = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonDropoutKargs
|
||||
struct FmhaFwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(const float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
template <typename T>
|
||||
union ValueOrPointer
|
||||
{
|
||||
T val;
|
||||
const T* ptr;
|
||||
};
|
||||
|
||||
ValueOrPointer<uint64_t> drop_seed;
|
||||
ValueOrPointer<uint64_t> drop_offset;
|
||||
bool is_drop_seed_offset_from_host;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
drop_seed = std::get<0>(drop_seed_offset);
|
||||
drop_offset = std::get<1>(drop_seed_offset);
|
||||
this->drop_seed.val = seed;
|
||||
this->drop_offset.val = offset;
|
||||
this->is_drop_seed_offset_from_host = true;
|
||||
}
|
||||
|
||||
void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
this->drop_seed.ptr = seed_ptr;
|
||||
this->drop_offset.ptr = offset_ptr;
|
||||
this->is_drop_seed_offset_from_host = false;
|
||||
}
|
||||
|
||||
float rp_undrop = 1;
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
bool is_store_randval = false;
|
||||
uint64_t drop_seed = 1;
|
||||
uint64_t drop_offset = 0;
|
||||
void* rand_val_ptr = nullptr;
|
||||
|
||||
ck_tile::index_t stride_randval = 0;
|
||||
ck_tile::index_t nhead_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
@@ -278,7 +307,8 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -344,7 +374,19 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr));
|
||||
}
|
||||
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
kargs.stride_randval = stride_randval;
|
||||
kargs.nhead_stride_randval = nhead_stride_randval;
|
||||
@@ -392,7 +434,8 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -455,7 +498,19 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr));
|
||||
}
|
||||
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
kargs.stride_randval = stride_randval;
|
||||
kargs.nhead_stride_randval = nhead_stride_randval;
|
||||
@@ -748,8 +803,10 @@ struct FmhaFwdKernel
|
||||
return BlockDropout{i_batch_,
|
||||
i_nhead_,
|
||||
kargs.num_head_q,
|
||||
kargs.drop_seed,
|
||||
kargs.drop_offset,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
|
||||
: *kargs.drop_seed.ptr,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
|
||||
: *kargs.drop_offset.ptr,
|
||||
kargs.rp_undrop,
|
||||
kargs.p_undrop_in_uint8_t,
|
||||
kargs.is_store_randval};
|
||||
|
||||
Reference in New Issue
Block a user