diff --git a/example/ck_tile/18_hstu_attention/README.md b/example/ck_tile/18_hstu_attention/README.md index 6e1f978657..0d844822b6 100644 --- a/example/ck_tile/18_hstu_attention/README.md +++ b/example/ck_tile/18_hstu_attention/README.md @@ -54,7 +54,8 @@ .insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention") .insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data") .insert("seed", "13579", "seed by the uniform or normal distribution generator") - .insert("alpha", "0", "scale factor of P=Q@K. 0 means equal to 1/sqrt(hdim)") + .insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)") + .insert("attn_scale", "0", "scale factor of SiLu(Q@K), 0 means using 1/max_seqlen for scaling") .insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes") .insert("perf", "0", "weather measure execution time or not"); .insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true"); diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index f50f7d8d88..f1d0e5c461 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -108,7 +108,8 @@ auto create_args(int argc, char* argv[]) .insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention") .insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention") .insert("seed", "13579", "seed by the uniform or normal distribution generator") - .insert("alpha", "0", "scale factor of P=Q@K. 0 means equal to 1/sqrt(hdim)") + .insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)") + .insert("attn_scale", "0", "scale factor of SiLU(Q@K). 0 means using 1/max_seqlen for scaling") .insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data") .insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes") .insert("perf", "0", "weather measure execution time or not") @@ -223,6 +224,7 @@ bool run(const ck_tile::ArgParser& arg_parser) int min_full_attn_seqlen = arg_parser.get_int("minfull_len"); float alpha = arg_parser.get_float("alpha"); + float attn_scale = arg_parser.get_float("attn_scale"); int seed = arg_parser.get_int("seed"); bool measure_perf = static_cast(arg_parser.get_int("perf")); bool dump_output = static_cast(arg_parser.get_int("dump_output")); @@ -412,6 +414,7 @@ bool run(const ck_tile::ArgParser& arg_parser) params.hdim_v = hdim_v; params.num_head = num_head; params.scale_s = scale_s; + params.attn_scale = attn_scale; params.seq_stride_q = q_host.get_strides()[1]; params.seq_stride_k = k_host.get_strides()[1]; params.seq_stride_v = v_host.get_strides()[1]; @@ -445,6 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser) params.hdim_v = hdim_v; params.num_head = num_head; params.scale_s = scale_s; + params.attn_scale = attn_scale; params.seq_stride_q = q_host.get_strides()[1]; params.seq_stride_k = k_host.get_strides()[1]; params.seq_stride_v = v_host.get_strides()[1]; @@ -514,6 +518,7 @@ bool run(const ck_tile::ArgParser& arg_parser) mask_host, num_batch, scale_s, + attn_scale, max_seqlen, seq_offsets, num_targets, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 36bb6b261d..c39ed59ca7 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -104,6 +104,7 @@ struct batched_forward_causal_local_bias_dropout_dispatch param.hdim_v, param.num_head, param.scale_s, + param.attn_scale, param.seq_stride_q, param.seq_stride_k, param.seq_stride_v, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index ef4981ecb9..dae068448a 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -212,6 +212,7 @@ struct HstuAttentionFwdKernel ck_tile::index_t hdim_v, ck_tile::index_t num_head, float scale_s, + float attn_scale, ck_tile::index_t seq_stride_q, ck_tile::index_t seq_stride_k, ck_tile::index_t seq_stride_v, @@ -257,11 +258,11 @@ struct HstuAttentionFwdKernel seq_stride_o, num_head, -scale_s, - 1.0f / static_cast(seqlen), // max_seqlen - contextual_seqlen}, // args for common karg - {}, // placeholder for mask - {}, // placeholder for bias - {}, // placeholder for dropout + attn_scale ? attn_scale : 1.0f / static_cast(seqlen), // max_seqlen + contextual_seqlen}, // args for common karg + {}, // placeholder for mask + {}, // placeholder for bias + {}, // placeholder for dropout }; if constexpr(kHasLocalMask) @@ -298,6 +299,7 @@ struct HstuAttentionFwdKernel ck_tile::index_t hdim_v, ck_tile::index_t num_head, float scale_s, + float attn_scale, ck_tile::index_t seq_stride_q, ck_tile::index_t seq_stride_k, ck_tile::index_t seq_stride_v, @@ -331,6 +333,7 @@ struct HstuAttentionFwdKernel hdim_v, num_head, scale_s, + attn_scale, seq_stride_q, seq_stride_k, seq_stride_v, @@ -367,6 +370,7 @@ struct HstuAttentionFwdKernel ck_tile::index_t hdim_v, ck_tile::index_t num_head, float scale_s, + float attn_scale, ck_tile::index_t seq_stride_q, ck_tile::index_t seq_stride_k, ck_tile::index_t seq_stride_v, @@ -404,7 +408,7 @@ struct HstuAttentionFwdKernel -1, // seqlen will be updated by another pointer num_head, -scale_s, - 1.0f / static_cast(max_seqlen), + attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen), contextual_seqlen}, // args for common karg {}, // placeholder for mask {}, // placeholder for bias @@ -445,6 +449,7 @@ struct HstuAttentionFwdKernel ck_tile::index_t hdim_v, ck_tile::index_t num_head, float scale_s, + float attn_scale, ck_tile::index_t seq_stride_q, ck_tile::index_t seq_stride_k, ck_tile::index_t seq_stride_v, @@ -474,6 +479,7 @@ struct HstuAttentionFwdKernel hdim_v, num_head, scale_s, + attn_scale, seq_stride_q, seq_stride_k, seq_stride_v, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index 676ecc3e50..e3f6e00f79 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -97,6 +97,7 @@ struct jagged_forward_causal_local_bias_dropout_dispatch param.hdim_v, param.num_head, param.scale_s, + param.attn_scale, param.seq_stride_q, param.seq_stride_k, param.seq_stride_v, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp index d76c8857e4..168d3b4781 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp @@ -23,7 +23,8 @@ struct HstuAttentionFwdParams ck_tile::index_t hdim_qk; ck_tile::index_t hdim_v; ck_tile::index_t num_head; - float scale_s; + float scale_s; // scaling factor exerted on the immediate Q@K result + float attn_scale; // scaling factor exerted on the SiLU result ck_tile::index_t seq_stride_q; ck_tile::index_t seq_stride_k; diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 645f53aebd..96324946b2 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -43,6 +43,7 @@ struct reference_hstu_attention HostTensor& mask_batch_nhead_seq_seq, int num_batch, float alpha, + float attn_scale, int max_seqlen, std::vector seq_offsets, std::vector num_targets, // define masking length at the end of token @@ -177,6 +178,8 @@ struct reference_hstu_attention for(CompDataType& elem : locals) elem = silu(elem); + float scale_p = attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen); + // second Gemm for(int k = 0; k < hdim_v; k++) { @@ -203,7 +206,7 @@ struct reference_hstu_attention }; }; - dot_prod = dot_prod / ck_tile::type_convert(max_seqlen); + dot_prod = dot_prod * ck_tile::type_convert(scale_p); if constexpr(kIsJagged) o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) = diff --git a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh index 26bf65c40a..7f1fd3f4ca 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh @@ -3,41 +3,46 @@ BUILD=build EXE=$BUILD/bin/tile_example_hstu_attention +attn_scale=0 +if [ $# -ge 1 ]; then + attn_scale=$1 +fi + for dtype in "fp16" "bf16"; do set -x ## no masking batched - $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale ## no masking jagged - $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale ## batched causal - $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale ## jagged causal - $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale ## batched causal+local - $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale ## jagged causal+local - $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale ## batched causal+local+context - $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale ## jagged causal+local+context - $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale ## batched causal+local+context+target - $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale ## jagged causal+local+context+target - $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale ## jagged no-causal+local+context+target - $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale set +x done