mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Smalle update in reference hstu attention
This commit is contained in:
@@ -219,8 +219,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
int window_size = arg_parser.get_int("local_len");
|
||||
|
||||
bool use_local = (window_size > 0);
|
||||
|
||||
int contextual_seqlen = arg_parser.get_int("context_len");
|
||||
int min_full_attn_seqlen = arg_parser.get_int("minfull_len");
|
||||
|
||||
@@ -516,26 +514,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using GemmAccDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType;
|
||||
using CompDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType;
|
||||
|
||||
BOOL_SWITCH_3(is_jagged, kIsJagged, use_causal, kUseCausal, use_local, kUseLocal, [&] {
|
||||
BOOL_SWITCH_2(is_jagged, kIsJagged, use_causal, kUseCausal, [&] {
|
||||
ck_tile::reference_hstu_attention<InOutDataType,
|
||||
GemmAccDataType,
|
||||
CompDataType,
|
||||
kIsJagged,
|
||||
kUseCausal,
|
||||
kUseLocal>::Run(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
mask_host,
|
||||
num_batch,
|
||||
scale_s,
|
||||
attn_scale,
|
||||
max_seqlen,
|
||||
seq_offsets,
|
||||
num_targets,
|
||||
window_size,
|
||||
contextual_seqlen,
|
||||
min_full_attn_seqlen);
|
||||
kUseCausal>::Run(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
mask_host,
|
||||
num_batch,
|
||||
scale_s,
|
||||
attn_scale,
|
||||
max_seqlen,
|
||||
seq_offsets,
|
||||
num_targets,
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen);
|
||||
});
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> o_host(
|
||||
|
||||
@@ -29,13 +29,9 @@ template <typename InOutDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CompDataType,
|
||||
bool kIsJagged,
|
||||
bool kUseCausal,
|
||||
bool kUseLocal>
|
||||
bool kUseCausal>
|
||||
struct reference_hstu_attention
|
||||
{
|
||||
using HstuMask = typename HstuBlockMasking<kUseCausal, kUseLocal>::Type;
|
||||
static constexpr bool kHasLocalMask = HstuMask::kUseLocal;
|
||||
|
||||
static void Run(const HostTensor<InOutDataType>& q_batch_seq_nhead_hdim,
|
||||
const HostTensor<InOutDataType>& k_batch_seq_nhead_hdim,
|
||||
const HostTensor<InOutDataType>& v_batch_seq_nhead_hdim,
|
||||
@@ -48,9 +44,9 @@ struct reference_hstu_attention
|
||||
std::vector<int> seq_offsets,
|
||||
std::vector<int> num_targets, // define masking length at the end of token
|
||||
// sequence to be excluded for attention
|
||||
int max_attn_len, // define the diagonal local window size
|
||||
int contextual_seqlen, // define masking length at the begin of query token
|
||||
// sequence to be included for attention
|
||||
int window_size, // define the diagonal local window size
|
||||
int min_full_attn_seqlen) // define masking length at the end of query token
|
||||
// sequence which is included for full attention
|
||||
{
|
||||
@@ -112,122 +108,130 @@ struct reference_hstu_attention
|
||||
|
||||
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
|
||||
|
||||
HstuMask mask = [&]() {
|
||||
if constexpr(kHasLocalMask)
|
||||
// need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if the
|
||||
// user passed min_full_attn_seqlen is bigger than max_uih_len
|
||||
if(seqlen - num_target > min_full_attn_seqlen)
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(
|
||||
true,
|
||||
seqlen,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
max_attn_len,
|
||||
min_full_attn_seqlen);
|
||||
float scale_p = attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen);
|
||||
|
||||
BOOL_SWITCH(window_size > 0, kHasLocal, [&] {
|
||||
using HstuMaskType = typename HstuBlockMasking<kUseCausal, kHasLocal>::Type;
|
||||
|
||||
HstuMaskType mask = [&]() {
|
||||
if constexpr(kHasLocal)
|
||||
// need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if the
|
||||
// user passed min_full_attn_seqlen is bigger than max_uih_len
|
||||
if(seqlen - num_target > min_full_attn_seqlen)
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMaskType>(
|
||||
true,
|
||||
seqlen,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
window_size,
|
||||
min_full_attn_seqlen);
|
||||
else
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMaskType>(
|
||||
true,
|
||||
seqlen,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
window_size,
|
||||
seqlen - num_target);
|
||||
else
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(true,
|
||||
seqlen,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
max_attn_len,
|
||||
seqlen -
|
||||
num_target);
|
||||
else
|
||||
return ck_tile::make_hstu_block_mask_without_local<HstuMask>(
|
||||
seqlen, contextual_seqlen, num_target);
|
||||
}();
|
||||
return ck_tile::make_hstu_block_mask_without_local<HstuMaskType>(
|
||||
seqlen, contextual_seqlen, num_target);
|
||||
}();
|
||||
|
||||
if(save_mask)
|
||||
{
|
||||
// initialize the mask
|
||||
for(int sq = 0; sq < max_seqlen; sq++)
|
||||
for(int sk = 0; sk < max_seqlen; sk++)
|
||||
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) =
|
||||
static_cast<int8_t>(mask.IsTokenPairInsideMask(sq, sk));
|
||||
}
|
||||
|
||||
// for all rows in the batch
|
||||
for(int sq = 0; sq < seqlen; sq++)
|
||||
{
|
||||
std::vector<CompDataType> locals;
|
||||
|
||||
// for all cols in the batch
|
||||
for(int sk = 0; sk < seqlen; sk++)
|
||||
if(save_mask)
|
||||
{
|
||||
if(mask.IsTokenPairInsideMask(sq, sk))
|
||||
// initialize the mask
|
||||
for(int sq = 0; sq < max_seqlen; sq++)
|
||||
for(int sk = 0; sk < max_seqlen; sk++)
|
||||
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) =
|
||||
static_cast<int8_t>(mask.IsTokenPairInsideMask(sq, sk));
|
||||
}
|
||||
|
||||
// for all rows in the batch
|
||||
for(int sq = 0; sq < seqlen; sq++)
|
||||
{
|
||||
std::vector<CompDataType> locals;
|
||||
|
||||
// for all cols in the batch
|
||||
for(int sk = 0; sk < seqlen; sk++)
|
||||
{
|
||||
if(mask.IsTokenPairInsideMask(sq, sk))
|
||||
{
|
||||
GemmAccDataType dot_prod = 0.f;
|
||||
for(int k = 0; k < hdim_qk; k++)
|
||||
{
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
InOutDataType qreg = q_batch_seq_nhead_hdim(
|
||||
0, seq_offsets[i_batch] + sq, i_head, k);
|
||||
InOutDataType kreg = k_batch_seq_nhead_hdim(
|
||||
0, seq_offsets[i_batch] + sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(kreg);
|
||||
}
|
||||
else
|
||||
{
|
||||
InOutDataType qreg =
|
||||
q_batch_seq_nhead_hdim(i_batch, sq, i_head, k);
|
||||
InOutDataType kreg =
|
||||
k_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(kreg);
|
||||
};
|
||||
}
|
||||
|
||||
locals.push_back(ck_tile::type_convert<CompDataType>(dot_prod) *
|
||||
ck_tile::type_convert<CompDataType>(alpha));
|
||||
}
|
||||
else
|
||||
locals.push_back(ck_tile::type_convert<CompDataType>(0.0f));
|
||||
};
|
||||
|
||||
// SiLu element-wise
|
||||
for(CompDataType& elem : locals)
|
||||
elem = silu(elem);
|
||||
|
||||
// second Gemm
|
||||
for(int k = 0; k < hdim_v; k++)
|
||||
{
|
||||
GemmAccDataType dot_prod = 0.f;
|
||||
for(int k = 0; k < hdim_qk; k++)
|
||||
|
||||
for(int sk = 0; sk < seqlen; sk++)
|
||||
{
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
InOutDataType qreg =
|
||||
q_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k);
|
||||
InOutDataType kreg =
|
||||
k_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
|
||||
InOutDataType preg =
|
||||
ck_tile::type_convert<InOutDataType>(locals[sk]);
|
||||
InOutDataType vreg =
|
||||
v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(kreg);
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(vreg);
|
||||
}
|
||||
else
|
||||
{
|
||||
InOutDataType qreg = q_batch_seq_nhead_hdim(i_batch, sq, i_head, k);
|
||||
InOutDataType kreg = k_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
|
||||
InOutDataType preg =
|
||||
ck_tile::type_convert<InOutDataType>(locals[sk]);
|
||||
InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(kreg);
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(vreg);
|
||||
};
|
||||
}
|
||||
|
||||
locals.push_back(ck_tile::type_convert<CompDataType>(dot_prod) *
|
||||
ck_tile::type_convert<CompDataType>(alpha));
|
||||
}
|
||||
else
|
||||
locals.push_back(ck_tile::type_convert<CompDataType>(0.0f));
|
||||
};
|
||||
|
||||
// SiLu element-wise
|
||||
for(CompDataType& elem : locals)
|
||||
elem = silu(elem);
|
||||
|
||||
float scale_p = attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen);
|
||||
|
||||
// second Gemm
|
||||
for(int k = 0; k < hdim_v; k++)
|
||||
{
|
||||
GemmAccDataType dot_prod = 0.f;
|
||||
|
||||
for(int sk = 0; sk < seqlen; sk++)
|
||||
{
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
|
||||
InOutDataType vreg =
|
||||
v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(vreg);
|
||||
}
|
||||
else
|
||||
{
|
||||
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
|
||||
InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(vreg);
|
||||
};
|
||||
|
||||
dot_prod = dot_prod * ck_tile::type_convert<GemmAccDataType>(scale_p);
|
||||
|
||||
if constexpr(kIsJagged)
|
||||
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =
|
||||
ck_tile::type_convert<InOutDataType>(dot_prod);
|
||||
else
|
||||
o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) =
|
||||
ck_tile::type_convert<InOutDataType>(dot_prod);
|
||||
};
|
||||
|
||||
dot_prod = dot_prod * ck_tile::type_convert<GemmAccDataType>(scale_p);
|
||||
|
||||
if constexpr(kIsJagged)
|
||||
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =
|
||||
ck_tile::type_convert<InOutDataType>(dot_prod);
|
||||
else
|
||||
o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) =
|
||||
ck_tile::type_convert<InOutDataType>(dot_prod);
|
||||
};
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, num_batch, num_head)(std::thread::hardware_concurrency());
|
||||
|
||||
Reference in New Issue
Block a user