mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Update to partially reduce the register spilling
This commit is contained in:
@@ -543,13 +543,13 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
HstuMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return HstuMask{kargs.window_size,
|
||||
return HstuMask{kargs.seqlen,
|
||||
kargs.contextual_seqlen,
|
||||
kargs.min_full_attn_seqlen,
|
||||
kargs.seqlen,
|
||||
num_target};
|
||||
num_target,
|
||||
kargs.window_size,
|
||||
kargs.min_full_attn_seqlen};
|
||||
else
|
||||
return HstuMask{0, kargs.contextual_seqlen, 0, kargs.seqlen, num_target};
|
||||
return HstuMask{kargs.seqlen, kargs.contextual_seqlen, num_target};
|
||||
}();
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
|
||||
@@ -248,12 +248,10 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
auto s_acc = SaccBlockTileType{};
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [](CompDataType x) {
|
||||
const auto f_silu = [](CompDataType& x) {
|
||||
auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
|
||||
auto sigmod_val = one / (one + exp(-x));
|
||||
|
||||
return sigmod_val * x;
|
||||
return x = x / (one + exp(-x));
|
||||
};
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
@@ -405,7 +403,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
auto s = cast_tile<CompDataType>(s_acc); // S{j}
|
||||
|
||||
s = tile_elementwise_in(f_silu, s);
|
||||
tile_elementwise_inout(f_silu, s);
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
|
||||
@@ -12,22 +12,32 @@ struct HstuBlockMasking
|
||||
{
|
||||
static constexpr bool IsMasking = (kUseCausal || kUseLocal);
|
||||
|
||||
int max_attn_len;
|
||||
int contextual_seqlen;
|
||||
int min_full_attn_seqlen;
|
||||
int max_uih_len;
|
||||
|
||||
CK_TILE_HOST_DEVICE HstuBlockMasking(int max_attn_len_,
|
||||
int max_attn_len;
|
||||
int min_full_attn_seqlen;
|
||||
|
||||
CK_TILE_HOST_DEVICE HstuBlockMasking(int seqlen_,
|
||||
int contextual_seqlen_,
|
||||
int min_full_attn_seqlen_,
|
||||
int seqlen_,
|
||||
int num_target)
|
||||
int num_target,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_)
|
||||
{
|
||||
max_uih_len = seqlen_;
|
||||
contextual_seqlen = contextual_seqlen_;
|
||||
|
||||
max_attn_len = max_attn_len_;
|
||||
contextual_seqlen = contextual_seqlen_;
|
||||
min_full_attn_seqlen = min_full_attn_seqlen_;
|
||||
|
||||
max_uih_len = seqlen_;
|
||||
max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0;
|
||||
max_uih_len -= num_target;
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE HstuBlockMasking(int seqlen_, int contextual_seqlen_, int num_target)
|
||||
{
|
||||
max_uih_len = seqlen_;
|
||||
contextual_seqlen = contextual_seqlen_;
|
||||
|
||||
max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0;
|
||||
max_uih_len -= num_target;
|
||||
|
||||
@@ -106,9 +106,9 @@ struct reference_hstu_attention
|
||||
HstuMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return HstuMask{
|
||||
max_attn_len, contextual_seqlen, min_full_attn_seqlen, seqlen, num_target};
|
||||
seqlen, contextual_seqlen, num_target, max_attn_len, min_full_attn_seqlen};
|
||||
else
|
||||
return HstuMask{0, contextual_seqlen, 0, seqlen, num_target};
|
||||
return HstuMask{seqlen, contextual_seqlen, num_target};
|
||||
}();
|
||||
|
||||
// for all rows in the batch
|
||||
|
||||
Reference in New Issue
Block a user