mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Add attn_scale MakeKargs() parameter support and update in example, reference codes
This commit is contained in:
@@ -54,7 +54,8 @@
|
||||
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("alpha", "0", "scale factor of P=Q@K. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("attn_scale", "0", "scale factor of SiLu(Q@K), 0 means using 1/max_seqlen for scaling")
|
||||
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
|
||||
.insert("perf", "0", "weather measure execution time or not");
|
||||
.insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");
|
||||
|
||||
@@ -108,7 +108,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("alpha", "0", "scale factor of P=Q@K. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("attn_scale", "0", "scale factor of SiLU(Q@K). 0 means using 1/max_seqlen for scaling")
|
||||
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
|
||||
.insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes")
|
||||
.insert("perf", "0", "weather measure execution time or not")
|
||||
@@ -223,6 +224,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int min_full_attn_seqlen = arg_parser.get_int("minfull_len");
|
||||
|
||||
float alpha = arg_parser.get_float("alpha");
|
||||
float attn_scale = arg_parser.get_float("attn_scale");
|
||||
int seed = arg_parser.get_int("seed");
|
||||
bool measure_perf = static_cast<bool>(arg_parser.get_int("perf"));
|
||||
bool dump_output = static_cast<bool>(arg_parser.get_int("dump_output"));
|
||||
@@ -412,6 +414,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
params.hdim_v = hdim_v;
|
||||
params.num_head = num_head;
|
||||
params.scale_s = scale_s;
|
||||
params.attn_scale = attn_scale;
|
||||
params.seq_stride_q = q_host.get_strides()[1];
|
||||
params.seq_stride_k = k_host.get_strides()[1];
|
||||
params.seq_stride_v = v_host.get_strides()[1];
|
||||
@@ -445,6 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
params.hdim_v = hdim_v;
|
||||
params.num_head = num_head;
|
||||
params.scale_s = scale_s;
|
||||
params.attn_scale = attn_scale;
|
||||
params.seq_stride_q = q_host.get_strides()[1];
|
||||
params.seq_stride_k = k_host.get_strides()[1];
|
||||
params.seq_stride_v = v_host.get_strides()[1];
|
||||
@@ -514,6 +518,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
mask_host,
|
||||
num_batch,
|
||||
scale_s,
|
||||
attn_scale,
|
||||
max_seqlen,
|
||||
seq_offsets,
|
||||
num_targets,
|
||||
|
||||
@@ -104,6 +104,7 @@ struct batched_forward_causal_local_bias_dropout_dispatch
|
||||
param.hdim_v,
|
||||
param.num_head,
|
||||
param.scale_s,
|
||||
param.attn_scale,
|
||||
param.seq_stride_q,
|
||||
param.seq_stride_k,
|
||||
param.seq_stride_v,
|
||||
|
||||
@@ -212,6 +212,7 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
float scale_s,
|
||||
float attn_scale,
|
||||
ck_tile::index_t seq_stride_q,
|
||||
ck_tile::index_t seq_stride_k,
|
||||
ck_tile::index_t seq_stride_v,
|
||||
@@ -257,11 +258,11 @@ struct HstuAttentionFwdKernel
|
||||
seq_stride_o,
|
||||
num_head,
|
||||
-scale_s,
|
||||
1.0f / static_cast<float>(seqlen), // max_seqlen
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(seqlen), // max_seqlen
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
};
|
||||
|
||||
if constexpr(kHasLocalMask)
|
||||
@@ -298,6 +299,7 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
float scale_s,
|
||||
float attn_scale,
|
||||
ck_tile::index_t seq_stride_q,
|
||||
ck_tile::index_t seq_stride_k,
|
||||
ck_tile::index_t seq_stride_v,
|
||||
@@ -331,6 +333,7 @@ struct HstuAttentionFwdKernel
|
||||
hdim_v,
|
||||
num_head,
|
||||
scale_s,
|
||||
attn_scale,
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
@@ -367,6 +370,7 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
float scale_s,
|
||||
float attn_scale,
|
||||
ck_tile::index_t seq_stride_q,
|
||||
ck_tile::index_t seq_stride_k,
|
||||
ck_tile::index_t seq_stride_v,
|
||||
@@ -404,7 +408,7 @@ struct HstuAttentionFwdKernel
|
||||
-1, // seqlen will be updated by another pointer
|
||||
num_head,
|
||||
-scale_s,
|
||||
1.0f / static_cast<float>(max_seqlen),
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen),
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for bias
|
||||
@@ -445,6 +449,7 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
float scale_s,
|
||||
float attn_scale,
|
||||
ck_tile::index_t seq_stride_q,
|
||||
ck_tile::index_t seq_stride_k,
|
||||
ck_tile::index_t seq_stride_v,
|
||||
@@ -474,6 +479,7 @@ struct HstuAttentionFwdKernel
|
||||
hdim_v,
|
||||
num_head,
|
||||
scale_s,
|
||||
attn_scale,
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
|
||||
@@ -97,6 +97,7 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
|
||||
param.hdim_v,
|
||||
param.num_head,
|
||||
param.scale_s,
|
||||
param.attn_scale,
|
||||
param.seq_stride_q,
|
||||
param.seq_stride_k,
|
||||
param.seq_stride_v,
|
||||
|
||||
@@ -23,7 +23,8 @@ struct HstuAttentionFwdParams
|
||||
ck_tile::index_t hdim_qk;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t num_head;
|
||||
float scale_s;
|
||||
float scale_s; // scaling factor exerted on the immediate Q@K result
|
||||
float attn_scale; // scaling factor exerted on the SiLU result
|
||||
|
||||
ck_tile::index_t seq_stride_q;
|
||||
ck_tile::index_t seq_stride_k;
|
||||
|
||||
@@ -43,6 +43,7 @@ struct reference_hstu_attention
|
||||
HostTensor<int8_t>& mask_batch_nhead_seq_seq,
|
||||
int num_batch,
|
||||
float alpha,
|
||||
float attn_scale,
|
||||
int max_seqlen,
|
||||
std::vector<int> seq_offsets,
|
||||
std::vector<int> num_targets, // define masking length at the end of token
|
||||
@@ -177,6 +178,8 @@ struct reference_hstu_attention
|
||||
for(CompDataType& elem : locals)
|
||||
elem = silu(elem);
|
||||
|
||||
float scale_p = attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen);
|
||||
|
||||
// second Gemm
|
||||
for(int k = 0; k < hdim_v; k++)
|
||||
{
|
||||
@@ -203,7 +206,7 @@ struct reference_hstu_attention
|
||||
};
|
||||
};
|
||||
|
||||
dot_prod = dot_prod / ck_tile::type_convert<GemmAccDataType>(max_seqlen);
|
||||
dot_prod = dot_prod * ck_tile::type_convert<GemmAccDataType>(scale_p);
|
||||
|
||||
if constexpr(kIsJagged)
|
||||
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =
|
||||
|
||||
@@ -3,41 +3,46 @@
|
||||
BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_hstu_attention
|
||||
|
||||
attn_scale=0
|
||||
if [ $# -ge 1 ]; then
|
||||
attn_scale=$1
|
||||
fi
|
||||
|
||||
for dtype in "fp16" "bf16"; do
|
||||
set -x
|
||||
|
||||
## no masking batched
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale
|
||||
|
||||
## no masking jagged
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale
|
||||
|
||||
## batched causal
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale
|
||||
|
||||
## jagged causal
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale
|
||||
|
||||
## batched causal+local
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale
|
||||
|
||||
## jagged causal+local
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale
|
||||
|
||||
## batched causal+local+context
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale
|
||||
|
||||
## jagged causal+local+context
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale
|
||||
|
||||
## batched causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale
|
||||
|
||||
## jagged causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale
|
||||
|
||||
## jagged no-causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale
|
||||
|
||||
set +x
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user