diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 4c5c50ea44..ef4981ecb9 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -91,8 +91,8 @@ struct HstuAttentionFwdKernel ck_tile::index_t seq_stride_o; ck_tile::index_t num_head; - float scale_s; - ck_tile::index_t max_seqlen; + float scale_s; // scaling value exerted on the immediate Q@K result + float scale_p; // scaling value exerted on the SiLU result ck_tile::index_t contextual_seqlen; }; @@ -124,8 +124,8 @@ struct HstuAttentionFwdKernel ck_tile::index_t seqlen; ck_tile::index_t num_head; - float scale_s; - ck_tile::index_t max_seqlen; + float scale_s; // scaling value exerted on the immediate Q@K result + float scale_p; // scaling value exerted on the SiLU result ck_tile::index_t contextual_seqlen; }; @@ -257,11 +257,11 @@ struct HstuAttentionFwdKernel seq_stride_o, num_head, -scale_s, - seqlen, // max_seqlen - contextual_seqlen}, // args for common karg - {}, // placeholder for mask - {}, // placeholder for bias - {}, // placeholder for dropout + 1.0f / static_cast(seqlen), // max_seqlen + contextual_seqlen}, // args for common karg + {}, // placeholder for mask + {}, // placeholder for bias + {}, // placeholder for dropout }; if constexpr(kHasLocalMask) @@ -404,7 +404,7 @@ struct HstuAttentionFwdKernel -1, // seqlen will be updated by another pointer num_head, -scale_s, - max_seqlen, + 1.0f / static_cast(max_seqlen), contextual_seqlen}, // args for common karg {}, // placeholder for mask {}, // placeholder for bias @@ -850,7 +850,7 @@ struct HstuAttentionFwdKernel bias_dram_window, mask, kargs.scale_s, - kargs.max_seqlen, + kargs.scale_p, smem_ptr, dropout); }(); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 31049ca955..1a01151972 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -137,8 +137,8 @@ struct HstuAttentionFwdPipelineQRKSVS const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, HstuMask mask, - float scale_s, - index_t max_seqlen, + float scale_s, // scaling value exerted on the immediate Q@K result + float scale_p, // scaling value exerted on the SiLu result void* smem_ptr, DropoutType& dropout) const { @@ -569,12 +569,8 @@ struct HstuAttentionFwdPipelineQRKSVS }); } while(seqlen_k_curr < seqlen_k_end); - tile_elementwise_inout( - [&](auto& x) { - x = x * type_convert( - __builtin_amdgcn_rcpf(static_cast(max_seqlen))); - }, - o_acc); + tile_elementwise_inout([&](auto& x) { x = x * type_convert(scale_p); }, + o_acc); o_acc = tile_elementwise_in(o_acc_element_func, o_acc); @@ -591,8 +587,8 @@ struct HstuAttentionFwdPipelineQRKSVS const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile HstuMask mask, - float scale_s, - int max_seqlen, + float scale_s, // scaling value exerted on the immediate Q@K result + float scale_p, // scaling value exerted on the SiLU result void* smem_ptr, DropoutType& dropout) const { @@ -609,7 +605,7 @@ struct HstuAttentionFwdPipelineQRKSVS identity{}, mask, scale_s, - max_seqlen, + scale_p, smem_ptr, dropout); }