Smalle update in reference hstu attention

This commit is contained in:
Qianfeng Zhang
2025-09-13 06:42:46 +00:00
parent 798fc3cc0b
commit a5b7360862
2 changed files with 124 additions and 123 deletions

View File

@@ -219,8 +219,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
int window_size = arg_parser.get_int("local_len");
bool use_local = (window_size > 0);
int contextual_seqlen = arg_parser.get_int("context_len");
int min_full_attn_seqlen = arg_parser.get_int("minfull_len");
@@ -516,26 +514,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
using GemmAccDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType;
using CompDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType;
BOOL_SWITCH_3(is_jagged, kIsJagged, use_causal, kUseCausal, use_local, kUseLocal, [&] {
BOOL_SWITCH_2(is_jagged, kIsJagged, use_causal, kUseCausal, [&] {
ck_tile::reference_hstu_attention<InOutDataType,
GemmAccDataType,
CompDataType,
kIsJagged,
kUseCausal,
kUseLocal>::Run(q_host,
k_host,
v_host,
o_host_ref,
mask_host,
num_batch,
scale_s,
attn_scale,
max_seqlen,
seq_offsets,
num_targets,
window_size,
contextual_seqlen,
min_full_attn_seqlen);
kUseCausal>::Run(q_host,
k_host,
v_host,
o_host_ref,
mask_host,
num_batch,
scale_s,
attn_scale,
max_seqlen,
seq_offsets,
num_targets,
contextual_seqlen,
window_size,
min_full_attn_seqlen);
});
ck_tile::HostTensor<InOutDataType> o_host(

View File

@@ -29,13 +29,9 @@ template <typename InOutDataType,
typename GemmAccDataType,
typename CompDataType,
bool kIsJagged,
bool kUseCausal,
bool kUseLocal>
bool kUseCausal>
struct reference_hstu_attention
{
using HstuMask = typename HstuBlockMasking<kUseCausal, kUseLocal>::Type;
static constexpr bool kHasLocalMask = HstuMask::kUseLocal;
static void Run(const HostTensor<InOutDataType>& q_batch_seq_nhead_hdim,
const HostTensor<InOutDataType>& k_batch_seq_nhead_hdim,
const HostTensor<InOutDataType>& v_batch_seq_nhead_hdim,
@@ -48,9 +44,9 @@ struct reference_hstu_attention
std::vector<int> seq_offsets,
std::vector<int> num_targets, // define masking length at the end of token
// sequence to be excluded for attention
int max_attn_len, // define the diagonal local window size
int contextual_seqlen, // define masking length at the begin of query token
// sequence to be included for attention
int window_size, // define the diagonal local window size
int min_full_attn_seqlen) // define masking length at the end of query token
// sequence which is included for full attention
{
@@ -112,122 +108,130 @@ struct reference_hstu_attention
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
HstuMask mask = [&]() {
if constexpr(kHasLocalMask)
// need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if the
// user passed min_full_attn_seqlen is bigger than max_uih_len
if(seqlen - num_target > min_full_attn_seqlen)
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(
true,
seqlen,
contextual_seqlen,
num_target,
max_attn_len,
min_full_attn_seqlen);
float scale_p = attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen);
BOOL_SWITCH(window_size > 0, kHasLocal, [&] {
using HstuMaskType = typename HstuBlockMasking<kUseCausal, kHasLocal>::Type;
HstuMaskType mask = [&]() {
if constexpr(kHasLocal)
// need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if the
// user passed min_full_attn_seqlen is bigger than max_uih_len
if(seqlen - num_target > min_full_attn_seqlen)
return ck_tile::make_hstu_block_mask_with_local<HstuMaskType>(
true,
seqlen,
contextual_seqlen,
num_target,
window_size,
min_full_attn_seqlen);
else
return ck_tile::make_hstu_block_mask_with_local<HstuMaskType>(
true,
seqlen,
contextual_seqlen,
num_target,
window_size,
seqlen - num_target);
else
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(true,
seqlen,
contextual_seqlen,
num_target,
max_attn_len,
seqlen -
num_target);
else
return ck_tile::make_hstu_block_mask_without_local<HstuMask>(
seqlen, contextual_seqlen, num_target);
}();
return ck_tile::make_hstu_block_mask_without_local<HstuMaskType>(
seqlen, contextual_seqlen, num_target);
}();
if(save_mask)
{
// initialize the mask
for(int sq = 0; sq < max_seqlen; sq++)
for(int sk = 0; sk < max_seqlen; sk++)
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) =
static_cast<int8_t>(mask.IsTokenPairInsideMask(sq, sk));
}
// for all rows in the batch
for(int sq = 0; sq < seqlen; sq++)
{
std::vector<CompDataType> locals;
// for all cols in the batch
for(int sk = 0; sk < seqlen; sk++)
if(save_mask)
{
if(mask.IsTokenPairInsideMask(sq, sk))
// initialize the mask
for(int sq = 0; sq < max_seqlen; sq++)
for(int sk = 0; sk < max_seqlen; sk++)
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) =
static_cast<int8_t>(mask.IsTokenPairInsideMask(sq, sk));
}
// for all rows in the batch
for(int sq = 0; sq < seqlen; sq++)
{
std::vector<CompDataType> locals;
// for all cols in the batch
for(int sk = 0; sk < seqlen; sk++)
{
if(mask.IsTokenPairInsideMask(sq, sk))
{
GemmAccDataType dot_prod = 0.f;
for(int k = 0; k < hdim_qk; k++)
{
if constexpr(kIsJagged)
{
InOutDataType qreg = q_batch_seq_nhead_hdim(
0, seq_offsets[i_batch] + sq, i_head, k);
InOutDataType kreg = k_batch_seq_nhead_hdim(
0, seq_offsets[i_batch] + sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
ck_tile::type_convert<GemmAccDataType>(kreg);
}
else
{
InOutDataType qreg =
q_batch_seq_nhead_hdim(i_batch, sq, i_head, k);
InOutDataType kreg =
k_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
ck_tile::type_convert<GemmAccDataType>(kreg);
};
}
locals.push_back(ck_tile::type_convert<CompDataType>(dot_prod) *
ck_tile::type_convert<CompDataType>(alpha));
}
else
locals.push_back(ck_tile::type_convert<CompDataType>(0.0f));
};
// SiLu element-wise
for(CompDataType& elem : locals)
elem = silu(elem);
// second Gemm
for(int k = 0; k < hdim_v; k++)
{
GemmAccDataType dot_prod = 0.f;
for(int k = 0; k < hdim_qk; k++)
for(int sk = 0; sk < seqlen; sk++)
{
if constexpr(kIsJagged)
{
InOutDataType qreg =
q_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k);
InOutDataType kreg =
k_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
InOutDataType preg =
ck_tile::type_convert<InOutDataType>(locals[sk]);
InOutDataType vreg =
v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
ck_tile::type_convert<GemmAccDataType>(kreg);
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
ck_tile::type_convert<GemmAccDataType>(vreg);
}
else
{
InOutDataType qreg = q_batch_seq_nhead_hdim(i_batch, sq, i_head, k);
InOutDataType kreg = k_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
InOutDataType preg =
ck_tile::type_convert<InOutDataType>(locals[sk]);
InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
ck_tile::type_convert<GemmAccDataType>(kreg);
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
ck_tile::type_convert<GemmAccDataType>(vreg);
};
}
locals.push_back(ck_tile::type_convert<CompDataType>(dot_prod) *
ck_tile::type_convert<CompDataType>(alpha));
}
else
locals.push_back(ck_tile::type_convert<CompDataType>(0.0f));
};
// SiLu element-wise
for(CompDataType& elem : locals)
elem = silu(elem);
float scale_p = attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen);
// second Gemm
for(int k = 0; k < hdim_v; k++)
{
GemmAccDataType dot_prod = 0.f;
for(int sk = 0; sk < seqlen; sk++)
{
if constexpr(kIsJagged)
{
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
InOutDataType vreg =
v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
ck_tile::type_convert<GemmAccDataType>(vreg);
}
else
{
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
ck_tile::type_convert<GemmAccDataType>(vreg);
};
dot_prod = dot_prod * ck_tile::type_convert<GemmAccDataType>(scale_p);
if constexpr(kIsJagged)
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =
ck_tile::type_convert<InOutDataType>(dot_prod);
else
o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) =
ck_tile::type_convert<InOutDataType>(dot_prod);
};
dot_prod = dot_prod * ck_tile::type_convert<GemmAccDataType>(scale_p);
if constexpr(kIsJagged)
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =
ck_tile::type_convert<InOutDataType>(dot_prod);
else
o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) =
ck_tile::type_convert<InOutDataType>(dot_prod);
};
};
});
};
make_ParallelTensorFunctor(f, num_batch, num_head)(std::thread::hardware_concurrency());