mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Add max_seqlen as divider in siLu
This commit is contained in:
@@ -449,6 +449,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
o_host_ref,
|
||||
num_batch,
|
||||
1.0f / std::sqrt(params.hdim_qk),
|
||||
is_jagged ? max_seqlen : seqlen,
|
||||
seq_offsets,
|
||||
num_targets,
|
||||
window_size,
|
||||
|
||||
@@ -141,6 +141,8 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t max_seqlen;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdCommonKargs,
|
||||
@@ -155,6 +157,7 @@ struct HstuAttentionFwdKernel
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
{
|
||||
const int32_t* seq_offsets_ptr;
|
||||
ck_tile::index_t max_seqlen;
|
||||
};
|
||||
|
||||
using Kargs = std::
|
||||
@@ -219,7 +222,8 @@ struct HstuAttentionFwdKernel
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_o};
|
||||
batch_stride_o,
|
||||
seqlen}; // max_seqlen
|
||||
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
@@ -319,6 +323,7 @@ struct HstuAttentionFwdKernel
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
const void* seq_offsets_ptr,
|
||||
ck_tile::index_t max_seqlen,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
@@ -362,7 +367,8 @@ struct HstuAttentionFwdKernel
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for dropout
|
||||
reinterpret_cast<const int32_t*>(seq_offsets_ptr)};
|
||||
reinterpret_cast<const int32_t*>(seq_offsets_ptr),
|
||||
max_seqlen};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
@@ -393,6 +399,7 @@ struct HstuAttentionFwdKernel
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
const void* seq_offsets_ptr,
|
||||
ck_tile::index_t max_seqlen,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
@@ -421,6 +428,7 @@ struct HstuAttentionFwdKernel
|
||||
bias_ptr,
|
||||
o_ptr,
|
||||
seq_offsets_ptr,
|
||||
max_seqlen,
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
num_head,
|
||||
@@ -732,6 +740,7 @@ struct HstuAttentionFwdKernel
|
||||
bias_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.max_seqlen,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}();
|
||||
|
||||
@@ -133,6 +133,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
HstuMask mask,
|
||||
float scale_s,
|
||||
index_t max_seqlen, // used by silu
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
@@ -232,16 +233,17 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
statically_indexed_array<PcompBlockTileType, k1_loops> pcomp_tiles;
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [](CompDataType& x) {
|
||||
const auto f_silu = [&](CompDataType& x) {
|
||||
const auto neg_one = ck_tile::type_convert<CompDataType>(-1.0f);
|
||||
|
||||
if constexpr(std::is_same_v<CompDataType, float>)
|
||||
{
|
||||
x = x * __builtin_amdgcn_rcpf(neg_one - __expf(x));
|
||||
x = x * __builtin_amdgcn_rcpf(neg_one - __expf(x)) *
|
||||
__builtin_amdgcn_rcpf(static_cast<CompDataType>(max_seqlen));
|
||||
}
|
||||
else
|
||||
{
|
||||
x = x / (neg_one - exp(x));
|
||||
x = x / (neg_one - exp(x)) / static_cast<CompDataType>(max_seqlen);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -477,6 +479,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
HstuMask mask,
|
||||
float scale_s,
|
||||
index_t max_seqlen,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
@@ -493,6 +496,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
identity{},
|
||||
mask,
|
||||
scale_s,
|
||||
max_seqlen,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
|
||||
@@ -92,6 +92,7 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
|
||||
param.bias_ptr,
|
||||
param.o_ptr,
|
||||
param.seq_offsets_ptr,
|
||||
param.max_seqlen,
|
||||
param.hdim_qk,
|
||||
param.hdim_v,
|
||||
param.num_head,
|
||||
|
||||
@@ -42,6 +42,7 @@ struct reference_hstu_attention
|
||||
HostTensor<InOutDataType>& o_batch_seq_nhead_hdim,
|
||||
int num_batch,
|
||||
float alpha,
|
||||
int max_seqlen,
|
||||
std::vector<int> seq_offsets,
|
||||
std::vector<int> num_targets, // define masking length at the end of token
|
||||
// sequence to be excluded for attention
|
||||
@@ -89,10 +90,10 @@ struct reference_hstu_attention
|
||||
// check num_tagets
|
||||
assert(num_tagets.empty() || num_targets.size() == num_batch);
|
||||
|
||||
auto silu = [](CompDataType x) {
|
||||
auto silu = [&](CompDataType x) {
|
||||
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
|
||||
return x / (one + std::exp(-x));
|
||||
return x / (one + std::exp(-x)) / ck_tile::type_convert<CompDataType>(max_seqlen);
|
||||
};
|
||||
|
||||
auto f = [&](auto i_batch, auto i_head) {
|
||||
|
||||
Reference in New Issue
Block a user