Rename the reference interfaces and the files

This commit is contained in:
Qianfeng Zhang
2026-05-28 08:07:54 +00:00
parent e841981ddd
commit 333abddbae
3 changed files with 51 additions and 50 deletions

View File

@@ -4,7 +4,7 @@ set(EXAMPLE_HSTU_ATTENTION "tile_example_hstu_attention")
message("adding example ${EXAMPLE_HSTU_ATTENTION}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
set(INTERFACES_SRCS hstu_attention_no_group_forward_bf16.cpp hstu_attention_no_group_forward_fp16.cpp hstu_attention_group_forward_bf16.cpp hstu_attention_group_forward_fp16.cpp)
add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention.cpp)
add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention_fwd.cpp)
target_include_directories(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${INTERFACES_SRCS} ${INSTANCE_SRCS})

View File

@@ -25,7 +25,7 @@
#include "hstu_attention_fwd_type_config.hpp"
#include "hstu_attention_bool_switch.hpp"
#include "hstu_attention_params.hpp"
#include "reference_hstu_attention.hpp"
#include "reference_hstu_attention_fwd.hpp"
#include "hstu_attention_util.hpp"
#include "hstu_attention_api.hpp"
@@ -576,28 +576,28 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
using CompDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType;
BOOL_SWITCH_3(is_jagged, kIsJagged, use_softmax, kUseSoftmax, use_causal, kUseCausal, [&] {
ck_tile::reference_no_group_hstu_attention<InOutDataType,
GemmAccDataType,
CompDataType,
kIsJagged,
kUseSoftmax,
kUseCausal>::Run(is_cross_attention,
q_host,
k_host,
v_host,
o_host_ref,
mask_host,
num_batch,
scale_s,
attn_scale,
max_seqlen_q,
max_seqlen_kv,
seq_offsets_q,
seq_offsets_kv,
num_targets,
contextual_seqlen,
window_size,
min_full_attn_seqlen);
ck_tile::reference_no_group_hstu_attention_fwd<InOutDataType,
GemmAccDataType,
CompDataType,
kIsJagged,
kUseSoftmax,
kUseCausal>::Run(is_cross_attention,
q_host,
k_host,
v_host,
o_host_ref,
mask_host,
num_batch,
scale_s,
attn_scale,
max_seqlen_q,
max_seqlen_kv,
seq_offsets_q,
seq_offsets_kv,
num_targets,
contextual_seqlen,
window_size,
min_full_attn_seqlen);
});
ck_tile::HostTensor<InOutDataType> o_host(
@@ -1003,29 +1003,30 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
using CompDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType;
BOOL_SWITCH_2(use_softmax, kUseSoftmax, use_causal, kUseCausal, [&] {
ck_tile::reference_group_hstu_attention<InOutDataType,
GemmAccDataType,
CompDataType,
kUseSoftmax,
kUseCausal>::Run(is_cross_attention,
q_host,
k_host,
v_host,
o_host_ref,
mask_host,
num_batch,
num_batch / num_group,
scale_s,
max_max_seqlen_q,
max_max_seqlen_kv,
seq_offsets_q,
seq_offsets_kv,
num_targets,
group_max_seqlens_q,
group_contextual_seqlens,
group_window_sizes,
group_min_full_attn_seqlens,
group_attn_scales);
ck_tile::reference_group_hstu_attention_fwd<
InOutDataType,
GemmAccDataType,
CompDataType,
kUseSoftmax,
kUseCausal>::Run(is_cross_attention,
q_host,
k_host,
v_host,
o_host_ref,
mask_host,
num_batch,
num_batch / num_group,
scale_s,
max_max_seqlen_q,
max_max_seqlen_kv,
seq_offsets_q,
seq_offsets_kv,
num_targets,
group_max_seqlens_q,
group_contextual_seqlens,
group_window_sizes,
group_min_full_attn_seqlens,
group_attn_scales);
});
ck_tile::HostTensor<InOutDataType> o_host(

View File

@@ -17,7 +17,7 @@
namespace ck_tile {
// clang-format off
// Reference implementation of HSTUAttention problem, which does the following from input tensors:
// Reference implementation of HSTUAttention forward problem, which does the following from input tensors:
// S[num_batch, num_head, seqlen, seqlen] = Q[num_batch, seqlen, num_head, hdim_qk] @ key^T[num_batch, seqlen, num_head, hdim_v]
// P[num_batch, num_head, seqlen, seqlen] = SiLU(Masking(S[num_batch, num_head, seqlen, seqlen]))
// O[num_batch, num_head, seqlen, hdim_v] = P[num_batch, num_head, seqlen, seqlen] @ value^T[num_batch, num_head, seqlen, hdim_v]
@@ -31,7 +31,7 @@ template <typename InOutDataType,
bool kIsJagged,
bool kUseSoftmax,
bool kUseCausal>
struct reference_no_group_hstu_attention
struct reference_no_group_hstu_attention_fwd
{
static void Run(bool is_cross_attention,
const HostTensor<InOutDataType>& q_batch_seq_nhead_hdim,
@@ -321,7 +321,7 @@ template <typename InOutDataType,
typename CompDataType,
bool kUseSoftmax,
bool kUseCausal>
struct reference_group_hstu_attention
struct reference_group_hstu_attention_fwd
{
static void
Run(bool is_cross_attention,