Add max_seqlen as divider in siLu

This commit is contained in:
Qianfeng Zhang
2025-05-06 16:16:23 +00:00
parent 374e0626e6
commit 72d55d1b40
5 changed files with 23 additions and 7 deletions

View File

@@ -449,6 +449,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_host_ref,
num_batch,
1.0f / std::sqrt(params.hdim_qk),
is_jagged ? max_seqlen : seqlen,
seq_offsets,
num_targets,
window_size,

View File

@@ -141,6 +141,8 @@ struct HstuAttentionFwdKernel
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
ck_tile::index_t max_seqlen;
};
struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdCommonKargs,
@@ -155,6 +157,7 @@ struct HstuAttentionFwdKernel
HstuAttentionFwdEmptyKargs<2>>
{
const int32_t* seq_offsets_ptr;
ck_tile::index_t max_seqlen;
};
using Kargs = std::
@@ -219,7 +222,8 @@ struct HstuAttentionFwdKernel
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o};
batch_stride_o,
seqlen}; // max_seqlen
if constexpr(kHasBias)
{
@@ -319,6 +323,7 @@ struct HstuAttentionFwdKernel
const void* bias_ptr,
void* o_ptr,
const void* seq_offsets_ptr,
ck_tile::index_t max_seqlen,
ck_tile::index_t hdim_qk,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
@@ -362,7 +367,8 @@ struct HstuAttentionFwdKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for dropout
reinterpret_cast<const int32_t*>(seq_offsets_ptr)};
reinterpret_cast<const int32_t*>(seq_offsets_ptr),
max_seqlen};
if constexpr(kHasBias)
{
@@ -393,6 +399,7 @@ struct HstuAttentionFwdKernel
const void* bias_ptr,
void* o_ptr,
const void* seq_offsets_ptr,
ck_tile::index_t max_seqlen,
ck_tile::index_t hdim_qk,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
@@ -421,6 +428,7 @@ struct HstuAttentionFwdKernel
bias_ptr,
o_ptr,
seq_offsets_ptr,
max_seqlen,
hdim_qk,
hdim_v,
num_head,
@@ -732,6 +740,7 @@ struct HstuAttentionFwdKernel
bias_dram_window,
mask,
kargs.scale_s,
kargs.max_seqlen,
smem_ptr,
dropout);
}();

View File

@@ -133,6 +133,7 @@ struct HstuAttentionFwdPipelineQRKSVS
const OAccElementFunction& o_acc_element_func,
HstuMask mask,
float scale_s,
index_t max_seqlen, // used by silu
void* smem_ptr,
DropoutType& dropout) const
{
@@ -232,16 +233,17 @@ struct HstuAttentionFwdPipelineQRKSVS
statically_indexed_array<PcompBlockTileType, k1_loops> pcomp_tiles;
// reduction function for softmax
const auto f_silu = [](CompDataType& x) {
const auto f_silu = [&](CompDataType& x) {
const auto neg_one = ck_tile::type_convert<CompDataType>(-1.0f);
if constexpr(std::is_same_v<CompDataType, float>)
{
x = x * __builtin_amdgcn_rcpf(neg_one - __expf(x));
x = x * __builtin_amdgcn_rcpf(neg_one - __expf(x)) *
__builtin_amdgcn_rcpf(static_cast<CompDataType>(max_seqlen));
}
else
{
x = x / (neg_one - exp(x));
x = x / (neg_one - exp(x)) / static_cast<CompDataType>(max_seqlen);
}
};
@@ -477,6 +479,7 @@ struct HstuAttentionFwdPipelineQRKSVS
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
HstuMask mask,
float scale_s,
index_t max_seqlen,
void* smem_ptr,
DropoutType& dropout) const
{
@@ -493,6 +496,7 @@ struct HstuAttentionFwdPipelineQRKSVS
identity{},
mask,
scale_s,
max_seqlen,
smem_ptr,
dropout);
}

View File

@@ -92,6 +92,7 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
param.bias_ptr,
param.o_ptr,
param.seq_offsets_ptr,
param.max_seqlen,
param.hdim_qk,
param.hdim_v,
param.num_head,

View File

@@ -42,6 +42,7 @@ struct reference_hstu_attention
HostTensor<InOutDataType>& o_batch_seq_nhead_hdim,
int num_batch,
float alpha,
int max_seqlen,
std::vector<int> seq_offsets,
std::vector<int> num_targets, // define masking length at the end of token
// sequence to be excluded for attention
@@ -89,10 +90,10 @@ struct reference_hstu_attention
// check num_tagets
assert(num_tagets.empty() || num_targets.size() == num_batch);
auto silu = [](CompDataType x) {
auto silu = [&](CompDataType x) {
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
return x / (one + std::exp(-x));
return x / (one + std::exp(-x)) / ck_tile::type_convert<CompDataType>(max_seqlen);
};
auto f = [&](auto i_batch, auto i_head) {