From 83f29243dfd395f5debeaa518d30326c457c1bd8 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 29 Mar 2025 12:55:04 +0000 Subject: [PATCH] fix the jagged mode tensor access in reference_hstu_attention --- example/ck_tile/18_hstu_attention/README.md | 21 +++--- .../example_hstu_attention.cpp | 27 +++---- .../reference_hstu_attention.hpp | 73 +++++++++++++------ 3 files changed, 77 insertions(+), 44 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/README.md b/example/ck_tile/18_hstu_attention/README.md index 3ce1f27a14..9c6713fae4 100644 --- a/example/ck_tile/18_hstu_attention/README.md +++ b/example/ck_tile/18_hstu_attention/README.md @@ -29,7 +29,8 @@ ## test/verify ``` bash - #> build/bin/tile_example_hstu_attention -v=1 -prec=fp16 -b=10 -nidx=9 -nhead=4 -hsizeq=64 -hsizev=64 -seqq=13 -seqk=512 -init=u -seed=123 -perf=0 -maskmax=0 + #> build/bin/tile_example_hstu_attention -v=1 -prec=bf16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=750,730,733,860,870,788,760,821,833,779 -targets=5,5,6,6,5,6,5,6,4,6 + -causal=1 -local_len=5 -context_len=6 -minfull_len=6 #> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh ``` @@ -38,16 +39,18 @@ ``` C++ arg_parser.insert("v", "1", "weather do CPU validation or not") .insert("prec", "fp16", "data type. fp16/bf16") + .insert("jagged", "0", "q/k/v batched sequence is jagged or not") .insert("b", "12", "batch size") - .insert("nidx", "9", "number of indices for accessing the batches") .insert("nhead", "4", "number of heads") - .insert("hsizeq", "64", "headdim size of Q/K") - .insert("hsizev", "64", "headdim size of V/O") - .insert("seqq", "13", "length of the sequence dimension of query tensor") - .insert("seqv", "1024", "length of the sequence dimension of key tensor") - .insert("init", "u", "init method for input tensor values, u, uniform random float values, n, normalized random float values") + .insert("hdim_qk", "64", "headdim size of Q/K") + .insert("hdim_v", "64", "headdim size of V/O") + .insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor") + .insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention") + .insert("causal", "1", "enable causal mask or not") + .insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable") + .insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention") + .insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention") .insert("seed", "13579", "seed by the uniform or normal distribution generator") - .insert("perf", "0", "weather measure execution time or not") - .insert("maskmax", "0", "used to set mask values to random [0, maskmax), maskmax should in [0, 128], 0 means set all values to 1"); + .insert("perf", "0", "weather measure execution time or not"); ``` diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index 9432819199..1c46b7f846 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -216,22 +216,23 @@ bool run(const ck_tile::ArgParser& arg_parser) using GemmAccDataType = typename HSTUAttentionTypeConfig::GemmAccDataType; using SMComputeDataType = typename HSTUAttentionTypeConfig::SMComputeDataType; - BOOL_SWITCH_2(use_causal, USE_CAUSAL_, use_local, USE_LOCAL_, [&] { + BOOL_SWITCH_3(is_jagged, kIsJagged, use_causal, kUseCausal, use_local, kUseLocal, [&] { ck_tile::reference_hstu_attention::Run(q_host, - k_host, - v_host, - o_host_ref, - num_batch, - 1.0f, - seq_offsets, - num_targets, - max_attn_len, - contextual_seq_len, - min_full_seq_len); + kIsJagged, + kUseCausal, + kUseLocal>::Run(q_host, + k_host, + v_host, + o_host_ref, + num_batch, + 1.0f, + seq_offsets, + num_targets, + max_attn_len, + contextual_seq_len, + min_full_seq_len); }); return 0; } diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index d512aedcd3..27cd729a66 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -27,8 +27,9 @@ namespace ck_tile { template + bool kIsJagged, + bool kUseCausal, + bool kUseLocal> struct reference_hstu_attention { struct hstu_mask @@ -49,15 +50,15 @@ struct reference_hstu_attention max_uih_len = max_uih_len_; }; - bool IsPixelInsideMask(int row, int col) + bool IsTokenPairInsideMask(int row, int col) { if(row < contextual_seq_len) return true; bool result = false; - if constexpr(use_local) + if constexpr(kUseLocal) { - if constexpr(use_causal) + if constexpr(kUseCausal) result = (row >= col) && (row - col <= max_attn_len); else result = std::abs(row - col) <= max_attn_len; @@ -67,7 +68,7 @@ struct reference_hstu_attention } else { - if constexpr(use_causal) + if constexpr(kUseCausal) result = (row >= col); }; @@ -90,12 +91,10 @@ struct reference_hstu_attention int min_full_attn_seq_len) // define masking length at the end of query token // sequence which is included for full attention { - bool is_jagged = !seq_offsets.empty(); - - if(is_jagged) + if constexpr(kIsJagged) { // check the number of batches - assert(seq_offsets.size() == num_batch + 1); + assert(!seq_offsets.empty() && seq_offsets.size() == num_batch + 1); assert(q_batch_seq_nhead_hdim.get_lengths()[0] == 1); assert(k_batch_seq_nhead_hdim.get_lengths()[0] == 1); assert(v_batch_seq_nhead_hdim.get_lengths()[0] == 1); @@ -103,6 +102,7 @@ struct reference_hstu_attention } else { + assert(seq_offsets.empty()); assert(q_batch_seq_nhead_hdim.get_lengths()[0] == num_batch); assert(k_batch_seq_nhead_hdim.get_lengths()[0] == num_batch); assert(v_batch_seq_nhead_hdim.get_lengths()[0] == num_batch); @@ -140,7 +140,7 @@ struct reference_hstu_attention assert(num_targets.size() == num_batch); auto f = [&](auto i_batch, auto i_head) { - int seqlen = is_jagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch]) + int seqlen = kIsJagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch]) : q_batch_seq_nhead_hdim.get_lengths()[1]; int max_uih_len = seqlen; @@ -161,16 +161,29 @@ struct reference_hstu_attention // for all cols in the batch for(int sk = 0; sk < max_uih_len; sk++) { - if(mask.IsPixelInsideMask(sq, sk)) + if(mask.IsTokenPairInsideMask(sq, sk)) { GemmAccDataType dot_prod = 0.f; for(int k = 0; k < hdim_qk; k++) { - InOutDataType qreg = q_batch_seq_nhead_hdim(i_batch, sq, i_head, k); - InOutDataType kreg = k_batch_seq_nhead_hdim(i_batch, sk, i_head, k); + if constexpr(kIsJagged) + { + InOutDataType qreg = + q_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k); + InOutDataType kreg = + k_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k); - dot_prod += ck_tile::type_convert(qreg) * - ck_tile::type_convert(kreg); + dot_prod += ck_tile::type_convert(qreg) * + ck_tile::type_convert(kreg); + } + else + { + InOutDataType qreg = q_batch_seq_nhead_hdim(i_batch, sq, i_head, k); + InOutDataType kreg = k_batch_seq_nhead_hdim(i_batch, sk, i_head, k); + + dot_prod += ck_tile::type_convert(qreg) * + ck_tile::type_convert(kreg); + }; } locals.push_back(ck_tile::type_convert(dot_prod) * @@ -191,15 +204,31 @@ struct reference_hstu_attention for(int sk = 0; sk < max_uih_len; sk++) { - InOutDataType preg = ck_tile::type_convert(locals[sk]); - InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k); + if constexpr(kIsJagged) + { + InOutDataType preg = ck_tile::type_convert(locals[sk]); + InOutDataType vreg = + v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k); - dot_prod += ck_tile::type_convert(preg) * - ck_tile::type_convert(vreg); + dot_prod += ck_tile::type_convert(preg) * + ck_tile::type_convert(vreg); + } + else + { + InOutDataType preg = ck_tile::type_convert(locals[sk]); + InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k); + + dot_prod += ck_tile::type_convert(preg) * + ck_tile::type_convert(vreg); + }; }; - o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) = - ck_tile::type_convert(dot_prod); + if constexpr(kIsJagged) + o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) = + ck_tile::type_convert(dot_prod); + else + o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) = + ck_tile::type_convert(dot_prod); }; }; };