Add example parameter alpha to ease the testing

This commit is contained in:
Qianfeng Zhang
2025-05-30 08:47:55 +00:00
parent 781cba355a
commit 832747c58d
2 changed files with 4 additions and 1 deletions

View File

@@ -54,6 +54,7 @@
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("alpha", "0", "scale factor of P=Q@K. 0 means equal to 1/sqrt(hdim)")
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
.insert("perf", "0", "weather measure execution time or not");
.insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");

View File

@@ -108,6 +108,7 @@ auto create_args(int argc, char* argv[])
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("alpha", "0", "scale factor of P=Q@K. 0 means equal to 1/sqrt(hdim)")
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
.insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes")
.insert("perf", "0", "weather measure execution time or not")
@@ -221,6 +222,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int contextual_seqlen = arg_parser.get_int("context_len");
int min_full_attn_seqlen = arg_parser.get_int("minfull_len");
float alpha = arg_parser.get_float("alpha");
int seed = arg_parser.get_int("seed");
bool measure_perf = static_cast<bool>(arg_parser.get_int("perf"));
bool dump_output = static_cast<bool>(arg_parser.get_int("dump_output"));
@@ -391,7 +393,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
HstuAttentionFwdParams params;
float scale_s = 1.0f / std::sqrt(hdim_qk);
float scale_s = (alpha != 0.f) ? alpha : 1.0f / std::sqrt(hdim_qk);
if(is_jagged)
{