Add support of softmax in hstu attention

This commit is contained in:
Qianfeng Zhang
2025-10-16 16:02:45 +00:00
parent a874839dc2
commit d1505786f8
11 changed files with 311 additions and 52 deletions

View File

@@ -48,6 +48,7 @@
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
.insert("softmax", "0", "use softmax or not")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")

View File

@@ -104,6 +104,7 @@ auto create_args(int argc, char* argv[])
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
.insert("softmax", "0", "use softmax or not")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
@@ -216,6 +217,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int num_head = arg_parser.get_int("nhead");
int hdim_qk = arg_parser.get_int("hdim_qk");
int hdim_v = arg_parser.get_int("hdim_v");
bool use_softmax = static_cast<bool>(arg_parser.get_int("softmax"));
bool use_causal = static_cast<bool>(arg_parser.get_int("causal"));
int window_size = arg_parser.get_int("local_len");
@@ -434,6 +436,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
params.nhead_stride_bias = 0;
params.nhead_stride_o = o_host_ref.get_strides()[2];
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
params.use_softmax = use_softmax;
params.use_causal = use_causal;
params.window_size = window_size;
params.contextual_seqlen = contextual_seqlen;
@@ -473,6 +476,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
params.batch_stride_bias = 0;
params.batch_stride_o = o_host_ref.get_strides()[0];
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
params.use_softmax = use_softmax;
params.use_causal = use_causal;
params.window_size = window_size;
params.contextual_seqlen = contextual_seqlen;
@@ -513,11 +517,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
using GemmAccDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType;
using CompDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType;
BOOL_SWITCH_2(is_jagged, kIsJagged, use_causal, kUseCausal, [&] {
BOOL_SWITCH_3(is_jagged, kIsJagged, use_softmax, kUseSoftmax, use_causal, kUseCausal, [&] {
ck_tile::reference_hstu_attention<InOutDataType,
GemmAccDataType,
CompDataType,
kIsJagged,
kUseSoftmax,
kUseCausal>::Run(q_host,
k_host,
v_host,

View File

@@ -17,12 +17,14 @@ void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStrea
const bool use_causal = param.use_causal;
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
kUseCausal,
false, // using softmax
kHasBias,
kHasDropout,
MaxK>(param, stream);
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
});
});
});
};

View File

@@ -17,12 +17,14 @@ void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStrea
const bool use_causal = param.use_causal;
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
kUseCausal,
false, // using softmax
kHasBias,
kHasDropout,
MaxK>(param, stream);
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
});
});
});
};

View File

@@ -245,7 +245,7 @@ struct HstuAttentionFwdKernel
seq_stride_v,
seq_stride_o,
num_head,
-scale_s,
scale_s,
attn_scale ? attn_scale : 1.0f / static_cast<float>(seqlen), // max_seqlen
contextual_seqlen,
window_size,
@@ -320,7 +320,7 @@ struct HstuAttentionFwdKernel
hdim_v,
-1, // seqlen will be updated by another pointer
num_head,
-scale_s,
scale_s,
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen),
contextual_seqlen,
window_size,

View File

@@ -156,6 +156,8 @@ struct HstuAttentionFwdPipelineQRKSVS
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr bool kUseSoftmax = Problem::kUseSoftmax;
constexpr index_t k1_loops = kN0 / kK1;
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
@@ -173,6 +175,16 @@ struct HstuAttentionFwdPipelineQRKSVS
SaccBlockTileType sacc_tile;
PcompBlockTileType pcomp_tile;
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
using MLBlockTileType = decltype(block_tile_reduce<CompDataType>(
PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0}));
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
OaccBlockTileType o_acc;
@@ -279,15 +291,26 @@ struct HstuAttentionFwdPipelineQRKSVS
// reduction function for softmax
const auto f_silu = [&](CompDataType& x) {
const auto neg_one = ck_tile::type_convert<CompDataType>(-1.0f);
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
if constexpr(std::is_same_v<CompDataType, float>)
{
x = x * __builtin_amdgcn_rcpf(neg_one - __expf(x));
x = x * __builtin_amdgcn_rcpf(one + __expf(-x));
}
else
{
x = x / (neg_one - exp(x));
x = x / (one + exp(-x));
}
};
const auto f_exp = [&](CompDataType x) {
if constexpr(std::is_same_v<CompDataType, float>)
{
return __expf(x);
}
else
{
return exp(x);
}
};
@@ -353,6 +376,12 @@ struct HstuAttentionFwdPipelineQRKSVS
});
clear_tile(o_acc);
if constexpr(kUseSoftmax)
{
set_tile(m, -numeric<CompDataType>::infinity());
clear_tile(l);
};
};
q_tile = tile_elementwise_in(q_element_func, q_tile);
@@ -409,41 +438,141 @@ struct HstuAttentionFwdPipelineQRKSVS
tile_elementwise_inout(
[&scale_s, &bias_element_func](auto& x, const auto& y) {
x = x * scale_s - type_convert<CompDataType>(bias_element_func(y));
x = x * scale_s + type_convert<CompDataType>(bias_element_func(y));
},
pcomp_tile,
bias_tile);
move_tile_window(bias_dram_window, {0, kK1});
move_tile_window(bias_dram_window, {0, kN0});
}
else
{
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
}
if(!mask.IsFullTileInsideMask(
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
if constexpr(!kUseSoftmax)
{
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
if(!mask.IsFullTileInsideMask(
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
{
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if(!mask.IsTokenPairInsideMask(row, col))
pcomp_tile(i_j_idx) = static_cast<CompDataType>(0.0f);
if(!mask.IsTokenPairInsideMask(row, col))
{
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
};
});
});
});
}
tile_elementwise_inout(f_silu, pcomp_tile);
tile_elementwise_inout(
[&](auto& x) { x = x * type_convert<CompDataType>(scale_p); }, pcomp_tile);
}
else
{
if(!mask.IsFullTileInsideMask(
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
{
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
tile_elementwise_inout(f_silu, pcomp_tile);
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
tile_elementwise_inout([&](auto& x) { x = x * type_convert<CompDataType>(scale_p); },
pcomp_tile);
if(!mask.IsTokenPairInsideMask(row, col) || col >= seqlen_k_end)
{
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
};
});
});
}
else
{
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if(col >= seqlen_k_end)
{
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
};
});
});
};
auto m_local = block_tile_reduce<CompDataType>(
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m;
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(m[i_idx] == -numeric<CompDataType>::infinity())
{
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
});
}
else
{
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
});
}
});
auto rowsum_p = block_tile_reduce<CompDataType>(
pcomp_tile, sequence<1>{}, f_sum, CompDataType{0});
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// adjust o_acc[] according to the update between m and m_old
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(m[i_idx] == -numeric<CompDataType>::infinity())
{
l(i_idx) = rowsum_p[i_idx];
}
else
{
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
}
});
};
seqlen_k_curr += kN0;
@@ -502,6 +631,23 @@ struct HstuAttentionFwdPipelineQRKSVS
};
} while(seqlen_k_curr < seqlen_k_end);
if constexpr(kUseSoftmax)
{
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if(m[i_idx] == -numeric<CompDataType>::infinity())
o_acc(i_j_idx) = 0.0f;
else
o_acc(i_j_idx) *= 1.0f / l[i_idx];
});
});
};
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;

