Add attn_scale MakeKargs() parameter support and update in example, reference codes

This commit is contained in:
Qianfeng Zhang
2025-08-03 03:33:08 +00:00
parent 7c9032d2cf
commit f27d8cefb7
8 changed files with 44 additions and 21 deletions

View File

@@ -54,7 +54,8 @@
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("alpha", "0", "scale factor of P=Q@K. 0 means equal to 1/sqrt(hdim)")
.insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)")
.insert("attn_scale", "0", "scale factor of SiLu(Q@K), 0 means using 1/max_seqlen for scaling")
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
.insert("perf", "0", "weather measure execution time or not");
.insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");

View File

@@ -108,7 +108,8 @@ auto create_args(int argc, char* argv[])
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("alpha", "0", "scale factor of P=Q@K. 0 means equal to 1/sqrt(hdim)")
.insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)")
.insert("attn_scale", "0", "scale factor of SiLU(Q@K). 0 means using 1/max_seqlen for scaling")
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
.insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes")
.insert("perf", "0", "weather measure execution time or not")
@@ -223,6 +224,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int min_full_attn_seqlen = arg_parser.get_int("minfull_len");
float alpha = arg_parser.get_float("alpha");
float attn_scale = arg_parser.get_float("attn_scale");
int seed = arg_parser.get_int("seed");
bool measure_perf = static_cast<bool>(arg_parser.get_int("perf"));
bool dump_output = static_cast<bool>(arg_parser.get_int("dump_output"));
@@ -412,6 +414,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
params.hdim_v = hdim_v;
params.num_head = num_head;
params.scale_s = scale_s;
params.attn_scale = attn_scale;
params.seq_stride_q = q_host.get_strides()[1];
params.seq_stride_k = k_host.get_strides()[1];
params.seq_stride_v = v_host.get_strides()[1];
@@ -445,6 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
params.hdim_v = hdim_v;
params.num_head = num_head;
params.scale_s = scale_s;
params.attn_scale = attn_scale;
params.seq_stride_q = q_host.get_strides()[1];
params.seq_stride_k = k_host.get_strides()[1];
params.seq_stride_v = v_host.get_strides()[1];
@@ -514,6 +518,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask_host,
num_batch,
scale_s,
attn_scale,
max_seqlen,
seq_offsets,
num_targets,

View File

@@ -104,6 +104,7 @@ struct batched_forward_causal_local_bias_dropout_dispatch
param.hdim_v,
param.num_head,
param.scale_s,
param.attn_scale,
param.seq_stride_q,
param.seq_stride_k,
param.seq_stride_v,

View File

@@ -212,6 +212,7 @@ struct HstuAttentionFwdKernel
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
float scale_s,
float attn_scale,
ck_tile::index_t seq_stride_q,
ck_tile::index_t seq_stride_k,
ck_tile::index_t seq_stride_v,
@@ -257,11 +258,11 @@ struct HstuAttentionFwdKernel
seq_stride_o,
num_head,
-scale_s,
1.0f / static_cast<float>(seqlen), // max_seqlen
contextual_seqlen}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for bias
{}, // placeholder for dropout
attn_scale ? attn_scale : 1.0f / static_cast<float>(seqlen), // max_seqlen
contextual_seqlen}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for bias
{}, // placeholder for dropout
};
if constexpr(kHasLocalMask)
@@ -298,6 +299,7 @@ struct HstuAttentionFwdKernel
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
float scale_s,
float attn_scale,
ck_tile::index_t seq_stride_q,
ck_tile::index_t seq_stride_k,
ck_tile::index_t seq_stride_v,
@@ -331,6 +333,7 @@ struct HstuAttentionFwdKernel
hdim_v,
num_head,
scale_s,
attn_scale,
seq_stride_q,
seq_stride_k,
seq_stride_v,
@@ -367,6 +370,7 @@ struct HstuAttentionFwdKernel
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
float scale_s,
float attn_scale,
ck_tile::index_t seq_stride_q,
ck_tile::index_t seq_stride_k,
ck_tile::index_t seq_stride_v,
@@ -404,7 +408,7 @@ struct HstuAttentionFwdKernel
-1, // seqlen will be updated by another pointer
num_head,
-scale_s,
1.0f / static_cast<float>(max_seqlen),
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen),
contextual_seqlen}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for bias
@@ -445,6 +449,7 @@ struct HstuAttentionFwdKernel
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
float scale_s,
float attn_scale,
ck_tile::index_t seq_stride_q,
ck_tile::index_t seq_stride_k,
ck_tile::index_t seq_stride_v,
@@ -474,6 +479,7 @@ struct HstuAttentionFwdKernel
hdim_v,
num_head,
scale_s,
attn_scale,
seq_stride_q,
seq_stride_k,
seq_stride_v,

View File

@@ -97,6 +97,7 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
param.hdim_v,
param.num_head,
param.scale_s,
param.attn_scale,
param.seq_stride_q,
param.seq_stride_k,
param.seq_stride_v,

View File

@@ -23,7 +23,8 @@ struct HstuAttentionFwdParams
ck_tile::index_t hdim_qk;
ck_tile::index_t hdim_v;
ck_tile::index_t num_head;
float scale_s;
float scale_s; // scaling factor exerted on the immediate Q@K result
float attn_scale; // scaling factor exerted on the SiLU result
ck_tile::index_t seq_stride_q;
ck_tile::index_t seq_stride_k;

View File

@@ -43,6 +43,7 @@ struct reference_hstu_attention
HostTensor<int8_t>& mask_batch_nhead_seq_seq,
int num_batch,
float alpha,
float attn_scale,
int max_seqlen,
std::vector<int> seq_offsets,
std::vector<int> num_targets, // define masking length at the end of token
@@ -177,6 +178,8 @@ struct reference_hstu_attention
for(CompDataType& elem : locals)
elem = silu(elem);
float scale_p = attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen);
// second Gemm
for(int k = 0; k < hdim_v; k++)
{
@@ -203,7 +206,7 @@ struct reference_hstu_attention
};
};
dot_prod = dot_prod / ck_tile::type_convert<GemmAccDataType>(max_seqlen);
dot_prod = dot_prod * ck_tile::type_convert<GemmAccDataType>(scale_p);
if constexpr(kIsJagged)
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =

View File

@@ -3,41 +3,46 @@
BUILD=build
EXE=$BUILD/bin/tile_example_hstu_attention
attn_scale=0
if [ $# -ge 1 ]; then
attn_scale=$1
fi
for dtype in "fp16" "bf16"; do
set -x
## no masking batched
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale
## no masking jagged
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale
## batched causal
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale
## jagged causal
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale
## batched causal+local
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale
## jagged causal+local
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale
## batched causal+local+context
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale
## jagged causal+local+context
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale
## batched causal+local+context+target
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale
## jagged causal+local+context+target
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale
## jagged no-causal+local+context+target
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale
set +x
done