diff --git a/example/ck_tile/18_hstu_attention/README.md b/example/ck_tile/18_hstu_attention/README.md index 9c6713fae4..b9f9dafe89 100644 --- a/example/ck_tile/18_hstu_attention/README.md +++ b/example/ck_tile/18_hstu_attention/README.md @@ -51,6 +51,7 @@ .insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention") .insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention") .insert("seed", "13579", "seed by the uniform or normal distribution generator") + .insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes") .insert("perf", "0", "weather measure execution time or not"); ``` diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index 3bac90d1a4..d07dba89c3 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -40,7 +40,7 @@ void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems) { outFile.write(reinterpret_cast(data), dataNumItems * sizeof(T)); outFile.close(); - printf("Wrote output to file %s\n", fileName); + printf("Write output to file %s\n", fileName); } else { @@ -84,6 +84,7 @@ auto create_args(int argc, char* argv[]) .insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention") .insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention") .insert("seed", "13579", "seed by the uniform or normal distribution generator") + .insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes") .insert("perf", "0", "weather measure execution time or not"); // clang-format on @@ -197,6 +198,8 @@ bool run(const ck_tile::ArgParser& arg_parser) int seed = arg_parser.get_int("seed"); bool measure_perf = static_cast(arg_parser.get_int("perf")); + bool save_mask = static_cast(arg_parser.get_int("save_mask")); + std::string str_of_targets = arg_parser.get_str("targets"); std::vector num_targets = get_integers_from_string(str_of_targets); @@ -260,7 +263,8 @@ bool run(const ck_tile::ArgParser& arg_parser) else { assert(1 == seq_lengths.size()); - seqlen = seq_lengths[0]; + seqlen = seq_lengths[0]; + max_seqlen = seqlen; if(!num_targets.empty()) { @@ -310,6 +314,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor o_host_ref( std::array{batches_for_alloc, seqlen, num_head, hdim_v}); + ck_tile::HostTensor mask_host( + save_mask ? std::array{num_batch, num_head, max_seqlen, max_seqlen} + : std::array{1, 1, 1, 1}); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(q_host); ck_tile::FillNormalDistribution{0.f, 1.f, seed}(k_host); ck_tile::FillNormalDistribution{0.f, 1.f, seed}(v_host); @@ -447,6 +455,7 @@ bool run(const ck_tile::ArgParser& arg_parser) k_host, v_host, o_host_ref, + mask_host, num_batch, 1.0f / std::sqrt(params.hdim_qk), is_jagged ? max_seqlen : seqlen, @@ -465,6 +474,9 @@ bool run(const ck_tile::ArgParser& arg_parser) // dumpBufferToFile("output_dev.dat", o_host.data(), o_host.get_element_space_size()); // dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size()); + if(save_mask) + dumpBufferToFile("ck_hstu_mask.dat", mask_host.data(), mask_host.get_element_space_size()); + auto [rtol, atol] = get_elimit(); res = ck_tile::check_err( diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index ce02f57e73..2319a893e6 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -40,6 +40,7 @@ struct reference_hstu_attention const HostTensor& k_batch_seq_nhead_hdim, const HostTensor& v_batch_seq_nhead_hdim, HostTensor& o_batch_seq_nhead_hdim, + HostTensor& mask_batch_nhead_seq_seq, int num_batch, float alpha, int max_seqlen, @@ -87,6 +88,14 @@ struct reference_hstu_attention assert(hdim_qk == k_batch_seq_nhead_hdim.get_lengths()[3]); assert(hdim_v == o_batch_seq_nhead_hdim.get_lengths()[3]); + bool save_mask = false; + + if(static_cast(mask_batch_nhead_seq_seq.get_lengths()[0]) == num_batch && + static_cast(mask_batch_nhead_seq_seq.get_lengths()[1]) == num_head && + static_cast(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_seqlen && + static_cast(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_seqlen) + save_mask = true; + // check num_tagets assert(num_tagets.empty() || num_targets.size() == num_batch); @@ -111,6 +120,15 @@ struct reference_hstu_attention seqlen, contextual_seqlen, num_target); }(); + if(save_mask) + { + // initialize the mask + for(int sq = 0; sq < max_seqlen; sq++) + for(int sk = 0; sk < max_seqlen; sk++) + mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) = + static_cast(mask.IsTokenPairInsideMask(sq, sk)); + } + // for all rows in the batch for(int sq = 0; sq < seqlen; sq++) {