From 333abddbaeb3b7e9a476e5bdd046ff844ab66ffe Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 28 May 2026 08:07:54 +0000 Subject: [PATCH] Rename the reference interfaces and the files --- .../ck_tile/18_hstu_attention/CMakeLists.txt | 2 +- ...ion.cpp => example_hstu_attention_fwd.cpp} | 93 ++++++++++--------- ...n.hpp => reference_hstu_attention_fwd.hpp} | 6 +- 3 files changed, 51 insertions(+), 50 deletions(-) rename example/ck_tile/18_hstu_attention/{example_hstu_attention.cpp => example_hstu_attention_fwd.cpp} (94%) rename example/ck_tile/18_hstu_attention/{reference_hstu_attention.hpp => reference_hstu_attention_fwd.hpp} (99%) diff --git a/example/ck_tile/18_hstu_attention/CMakeLists.txt b/example/ck_tile/18_hstu_attention/CMakeLists.txt index 8d40e5903c..d7690ec66f 100644 --- a/example/ck_tile/18_hstu_attention/CMakeLists.txt +++ b/example/ck_tile/18_hstu_attention/CMakeLists.txt @@ -4,7 +4,7 @@ set(EXAMPLE_HSTU_ATTENTION "tile_example_hstu_attention") message("adding example ${EXAMPLE_HSTU_ATTENTION}") file(GLOB INSTANCE_SRCS instances/*.cpp) set(INTERFACES_SRCS hstu_attention_no_group_forward_bf16.cpp hstu_attention_no_group_forward_fp16.cpp hstu_attention_group_forward_bf16.cpp hstu_attention_group_forward_fp16.cpp) -add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention.cpp) +add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention_fwd.cpp) target_include_directories(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${INTERFACES_SRCS} ${INSTANCE_SRCS}) diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp similarity index 94% rename from example/ck_tile/18_hstu_attention/example_hstu_attention.cpp rename to example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp index 51ddce2e87..f86168c0c6 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp @@ -25,7 +25,7 @@ #include "hstu_attention_fwd_type_config.hpp" #include "hstu_attention_bool_switch.hpp" #include "hstu_attention_params.hpp" -#include "reference_hstu_attention.hpp" +#include "reference_hstu_attention_fwd.hpp" #include "hstu_attention_util.hpp" #include "hstu_attention_api.hpp" @@ -576,28 +576,28 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged) using CompDataType = typename HstuAttentionFwdTypeConfig::CompDataType; BOOL_SWITCH_3(is_jagged, kIsJagged, use_softmax, kUseSoftmax, use_causal, kUseCausal, [&] { - ck_tile::reference_no_group_hstu_attention::Run(is_cross_attention, - q_host, - k_host, - v_host, - o_host_ref, - mask_host, - num_batch, - scale_s, - attn_scale, - max_seqlen_q, - max_seqlen_kv, - seq_offsets_q, - seq_offsets_kv, - num_targets, - contextual_seqlen, - window_size, - min_full_attn_seqlen); + ck_tile::reference_no_group_hstu_attention_fwd::Run(is_cross_attention, + q_host, + k_host, + v_host, + o_host_ref, + mask_host, + num_batch, + scale_s, + attn_scale, + max_seqlen_q, + max_seqlen_kv, + seq_offsets_q, + seq_offsets_kv, + num_targets, + contextual_seqlen, + window_size, + min_full_attn_seqlen); }); ck_tile::HostTensor o_host( @@ -1003,29 +1003,30 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) using CompDataType = typename HstuAttentionFwdTypeConfig::CompDataType; BOOL_SWITCH_2(use_softmax, kUseSoftmax, use_causal, kUseCausal, [&] { - ck_tile::reference_group_hstu_attention::Run(is_cross_attention, - q_host, - k_host, - v_host, - o_host_ref, - mask_host, - num_batch, - num_batch / num_group, - scale_s, - max_max_seqlen_q, - max_max_seqlen_kv, - seq_offsets_q, - seq_offsets_kv, - num_targets, - group_max_seqlens_q, - group_contextual_seqlens, - group_window_sizes, - group_min_full_attn_seqlens, - group_attn_scales); + ck_tile::reference_group_hstu_attention_fwd< + InOutDataType, + GemmAccDataType, + CompDataType, + kUseSoftmax, + kUseCausal>::Run(is_cross_attention, + q_host, + k_host, + v_host, + o_host_ref, + mask_host, + num_batch, + num_batch / num_group, + scale_s, + max_max_seqlen_q, + max_max_seqlen_kv, + seq_offsets_q, + seq_offsets_kv, + num_targets, + group_max_seqlens_q, + group_contextual_seqlens, + group_window_sizes, + group_min_full_attn_seqlens, + group_attn_scales); }); ck_tile::HostTensor o_host( diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention_fwd.hpp similarity index 99% rename from example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp rename to example/ck_tile/18_hstu_attention/reference_hstu_attention_fwd.hpp index f5d6e80bd0..3af31c62b9 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention_fwd.hpp @@ -17,7 +17,7 @@ namespace ck_tile { // clang-format off -// Reference implementation of HSTUAttention problem, which does the following from input tensors: +// Reference implementation of HSTUAttention forward problem, which does the following from input tensors: // S[num_batch, num_head, seqlen, seqlen] = Q[num_batch, seqlen, num_head, hdim_qk] @ key^T[num_batch, seqlen, num_head, hdim_v] // P[num_batch, num_head, seqlen, seqlen] = SiLU(Masking(S[num_batch, num_head, seqlen, seqlen])) // O[num_batch, num_head, seqlen, hdim_v] = P[num_batch, num_head, seqlen, seqlen] @ value^T[num_batch, num_head, seqlen, hdim_v] @@ -31,7 +31,7 @@ template -struct reference_no_group_hstu_attention +struct reference_no_group_hstu_attention_fwd { static void Run(bool is_cross_attention, const HostTensor& q_batch_seq_nhead_hdim, @@ -321,7 +321,7 @@ template -struct reference_group_hstu_attention +struct reference_group_hstu_attention_fwd { static void Run(bool is_cross_attention,