Move scaling by attn_scale to inside the main-loop

This commit is contained in:
Qianfeng Zhang
2025-10-15 09:24:44 +00:00
parent bbda3f6f1c
commit 1a8f2f21fb
2 changed files with 4 additions and 6 deletions

View File

@@ -450,6 +450,9 @@ struct HstuAttentionFwdPipelineQRKSVS
tile_elementwise_inout(f_silu, pcomp_tile);
tile_elementwise_inout([&](auto& x) { x = x * type_convert<CompDataType>(scale_p); },
pcomp_tile);
seqlen_k_curr += kN0;
if constexpr(kHasDropout)
@@ -507,9 +510,6 @@ struct HstuAttentionFwdPipelineQRKSVS
};
} while(seqlen_k_curr < seqlen_k_end);
tile_elementwise_inout([&](auto& x) { x = x * type_convert<GemmAccDataType>(scale_p); },
o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;

View File

@@ -191,7 +191,7 @@ struct reference_hstu_attention
// SiLu element-wise
for(CompDataType& elem : locals)
elem = silu(elem);
elem = silu(elem) * ck_tile::type_convert<CompDataType>(scale_p);
// second Gemm
for(int k = 0; k < hdim_v; k++)
@@ -221,8 +221,6 @@ struct reference_hstu_attention
};
};
dot_prod = dot_prod * ck_tile::type_convert<GemmAccDataType>(scale_p);
if constexpr(kIsJagged)
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =
ck_tile::type_convert<InOutDataType>(dot_prod);