From 6f048f54dc08d5c10abe9dd05e2b2e546544d060 Mon Sep 17 00:00:00 2001 From: kylasa Date: Fri, 4 Oct 2024 11:48:47 -0700 Subject: [PATCH] Adding seed and offset pointer support to the philox random number generator. (#1523) * Adding seed and offset pointer support to the philox random number generator. * Separating seed and offset pointer checks with different condition statements. * Changes include, adding support for device seed and offset pointers, union is used to store seed/offset values and device pointers to minimize device SGPRs. * Correcting a typo in the readme file * Re-format files using remod.py * Use STL type for API parameters * Use simpler struct design for drop_seed & drop_offset * Undo unnecessary changes * Sync kargs style for fmha_fwd.hpp/.cpp * Use templated union to reduce code * Use structured binding to make code more readable --------- Co-authored-by: Sudhir Kylasa Co-authored-by: Po Yen Chen [ROCm/composable_kernel commit: c24fae234600aa2863e945d072e6f5b3aec2a6b2] --- example/ck_tile/01_fmha/README.md | 7 +- example/ck_tile/01_fmha/fmha_bwd.cpp | 23 ++++- example/ck_tile/01_fmha/fmha_bwd.hpp | 6 +- example/ck_tile/01_fmha/fmha_fwd.cpp | 23 ++++- example/ck_tile/01_fmha/fmha_fwd.hpp | 6 +- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 91 ++++++++++++++++--- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 83 ++++++++++++++--- 7 files changed, 205 insertions(+), 34 deletions(-) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 0bb5408772..0803d54d66 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -70,8 +70,13 @@ args: -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) -warmup number of iterations before benchmark the kernel (default:5) -repeat number of iterations to benchmark the kernel (default:20) + -drop_seed seed for the random number generator for the dropout layer, default is 1 +-drop_offset offset for the dropout layer which is used during random number generation, default is 0 + -drop_prefs flag to indicate `drop_seed` and `drop_offset` values if present on the GPU, default is 0, 0 - host, 1 - GPU ``` -Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. +Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. +Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with + batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case ## support features Currently we are still in rapid development stage, so more features/optimizations will be coming soon. diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index c2f554f6cc..2d76627a72 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -85,6 +85,9 @@ auto create_args(int argc, char* argv[]) .insert("p_drop", "0", "0~1 probability of dropout") .insert("drop_seed", "1", "seed for random number generator") .insert("drop_offset", "0", "offset for random number generator") + .insert("drop_prefs", + "0", + "seed and offset values are present on GPU; 0 - host, 1 - device/GPU") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel") @@ -158,6 +161,8 @@ bool run(const ck_tile::ArgParser& arg_parser) float p_drop = arg_parser.get_float("p_drop"); uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + bool drop_prefs = arg_parser.get_bool("drop_prefs"); + if(use_dbias && bias.type != bias_enum::elementwise_bias) { std::cerr << "dbias only exists when bias type is elementwise" << std::endl; @@ -381,6 +386,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes()); @@ -391,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser) do_buf.ToDevice(do_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqstart_k_host.data()); + drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); + drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); alibi_slope_buf.ToDevice(alibi_slope_host.data()); // clang-format off @@ -472,6 +481,18 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t split_stride_dq_acc = (shape_batch * nhead * shape_seqlen_q * hdim_q); + const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) { + if(drop_prefs) + { + return std::make_pair(drop_seed_buf.GetDeviceBuffer(), + drop_offset_buf.GetDeviceBuffer()); + } + else + { + return std::make_pair(drop_seed, drop_offset); + } + }(); + return fmha_bwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), @@ -545,7 +566,7 @@ bool run(const ck_tile::ArgParser& arg_parser) static_cast(mask.type), p_drop, p_undrop, - {drop_seed, drop_offset}}; + drop_seed_offset}; }(); float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index aea42515dc..3b21a3257f 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -9,7 +9,10 @@ #include "ck_tile/ops/epilogue.hpp" #include "mask.hpp" #include "bias.hpp" + #include +#include +#include template struct FmhaBwdTypeConfig; @@ -135,7 +138,8 @@ struct fmha_bwd_args ck_tile::index_t mask_type; float p_drop; float p_undrop; - std::tuple drop_seed_offset; + std::variant, std::pair> + drop_seed_offset; }; template diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index b9cb9a1ec2..6d519a7ea8 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[]) .insert("p_drop", "0", "0~1 probability of dropout") .insert("drop_seed", "1", "seed for random number generator") .insert("drop_offset", "0", "offset for random number generator") + .insert("drop_prefs", + "0", + "seed and offset values are present on GPU; 0 - host, 1 - device/GPU") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert( "rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all") @@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser) float p_drop = arg_parser.get_float("p_drop"); uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + bool drop_prefs = arg_parser.get_bool("drop_prefs"); + if(p_drop < 0.0f || p_drop > 1.0f) { std::cerr << "The value of p_drop should be 0~1" << std::endl; @@ -756,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser) need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); @@ -774,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser) cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); rotary_cos_buf.ToDevice(rotary_cos_host.data()); rotary_sin_buf.ToDevice(rotary_sin_host.data()); + drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); + drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); alibi_slope_buf.ToDevice(alibi_slope_host.data()); block_table_buf.ToDevice(block_table_host.data()); cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data()); @@ -1013,9 +1022,17 @@ bool run(const ck_tile::ArgParser& arg_parser) args.nhead_stride_randval = nhead_stride_randval; args.batch_stride_randval = batch_stride_randval; - args.p_drop = p_drop; - args.s_randval = s_randval; - args.drop_seed_offset = std::tie(drop_seed, drop_offset); + args.p_drop = p_drop; + args.s_randval = s_randval; + if(drop_prefs) + { + args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(), + drop_offset_buf.GetDeviceBuffer()); + } + else + { + args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); + } } else if constexpr(std::is_same_v>) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 5dcad7907f..251e61bc76 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -13,6 +13,8 @@ #include "rotary.hpp" #include +#include +#include template struct FmhaFwdTypeConfig; @@ -144,7 +146,9 @@ struct fmha_fwd_args float p_drop; bool s_randval; - std::tuple drop_seed_offset; + + std::variant, std::pair> + drop_seed_offset; }; struct fmha_fwd_splitkv_args diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 167494b193..c5858a20f7 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -6,8 +6,11 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" + #include #include +#include +#include // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] @@ -194,11 +197,23 @@ struct FmhaBwdDQDKDVKernel ck_tile::GenericAttentionMaskEnum mask_type; }; - struct FmhaBwdCommonDropoutKargs + struct FmhaBwdDropoutSeedOffset { - void init_dropout(const float p_drop, - const std::tuple& drop_seed_offset, - const float raw_scale) + template + union ValueOrPointer + { + T val; + const T* ptr; + }; + + ValueOrPointer drop_seed; + ValueOrPointer drop_offset; + bool is_drop_seed_offset_from_host; + }; + + struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset + { + void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale) { float p_undrop = 1.0 - p_drop; p_undrop_in_uint8_t = @@ -206,23 +221,41 @@ struct FmhaBwdDQDKDVKernel rp_undrop = 1.0 / p_undrop; scale_rp_undrop = rp_undrop * raw_scale; - drop_seed = std::get<0>(drop_seed_offset); - drop_offset = std::get<1>(drop_seed_offset); + this->drop_seed.val = seed; + this->drop_offset.val = offset; + this->is_drop_seed_offset_from_host = true; } + + void init_dropout(float p_drop, + const uint64_t* seed_ptr, + const uint64_t* offset_ptr, + float raw_scale) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + scale_rp_undrop = rp_undrop * raw_scale; + + this->drop_seed.ptr = seed_ptr; + this->drop_offset.ptr = offset_ptr; + this->is_drop_seed_offset_from_host = false; + } + float rp_undrop = 1; float scale_rp_undrop = 1; uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); - uint64_t drop_seed = 1; - uint64_t drop_offset = 0; void* rand_val_ptr = nullptr; ck_tile::index_t stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0; }; + struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs { ck_tile::index_t batch_stride_randval = 0; }; + struct FmhaBwdDeterministicKargs { ck_tile::index_t split_stride_dq_acc = 0; @@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, - const std::tuple& drop_seed_offset) + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel if constexpr(kHasDropout) { - kargs.init_dropout(p_drop, drop_seed_offset, scale); + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset, scale); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr), + scale); + } + if constexpr(kIsStoreRandval) { kargs.rand_val_ptr = rand_val_ptr; @@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, - const std::tuple& drop_seed_offset) + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel } if constexpr(kHasDropout) { - kargs.init_dropout(p_drop, drop_seed_offset, scale); + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset, scale); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr), + scale); + } + if constexpr(kIsStoreRandval) { kargs.rand_val_ptr = rand_val_ptr; @@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel return FmhaDropout{i_batch_, i_nhead_, kargs.num_head_q, - kargs.drop_seed, - kargs.drop_offset, + kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val + : *kargs.drop_seed.ptr, + kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val + : *kargs.drop_offset.ptr, kargs.rp_undrop, kargs.p_undrop_in_uint8_t}; } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 49ef7bf6d9..adabda165c 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -6,8 +6,11 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" + #include #include +#include +#include // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] @@ -170,29 +173,55 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_lse = 0; }; - struct FmhaFwdCommonDropoutKargs + struct FmhaFwdDropoutSeedOffset { - void init_dropout(const float p_drop, - const std::tuple& drop_seed_offset) + template + union ValueOrPointer + { + T val; + const T* ptr; + }; + + ValueOrPointer drop_seed; + ValueOrPointer drop_offset; + bool is_drop_seed_offset_from_host; + }; + + struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset + { + void init_dropout(float p_drop, uint64_t seed, uint64_t offset) { float p_undrop = 1.0 - p_drop; p_undrop_in_uint8_t = uint8_t(std::floor(p_undrop * std::numeric_limits::max())); rp_undrop = 1.0 / p_undrop; - drop_seed = std::get<0>(drop_seed_offset); - drop_offset = std::get<1>(drop_seed_offset); + this->drop_seed.val = seed; + this->drop_offset.val = offset; + this->is_drop_seed_offset_from_host = true; } + + void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed.ptr = seed_ptr; + this->drop_offset.ptr = offset_ptr; + this->is_drop_seed_offset_from_host = false; + } + float rp_undrop = 1; uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); bool is_store_randval = false; - uint64_t drop_seed = 1; - uint64_t drop_offset = 0; void* rand_val_ptr = nullptr; ck_tile::index_t stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0; }; + struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs { ck_tile::index_t batch_stride_randval = 0; @@ -278,7 +307,8 @@ struct FmhaFwdKernel ck_tile::index_t mask_type, float p_drop, bool s_randval, - const std::tuple& drop_seed_offset) + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -344,7 +374,19 @@ struct FmhaFwdKernel } if constexpr(kHasDropout) { - kargs.init_dropout(p_drop, drop_seed_offset); + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + kargs.rand_val_ptr = rand_val_ptr; kargs.stride_randval = stride_randval; kargs.nhead_stride_randval = nhead_stride_randval; @@ -392,7 +434,8 @@ struct FmhaFwdKernel ck_tile::index_t mask_type, float p_drop, bool s_randval, - const std::tuple& drop_seed_offset) + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -455,7 +498,19 @@ struct FmhaFwdKernel } if constexpr(kHasDropout) { - kargs.init_dropout(p_drop, drop_seed_offset); + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + kargs.rand_val_ptr = rand_val_ptr; kargs.stride_randval = stride_randval; kargs.nhead_stride_randval = nhead_stride_randval; @@ -748,8 +803,10 @@ struct FmhaFwdKernel return BlockDropout{i_batch_, i_nhead_, kargs.num_head_q, - kargs.drop_seed, - kargs.drop_offset, + kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val + : *kargs.drop_seed.ptr, + kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val + : *kargs.drop_offset.ptr, kargs.rp_undrop, kargs.p_undrop_in_uint8_t, kargs.is_store_randval};