View File

@@ -17,12 +17,14 @@ void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream
const bool use_causal = param.use_causal;
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
kUseCausal,
false, // using softmax
kHasBias,
kHasDropout,
MaxK>(param, stream);
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
});
});
});
};

View File

@@ -17,12 +17,14 @@ void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream
const bool use_causal = param.use_causal;
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
kUseCausal,
false, // using softmax
kHasBias,
kHasDropout,
MaxK>(param, stream);
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
});
});
});
};

View File

@@ -52,6 +52,8 @@ struct HstuAttentionFwdParams
ck_tile::index_t contextual_seqlen;
ck_tile::index_t min_full_attn_seqlen;
bool use_softmax;
float p_drop;
uint64_t philox_seed;
uint64_t philox_offset;

View File

@@ -29,6 +29,7 @@ template <typename InOutDataType,
typename GemmAccDataType,
typename CompDataType,
bool kIsJagged,
bool kUseSoftmax,
bool kUseCausal>
struct reference_hstu_attention
{
@@ -150,6 +151,11 @@ struct reference_hstu_attention
// for all rows in the batch
for(int sq = 0; sq < seqlen; sq++)
{
CompDataType m =
-ck_tile::numeric<CompDataType>::infinity(); // max value of the row
CompDataType l =
ck_tile::type_convert<CompDataType>(0.0f); // sum of exp(x-m) of the row
//
std::vector<CompDataType> locals;
// for all cols in the batch
@@ -186,12 +192,41 @@ struct reference_hstu_attention
ck_tile::type_convert<CompDataType>(alpha));
}
else
locals.push_back(ck_tile::type_convert<CompDataType>(0.0f));
{
if constexpr(!kUseSoftmax)
locals.push_back(ck_tile::type_convert<CompDataType>(0.0f));
else
locals.push_back(-ck_tile::numeric<CompDataType>::infinity());
};
};
// SiLu element-wise
for(CompDataType& elem : locals)
elem = silu(elem) * ck_tile::type_convert<CompDataType>(scale_p);
if constexpr(!kUseSoftmax)
{
// SiLu element-wise
for(CompDataType& elem : locals)
elem = silu(elem) * ck_tile::type_convert<CompDataType>(scale_p);
}
else
{
for(CompDataType elem : locals)
m = ck_tile::max(m, elem);
if(m == -ck_tile::numeric<CompDataType>::infinity())
{
for(CompDataType& elem : locals)
elem = ck_tile::type_convert<CompDataType>(0.0f);
}
else
{
// stabalized sum of exp()
for(CompDataType elem : locals)
l += std::exp(elem - m);
// normalization
for(CompDataType& elem : locals)
elem = std::exp(elem - m) / l;
}
};
// second Gemm
for(int k = 0; k < hdim_v; k++)

View File

@@ -0,0 +1,62 @@
#!/bin/bash
BUILD=build
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
attn_scale=0
if [ $# -ge 1 ]; then
attn_scale=$1
fi
ndist=0
if [ $# -ge 2 ]; then
ndist=$2
fi
for dtype in "fp16" "bf16"; do
set -x
## no masking batched
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## no masking jagged
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## batched causal
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## batched causal+local
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## batched causal+local+context
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local+context
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## batched causal+local+context+target
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local+context+target
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged no-causal+local+context+target
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local+target (minfull_len > max_uih_len)
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local+context+target (minfull_len > max_uih_len)
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged no-causal+local+context+target (minfull_len > max_uih_len)
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
set +x
done