From 733734553bfe72b9912937cb2f7be9f282fd2c67 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 3 Apr 2025 14:44:36 +0000 Subject: [PATCH] Fix and change in example --- .../example_hstu_attention.cpp | 85 ++++++++++++++++--- 1 file changed, 74 insertions(+), 11 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index 7016746ebd..dfbdb34427 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -62,7 +62,7 @@ auto create_args(int argc, char* argv[]) .insert("hdim_qk", "64", "headdim size of Q/K") .insert("hdim_v", "64", "headdim size of V/O") .insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor") - .insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention") + .insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention") .insert("causal", "1", "enable causal mask or not") .insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable") .insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention") @@ -95,13 +95,53 @@ static std::vector get_integers_from_string(std::string lengthsStr) }; std::string sliceStr = lengthsStr.substr(pos); - int len = std::stoi(sliceStr); - lengths.push_back(len); + if(!sliceStr.empty()) + { + int len = std::stoi(sliceStr); + + lengths.push_back(len); + }; return (lengths); }; +static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdParams& param) +{ + if(param.is_jagged) + { + os << "Jagged inputs used! " << std::endl; + os << "use causal: " << param.use_causal << std::endl; + os << "Num of batches: " << param.num_batch << std::endl; + os << "Num of heads: " << param.num_head << std::endl; + os << "QK hdim: " << param.hdim_qk << " V hdim: " << param.hdim_v << std::endl; + os << "Q/K/V/O seq stride: " << param.seq_stride_q << " " << param.seq_stride_k << " " + << param.seq_stride_v << " " << param.seq_stride_o << std::endl; + os << "Q/K/V/O nhead stride: " << param.nhead_stride_q << " " << param.nhead_stride_k << " " + << param.nhead_stride_v << " " << param.nhead_stride_o << std::endl; + os << "contextual_seqlen: " << param.contextual_seqlen << std::endl; + os << "window_size: " << param.window_size << std::endl; + os << "min_full_attn_seqlen: " << param.min_full_attn_seqlen << std::endl; + } + else + { + os << "Batched inputs used! " << std::endl; + os << "use causal: " << param.use_causal << std::endl; + os << "Num of batches: " << param.num_batch << std::endl; + os << "Num of heads: " << param.num_head << std::endl; + os << "QK hdim: " << param.hdim_qk << " V hdim: " << param.hdim_v << std::endl; + os << "Q/K/V/O seq stride: " << param.seq_stride_q << " " << param.seq_stride_k << " " + << param.seq_stride_v << " " << param.seq_stride_o << std::endl; + os << "Q/K/V/O nhead stride: " << param.nhead_stride_q << " " << param.nhead_stride_k << " " + << param.nhead_stride_v << " " << param.nhead_stride_o << std::endl; + os << "Q/K/V/O batch stride: " << param.batch_stride_q << " " << param.batch_stride_k << " " + << param.batch_stride_v << " " << param.batch_stride_o << std::endl; + os << "contextual_seqlen: " << param.contextual_seqlen << std::endl; + os << "window_size: " << param.window_size << std::endl; + os << "min_full_attn_seqlen: " << param.min_full_attn_seqlen << std::endl; + }; +}; + // threshold for different dtypes template auto get_elimit() @@ -152,9 +192,27 @@ bool run(const ck_tile::ArgParser& arg_parser) int seqlen = 0; // means total seq lengths for jagged int max_seqlen = 0; + // supplement the sequence of lengths to have num_batch values + if(!num_targets.empty() && static_cast(num_targets.size()) < num_batch) + { + auto last_val = num_targets.back(); + + for(int i = num_targets.size(); i < num_batch; i++) + num_targets.push_back(last_val); + }; + if(is_jagged) { - assert(num_batch == seq_lengths.size()); + assert(num_batch >= seq_lengths.size()); + + // supplement the sequence of lengths to have num_batch values + if(static_cast(seq_lengths.size()) < num_batch) + { + auto last_len = seq_lengths.back(); + + for(int i = seq_lengths.size(); i < num_batch; i++) + seq_lengths.push_back(last_len); + }; seq_offsets.push_back(0); for(size_t i = 0; i < seq_lengths.size(); i++) @@ -190,10 +248,13 @@ bool run(const ck_tile::ArgParser& arg_parser) if(!num_targets.empty()) { - assert(1 == num_targets.size()); + assert(num_batch == num_targets.size()); - assert(seqlen - num_targets[0] >= min_full_attn_seqlen); - assert(seqlen - num_targets[0] >= contextual_seqlen); + for(size_t i = 0; i < seq_lengths.size(); i++) + { + assert(seqlen - num_targets[i] >= min_full_attn_seqlen); + assert(seqlen - num_targets[i] >= contextual_seqlen); + }; } else { @@ -245,7 +306,7 @@ bool run(const ck_tile::ArgParser& arg_parser) params.q_ptr = q_dev.GetDeviceBuffer(); params.k_ptr = k_dev.GetDeviceBuffer(); params.v_ptr = v_dev.GetDeviceBuffer(); - params.bias_ptr = nullptr; + params.bias_ptr = nullptr; // bias is not supported at present params.o_ptr = o_dev.GetDeviceBuffer(); params.hdim_qk = hdim_qk; params.hdim_v = hdim_v; @@ -278,7 +339,7 @@ bool run(const ck_tile::ArgParser& arg_parser) params.q_ptr = q_dev.GetDeviceBuffer(); params.k_ptr = k_dev.GetDeviceBuffer(); params.v_ptr = v_dev.GetDeviceBuffer(); - params.bias_ptr = nullptr; + params.bias_ptr = nullptr; // bias is not supported at present params.o_ptr = o_dev.GetDeviceBuffer(); params.hdim_qk = hdim_qk; params.hdim_v = hdim_v; @@ -309,6 +370,8 @@ bool run(const ck_tile::ArgParser& arg_parser) params.philox_offset = 0UL; }; + show_hstu_attention_fwd_param(std::cout, params); + hipStream_t stream; HIP_CHECK_ERROR(hipStreamCreate(&stream)); @@ -348,7 +411,7 @@ bool run(const ck_tile::ArgParser& arg_parser) v_host, o_host_ref, num_batch, - 1.0f, + 1.0f / std::sqrt(params.hdim_qk), seq_offsets, num_targets, window_size, @@ -393,7 +456,7 @@ bool run(const ck_tile::ArgParser& arg_parser) auto ms = timer.duration() / 20.f; - std::cout << "Average execution time of the gather_attention operator is " << ms + std::cout << "Average execution time of the hstu_attention operator is " << ms << " milli-seconds" << std::endl; }