mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Adding seed and offset pointer support to the philox random number generator. (#1523)
* Adding seed and offset pointer support to the philox random number generator. * Separating seed and offset pointer checks with different condition statements. * Changes include, adding support for device seed and offset pointers, union is used to store seed/offset values and device pointers to minimize device SGPRs. * Correcting a typo in the readme file * Re-format files using remod.py * Use STL type for API parameters * Use simpler struct design for drop_seed & drop_offset * Undo unnecessary changes * Sync kargs style for fmha_fwd.hpp/.cpp * Use templated union to reduce code * Use structured binding to make code more readable --------- Co-authored-by: Sudhir Kylasa <sukylasa@amd.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -6,8 +6,11 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
@@ -194,11 +197,23 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
struct FmhaBwdCommonDropoutKargs
|
||||
struct FmhaBwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(const float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const float raw_scale)
|
||||
template <typename T>
|
||||
union ValueOrPointer
|
||||
{
|
||||
T val;
|
||||
const T* ptr;
|
||||
};
|
||||
|
||||
ValueOrPointer<uint64_t> drop_seed;
|
||||
ValueOrPointer<uint64_t> drop_offset;
|
||||
bool is_drop_seed_offset_from_host;
|
||||
};
|
||||
|
||||
struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
@@ -206,23 +221,41 @@ struct FmhaBwdDQDKDVKernel
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
scale_rp_undrop = rp_undrop * raw_scale;
|
||||
|
||||
drop_seed = std::get<0>(drop_seed_offset);
|
||||
drop_offset = std::get<1>(drop_seed_offset);
|
||||
this->drop_seed.val = seed;
|
||||
this->drop_offset.val = offset;
|
||||
this->is_drop_seed_offset_from_host = true;
|
||||
}
|
||||
|
||||
void init_dropout(float p_drop,
|
||||
const uint64_t* seed_ptr,
|
||||
const uint64_t* offset_ptr,
|
||||
float raw_scale)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
scale_rp_undrop = rp_undrop * raw_scale;
|
||||
|
||||
this->drop_seed.ptr = seed_ptr;
|
||||
this->drop_offset.ptr = offset_ptr;
|
||||
this->is_drop_seed_offset_from_host = false;
|
||||
}
|
||||
|
||||
float rp_undrop = 1;
|
||||
float scale_rp_undrop = 1;
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
uint64_t drop_seed = 1;
|
||||
uint64_t drop_offset = 0;
|
||||
void* rand_val_ptr = nullptr;
|
||||
|
||||
ck_tile::index_t stride_randval = 0;
|
||||
ck_tile::index_t nhead_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaBwdDeterministicKargs
|
||||
{
|
||||
ck_tile::index_t split_stride_dq_acc = 0;
|
||||
@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset, scale);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset, scale);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr),
|
||||
scale);
|
||||
}
|
||||
|
||||
if constexpr(kIsStoreRandval)
|
||||
{
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset, scale);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset, scale);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr),
|
||||
scale);
|
||||
}
|
||||
|
||||
if constexpr(kIsStoreRandval)
|
||||
{
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
return FmhaDropout{i_batch_,
|
||||
i_nhead_,
|
||||
kargs.num_head_q,
|
||||
kargs.drop_seed,
|
||||
kargs.drop_offset,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
|
||||
: *kargs.drop_seed.ptr,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
|
||||
: *kargs.drop_offset.ptr,
|
||||
kargs.rp_undrop,
|
||||
kargs.p_undrop_in_uint8_t};
|
||||
}
|
||||
|
||||
@@ -6,8 +6,11 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
@@ -170,29 +173,55 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_lse = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonDropoutKargs
|
||||
struct FmhaFwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(const float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
template <typename T>
|
||||
union ValueOrPointer
|
||||
{
|
||||
T val;
|
||||
const T* ptr;
|
||||
};
|
||||
|
||||
ValueOrPointer<uint64_t> drop_seed;
|
||||
ValueOrPointer<uint64_t> drop_offset;
|
||||
bool is_drop_seed_offset_from_host;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
drop_seed = std::get<0>(drop_seed_offset);
|
||||
drop_offset = std::get<1>(drop_seed_offset);
|
||||
this->drop_seed.val = seed;
|
||||
this->drop_offset.val = offset;
|
||||
this->is_drop_seed_offset_from_host = true;
|
||||
}
|
||||
|
||||
void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
this->drop_seed.ptr = seed_ptr;
|
||||
this->drop_offset.ptr = offset_ptr;
|
||||
this->is_drop_seed_offset_from_host = false;
|
||||
}
|
||||
|
||||
float rp_undrop = 1;
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
bool is_store_randval = false;
|
||||
uint64_t drop_seed = 1;
|
||||
uint64_t drop_offset = 0;
|
||||
void* rand_val_ptr = nullptr;
|
||||
|
||||
ck_tile::index_t stride_randval = 0;
|
||||
ck_tile::index_t nhead_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
@@ -278,7 +307,8 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -344,7 +374,19 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr));
|
||||
}
|
||||
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
kargs.stride_randval = stride_randval;
|
||||
kargs.nhead_stride_randval = nhead_stride_randval;
|
||||
@@ -392,7 +434,8 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -455,7 +498,19 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr));
|
||||
}
|
||||
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
kargs.stride_randval = stride_randval;
|
||||
kargs.nhead_stride_randval = nhead_stride_randval;
|
||||
@@ -748,8 +803,10 @@ struct FmhaFwdKernel
|
||||
return BlockDropout{i_batch_,
|
||||
i_nhead_,
|
||||
kargs.num_head_q,
|
||||
kargs.drop_seed,
|
||||
kargs.drop_offset,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
|
||||
: *kargs.drop_seed.ptr,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
|
||||
: *kargs.drop_offset.ptr,
|
||||
kargs.rp_undrop,
|
||||
kargs.p_undrop_in_uint8_t,
|
||||
kargs.is_store_randval};
|
||||
|
||||
Reference in New Issue
Block a user