mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Move scaling by attn_scale to inside the main-loop
This commit is contained in:
@@ -450,6 +450,9 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tile);
|
||||
|
||||
tile_elementwise_inout([&](auto& x) { x = x * type_convert<CompDataType>(scale_p); },
|
||||
pcomp_tile);
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
@@ -507,9 +510,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
};
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
tile_elementwise_inout([&](auto& x) { x = x * type_convert<GemmAccDataType>(scale_p); },
|
||||
o_acc);
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
|
||||
@@ -191,7 +191,7 @@ struct reference_hstu_attention
|
||||
|
||||
// SiLu element-wise
|
||||
for(CompDataType& elem : locals)
|
||||
elem = silu(elem);
|
||||
elem = silu(elem) * ck_tile::type_convert<CompDataType>(scale_p);
|
||||
|
||||
// second Gemm
|
||||
for(int k = 0; k < hdim_v; k++)
|
||||
@@ -221,8 +221,6 @@ struct reference_hstu_attention
|
||||
};
|
||||
};
|
||||
|
||||
dot_prod = dot_prod * ck_tile::type_convert<GemmAccDataType>(scale_p);
|
||||
|
||||
if constexpr(kIsJagged)
|
||||
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =
|
||||
ck_tile::type_convert<InOutDataType>(dot_prod);
|
||||
|
||||
Reference in New Issue
Block a user