mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Add example parameter alpha to ease the testing
This commit is contained in:
@@ -54,6 +54,7 @@
|
||||
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("alpha", "0", "scale factor of P=Q@K. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
|
||||
.insert("perf", "0", "weather measure execution time or not");
|
||||
.insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");
|
||||
|
||||
@@ -108,6 +108,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("alpha", "0", "scale factor of P=Q@K. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
|
||||
.insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes")
|
||||
.insert("perf", "0", "weather measure execution time or not")
|
||||
@@ -221,6 +222,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int contextual_seqlen = arg_parser.get_int("context_len");
|
||||
int min_full_attn_seqlen = arg_parser.get_int("minfull_len");
|
||||
|
||||
float alpha = arg_parser.get_float("alpha");
|
||||
int seed = arg_parser.get_int("seed");
|
||||
bool measure_perf = static_cast<bool>(arg_parser.get_int("perf"));
|
||||
bool dump_output = static_cast<bool>(arg_parser.get_int("dump_output"));
|
||||
@@ -391,7 +393,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
HstuAttentionFwdParams params;
|
||||
|
||||
float scale_s = 1.0f / std::sqrt(hdim_qk);
|
||||
float scale_s = (alpha != 0.f) ? alpha : 1.0f / std::sqrt(hdim_qk);
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user