diff --git a/example/ck_tile/18_hstu_attention/README.md b/example/ck_tile/18_hstu_attention/README.md index bade486980..6e1f978657 100644 --- a/example/ck_tile/18_hstu_attention/README.md +++ b/example/ck_tile/18_hstu_attention/README.md @@ -54,6 +54,7 @@ .insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention") .insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data") .insert("seed", "13579", "seed by the uniform or normal distribution generator") + .insert("alpha", "0", "scale factor of P=Q@K. 0 means equal to 1/sqrt(hdim)") .insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes") .insert("perf", "0", "weather measure execution time or not"); .insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true"); diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index 48fcba02ff..69e9530bc8 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -108,6 +108,7 @@ auto create_args(int argc, char* argv[]) .insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention") .insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention") .insert("seed", "13579", "seed by the uniform or normal distribution generator") + .insert("alpha", "0", "scale factor of P=Q@K. 0 means equal to 1/sqrt(hdim)") .insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data") .insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes") .insert("perf", "0", "weather measure execution time or not") @@ -221,6 +222,7 @@ bool run(const ck_tile::ArgParser& arg_parser) int contextual_seqlen = arg_parser.get_int("context_len"); int min_full_attn_seqlen = arg_parser.get_int("minfull_len"); + float alpha = arg_parser.get_float("alpha"); int seed = arg_parser.get_int("seed"); bool measure_perf = static_cast(arg_parser.get_int("perf")); bool dump_output = static_cast(arg_parser.get_int("dump_output")); @@ -391,7 +393,7 @@ bool run(const ck_tile::ArgParser& arg_parser) HstuAttentionFwdParams params; - float scale_s = 1.0f / std::sqrt(hdim_qk); + float scale_s = (alpha != 0.f) ? alpha : 1.0f / std::sqrt(hdim_qk); if(is_jagged) {