diff --git a/example/ck_tile/18_hstu_attention/README.md b/example/ck_tile/18_hstu_attention/README.md index 3d52ba825d..7c0a680540 100644 --- a/example/ck_tile/18_hstu_attention/README.md +++ b/example/ck_tile/18_hstu_attention/README.md @@ -48,6 +48,7 @@ .insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens") .insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention") .insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets") + .insert("softmax", "0", "use softmax or not") .insert("causal", "1", "enable causal mask or not") .insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable") .insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention") diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index 685957594e..6f9e1b9391 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -104,6 +104,7 @@ auto create_args(int argc, char* argv[]) .insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens") .insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention") .insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets") + .insert("softmax", "0", "use softmax or not") .insert("causal", "1", "enable causal mask or not") .insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable") .insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention") @@ -216,6 +217,7 @@ bool run(const ck_tile::ArgParser& arg_parser) int num_head = arg_parser.get_int("nhead"); int hdim_qk = arg_parser.get_int("hdim_qk"); int hdim_v = arg_parser.get_int("hdim_v"); + bool use_softmax = static_cast(arg_parser.get_int("softmax")); bool use_causal = static_cast(arg_parser.get_int("causal")); int window_size = arg_parser.get_int("local_len"); @@ -434,6 +436,7 @@ bool run(const ck_tile::ArgParser& arg_parser) params.nhead_stride_bias = 0; params.nhead_stride_o = o_host_ref.get_strides()[2]; params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer(); + params.use_softmax = use_softmax; params.use_causal = use_causal; params.window_size = window_size; params.contextual_seqlen = contextual_seqlen; @@ -473,6 +476,7 @@ bool run(const ck_tile::ArgParser& arg_parser) params.batch_stride_bias = 0; params.batch_stride_o = o_host_ref.get_strides()[0]; params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer(); + params.use_softmax = use_softmax; params.use_causal = use_causal; params.window_size = window_size; params.contextual_seqlen = contextual_seqlen; @@ -513,11 +517,12 @@ bool run(const ck_tile::ArgParser& arg_parser) using GemmAccDataType = typename HstuAttentionFwdTypeConfig::GemmAccDataType; using CompDataType = typename HstuAttentionFwdTypeConfig::CompDataType; - BOOL_SWITCH_2(is_jagged, kIsJagged, use_causal, kUseCausal, [&] { + BOOL_SWITCH_3(is_jagged, kIsJagged, use_softmax, kUseSoftmax, use_causal, kUseCausal, [&] { ck_tile::reference_hstu_attention::Run(q_host, k_host, v_host, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_bf16.cpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_bf16.cpp index 764fa96c42..066adf61dc 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_bf16.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_bf16.cpp @@ -17,12 +17,14 @@ void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStrea const bool use_causal = param.use_causal; BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] { HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] { - run_batched_forward_causal_softmax_bias_dropout_dispatch(param, stream); + BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] { + run_batched_forward_causal_softmax_bias_dropout_dispatch(param, stream); + }); }); }); }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_fp16.cpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_fp16.cpp index 75575e151e..58db583131 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_fp16.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_fp16.cpp @@ -17,12 +17,14 @@ void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStrea const bool use_causal = param.use_causal; BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] { HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] { - run_batched_forward_causal_softmax_bias_dropout_dispatch(param, stream); + BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] { + run_batched_forward_causal_softmax_bias_dropout_dispatch(param, stream); + }); }); }); }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index c601f8fb80..d6d91746a5 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -245,7 +245,7 @@ struct HstuAttentionFwdKernel seq_stride_v, seq_stride_o, num_head, - -scale_s, + scale_s, attn_scale ? attn_scale : 1.0f / static_cast(seqlen), // max_seqlen contextual_seqlen, window_size, @@ -320,7 +320,7 @@ struct HstuAttentionFwdKernel hdim_v, -1, // seqlen will be updated by another pointer num_head, - -scale_s, + scale_s, attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen), contextual_seqlen, window_size, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index caad20e33f..345b1bb68f 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -156,6 +156,8 @@ struct HstuAttentionFwdPipelineQRKSVS kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + constexpr bool kUseSoftmax = Problem::kUseSoftmax; + constexpr index_t k1_loops = kN0 / kK1; constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); @@ -173,6 +175,16 @@ struct HstuAttentionFwdPipelineQRKSVS SaccBlockTileType sacc_tile; PcompBlockTileType pcomp_tile; + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + using MLBlockTileType = decltype(block_tile_reduce( + PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0})); + + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); OaccBlockTileType o_acc; @@ -279,15 +291,26 @@ struct HstuAttentionFwdPipelineQRKSVS // reduction function for softmax const auto f_silu = [&](CompDataType& x) { - const auto neg_one = ck_tile::type_convert(-1.0f); + const auto one = ck_tile::type_convert(1.0f); if constexpr(std::is_same_v) { - x = x * __builtin_amdgcn_rcpf(neg_one - __expf(x)); + x = x * __builtin_amdgcn_rcpf(one + __expf(-x)); } else { - x = x / (neg_one - exp(x)); + x = x / (one + exp(-x)); + } + }; + + const auto f_exp = [&](CompDataType x) { + if constexpr(std::is_same_v) + { + return __expf(x); + } + else + { + return exp(x); } }; @@ -353,6 +376,12 @@ struct HstuAttentionFwdPipelineQRKSVS }); clear_tile(o_acc); + + if constexpr(kUseSoftmax) + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + }; }; q_tile = tile_elementwise_in(q_element_func, q_tile); @@ -409,41 +438,141 @@ struct HstuAttentionFwdPipelineQRKSVS tile_elementwise_inout( [&scale_s, &bias_element_func](auto& x, const auto& y) { - x = x * scale_s - type_convert(bias_element_func(y)); + x = x * scale_s + type_convert(bias_element_func(y)); }, pcomp_tile, bias_tile); - move_tile_window(bias_dram_window, {0, kK1}); + move_tile_window(bias_dram_window, {0, kN0}); } else { tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); } - if(!mask.IsFullTileInsideMask( - q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) + if constexpr(!kUseSoftmax) { - constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + if(!mask.IsFullTileInsideMask( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) + { + constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); - if(!mask.IsTokenPairInsideMask(row, col)) - pcomp_tile(i_j_idx) = static_cast(0.0f); + if(!mask.IsTokenPairInsideMask(row, col)) + { + pcomp_tile(i_j_idx) = type_convert(0.0f); + }; + }); }); - }); + } + + tile_elementwise_inout(f_silu, pcomp_tile); + + tile_elementwise_inout( + [&](auto& x) { x = x * type_convert(scale_p); }, pcomp_tile); } + else + { + if(!mask.IsFullTileInsideMask( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) + { + constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - tile_elementwise_inout(f_silu, pcomp_tile); + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); - tile_elementwise_inout([&](auto& x) { x = x * type_convert(scale_p); }, - pcomp_tile); + if(!mask.IsTokenPairInsideMask(row, col) || col >= seqlen_k_end) + { + pcomp_tile(i_j_idx) = -numeric::infinity(); + }; + }); + }); + } + else + { + constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if(col >= seqlen_k_end) + { + pcomp_tile(i_j_idx) = -numeric::infinity(); + }; + }); + }); + }; + + auto m_local = block_tile_reduce( + pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; + + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + + constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = type_convert(0.0f); + }); + } + else + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); + }); + } + }); + + auto rowsum_p = block_tile_reduce( + pcomp_tile, sequence<1>{}, f_sum, CompDataType{0}); + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + + // adjust o_acc[] according to the update between m and m_old + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + l(i_idx) = rowsum_p[i_idx]; + } + else + { + const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + } + }); + }; seqlen_k_curr += kN0; @@ -502,6 +631,23 @@ struct HstuAttentionFwdPipelineQRKSVS }; } while(seqlen_k_curr < seqlen_k_end); + if constexpr(kUseSoftmax) + { + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if(m[i_idx] == -numeric::infinity()) + o_acc(i_j_idx) = 0.0f; + else + o_acc(i_j_idx) *= 1.0f / l[i_idx]; + }); + }); + }; + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); return o_acc; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_bf16.cpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_bf16.cpp index 0d07914ff3..d5d86fe833 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_bf16.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_bf16.cpp @@ -17,12 +17,14 @@ void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream const bool use_causal = param.use_causal; BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] { HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] { - run_jagged_forward_causal_softmax_bias_dropout_dispatch(param, stream); + BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] { + run_jagged_forward_causal_softmax_bias_dropout_dispatch(param, stream); + }); }); }); }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_fp16.cpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_fp16.cpp index 6cf51f502e..9980f7078a 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_fp16.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_fp16.cpp @@ -17,12 +17,14 @@ void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream const bool use_causal = param.use_causal; BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] { HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] { - run_jagged_forward_causal_softmax_bias_dropout_dispatch(param, stream); + BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] { + run_jagged_forward_causal_softmax_bias_dropout_dispatch(param, stream); + }); }); }); }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp index 5f3a3ef1a1..2e85e11971 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp @@ -52,6 +52,8 @@ struct HstuAttentionFwdParams ck_tile::index_t contextual_seqlen; ck_tile::index_t min_full_attn_seqlen; + bool use_softmax; + float p_drop; uint64_t philox_seed; uint64_t philox_offset; diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 4b35ef3884..798f73c4c5 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -29,6 +29,7 @@ template struct reference_hstu_attention { @@ -150,6 +151,11 @@ struct reference_hstu_attention // for all rows in the batch for(int sq = 0; sq < seqlen; sq++) { + CompDataType m = + -ck_tile::numeric::infinity(); // max value of the row + CompDataType l = + ck_tile::type_convert(0.0f); // sum of exp(x-m) of the row + // std::vector locals; // for all cols in the batch @@ -186,12 +192,41 @@ struct reference_hstu_attention ck_tile::type_convert(alpha)); } else - locals.push_back(ck_tile::type_convert(0.0f)); + { + if constexpr(!kUseSoftmax) + locals.push_back(ck_tile::type_convert(0.0f)); + else + locals.push_back(-ck_tile::numeric::infinity()); + }; }; - // SiLu element-wise - for(CompDataType& elem : locals) - elem = silu(elem) * ck_tile::type_convert(scale_p); + if constexpr(!kUseSoftmax) + { + // SiLu element-wise + for(CompDataType& elem : locals) + elem = silu(elem) * ck_tile::type_convert(scale_p); + } + else + { + for(CompDataType elem : locals) + m = ck_tile::max(m, elem); + + if(m == -ck_tile::numeric::infinity()) + { + for(CompDataType& elem : locals) + elem = ck_tile::type_convert(0.0f); + } + else + { + // stabalized sum of exp() + for(CompDataType elem : locals) + l += std::exp(elem - m); + + // normalization + for(CompDataType& elem : locals) + elem = std::exp(elem - m) / l; + } + }; // second Gemm for(int k = 0; k < hdim_v; k++) diff --git a/example/ck_tile/18_hstu_attention/scripts/test_hstu_softmax_attention.sh b/example/ck_tile/18_hstu_attention/scripts/test_hstu_softmax_attention.sh new file mode 100644 index 0000000000..bee42b45ea --- /dev/null +++ b/example/ck_tile/18_hstu_attention/scripts/test_hstu_softmax_attention.sh @@ -0,0 +1,62 @@ +#!/bin/bash + +BUILD=build +EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1" + +attn_scale=0 +if [ $# -ge 1 ]; then + attn_scale=$1 +fi + +ndist=0 + +if [ $# -ge 2 ]; then + ndist=$2 +fi + +for dtype in "fp16" "bf16"; do + set -x + + ## no masking batched + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## no masking jagged + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## batched causal + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## batched causal+local + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## batched causal+local+context + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local+context + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## batched causal+local+context+target + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local+context+target + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged no-causal+local+context+target + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local+context+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged no-causal+local+context+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + set +x +done