diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index dfbdb34427..e04cf76a40 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -32,6 +32,22 @@ extern void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, h extern void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream); extern void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream); +template +void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems) +{ + std::ofstream outFile(fileName, std::ios::binary); + if(outFile) + { + outFile.write(reinterpret_cast(data), dataNumItems * sizeof(T)); + outFile.close(); + printf("Wrote output to file %s\n", fileName); + } + else + { + printf("Could not open file %s for writing\n", fileName); + } +} + template std::ostream& operator<<(std::ostream& os, const std::vector& v) { @@ -424,6 +440,9 @@ bool run(const ck_tile::ArgParser& arg_parser) o_dev.FromDevice(o_host.data()); + dumpBufferToFile("output_dev.dat", o_host.data(), o_host.get_element_space_size()); + dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size()); + auto [rtol, atol] = get_elimit(); res = ck_tile::check_err( diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index d7c68e7929..01523bacea 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -118,8 +118,8 @@ struct batched_forward_causal_local_bias_dropout_dispatch param.batch_stride_bias, param.batch_stride_o, param.num_targets_ptr, - param.window_size, param.contextual_seqlen, + param.window_size, param.min_full_attn_seqlen, param.p_drop, param.philox_seed, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 87ebbb941b..1f76900000 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -383,7 +383,7 @@ struct HstuAttentionFwdPipelineQRKSVS { const auto k_origin = k_dram_block_window.get_window_origin(); set_tile_if(s_acc, type_convert(0), [&](auto tile_idx) { - if(i_loop < num_loops - 1) + if(q_origin.at(number<0>{}) + kM0 <= mask.max_uih_len && i_loop < num_loops - 1) return false; const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index 5fe497d666..7a689a93a6 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -107,8 +107,8 @@ struct jagged_forward_causal_local_bias_dropout_dispatch param.nhead_stride_bias, param.nhead_stride_o, param.num_targets_ptr, - param.window_size, param.contextual_seqlen, + param.window_size, param.min_full_attn_seqlen, param.p_drop, param.philox_seed, diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 85bffdcf9e..d6307128e0 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -29,12 +29,13 @@ struct HstuBlockMasking max_uih_len = seqlen_; - max_uih_len -= contextual_seqlen - 1; + max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0; max_uih_len -= num_target; }; // to get the loop length along X axis, return index:[start, end), end-start=length // use this if need loop over X axis tile by tile (eg. seqlen_k loop-over) + // i_y is the start offset of the current tile along the seqlen_q dimension template CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number, number) const @@ -45,7 +46,7 @@ struct HstuBlockMasking } else { - if(contextual_seqlen > 0 && (i_y < contextual_seqlen)) + if(i_y < contextual_seqlen) return ck_tile::make_tuple(0, max_uih_len); if constexpr(kUseCausal && !kUseLocal) @@ -101,10 +102,10 @@ struct HstuBlockMasking if constexpr(kUseCausal) result = (row >= col) && (row - col <= max_attn_len); else - result = std::abs(row - col) <= max_attn_len; + result = abs(row - col) <= max_attn_len; if(min_full_attn_seqlen > 0) - result = result || (row >= max_uih_len - min_full_attn_seqlen); + result = (row >= max_uih_len - min_full_attn_seqlen); } else { diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 96e7922eff..f87b7c2ff8 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -33,6 +33,9 @@ template struct reference_hstu_attention { + using HstuMask = HstuBlockMasking; + static constexpr bool kHasMask = kUseCausal || kUseLocal; + static void Run(const HostTensor& q_batch_seq_nhead_hdim, const HostTensor& k_batch_seq_nhead_hdim, const HostTensor& v_batch_seq_nhead_hdim, @@ -100,8 +103,13 @@ struct reference_hstu_attention int num_target = num_targets.empty() ? 0 : num_targets[i_batch]; - HstuBlockMasking mask{ - max_attn_len, contextual_seqlen, min_full_attn_seqlen, seqlen, num_target}; + HstuMask mask = [&]() { + if constexpr(kHasMask) + return HstuMask{ + max_attn_len, contextual_seqlen, min_full_attn_seqlen, seqlen, num_target}; + else + return HstuMask{0, contextual_seqlen, 0, seqlen, num_target}; + }(); // for all rows in the batch for(int sq = 0; sq < seqlen; sq++)