Update to partially reduce the register spilling

This commit is contained in:
Qianfeng Zhang
2025-04-15 07:22:09 +00:00
parent c2e6ab8516
commit fff13b6c76
4 changed files with 28 additions and 20 deletions

View File

@@ -543,13 +543,13 @@ struct HstuAttentionFwdKernel
HstuMask mask = [&]() {
if constexpr(kHasMask)
return HstuMask{kargs.window_size,
return HstuMask{kargs.seqlen,
kargs.contextual_seqlen,
kargs.min_full_attn_seqlen,
kargs.seqlen,
num_target};
num_target,
kargs.window_size,
kargs.min_full_attn_seqlen};
else
return HstuMask{0, kargs.contextual_seqlen, 0, kargs.seqlen, num_target};
return HstuMask{kargs.seqlen, kargs.contextual_seqlen, num_target};
}();
// for simplicity, batch stride we just modify the pointer

View File

@@ -248,12 +248,10 @@ struct HstuAttentionFwdPipelineQRKSVS
auto s_acc = SaccBlockTileType{};
// reduction function for softmax
const auto f_silu = [](CompDataType x) {
const auto f_silu = [](CompDataType& x) {
auto one = ck_tile::type_convert<CompDataType>(1.0f);
auto sigmod_val = one / (one + exp(-x));
return sigmod_val * x;
return x = x / (one + exp(-x));
};
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
@@ -405,7 +403,7 @@ struct HstuAttentionFwdPipelineQRKSVS
auto s = cast_tile<CompDataType>(s_acc); // S{j}
s = tile_elementwise_in(f_silu, s);
tile_elementwise_inout(f_silu, s);
if constexpr(kHasDropout)
{

View File

@@ -12,22 +12,32 @@ struct HstuBlockMasking
{
static constexpr bool IsMasking = (kUseCausal || kUseLocal);
int max_attn_len;
int contextual_seqlen;
int min_full_attn_seqlen;
int max_uih_len;
CK_TILE_HOST_DEVICE HstuBlockMasking(int max_attn_len_,
int max_attn_len;
int min_full_attn_seqlen;
CK_TILE_HOST_DEVICE HstuBlockMasking(int seqlen_,
int contextual_seqlen_,
int min_full_attn_seqlen_,
int seqlen_,
int num_target)
int num_target,
int max_attn_len_,
int min_full_attn_seqlen_)
{
max_uih_len = seqlen_;
contextual_seqlen = contextual_seqlen_;
max_attn_len = max_attn_len_;
contextual_seqlen = contextual_seqlen_;
min_full_attn_seqlen = min_full_attn_seqlen_;
max_uih_len = seqlen_;
max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0;
max_uih_len -= num_target;
};
CK_TILE_HOST_DEVICE HstuBlockMasking(int seqlen_, int contextual_seqlen_, int num_target)
{
max_uih_len = seqlen_;
contextual_seqlen = contextual_seqlen_;
max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0;
max_uih_len -= num_target;

View File

@@ -106,9 +106,9 @@ struct reference_hstu_attention
HstuMask mask = [&]() {
if constexpr(kHasMask)
return HstuMask{
max_attn_len, contextual_seqlen, min_full_attn_seqlen, seqlen, num_target};
seqlen, contextual_seqlen, num_target, max_attn_len, min_full_attn_seqlen};
else
return HstuMask{0, contextual_seqlen, 0, seqlen, num_target};
return HstuMask{seqlen, contextual_seqlen, num_target};
}();
// for all rows in the batch