mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Rename the reference interfaces and the files
This commit is contained in:
@@ -4,7 +4,7 @@ set(EXAMPLE_HSTU_ATTENTION "tile_example_hstu_attention")
|
||||
message("adding example ${EXAMPLE_HSTU_ATTENTION}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
set(INTERFACES_SRCS hstu_attention_no_group_forward_bf16.cpp hstu_attention_no_group_forward_fp16.cpp hstu_attention_group_forward_bf16.cpp hstu_attention_group_forward_fp16.cpp)
|
||||
add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention.cpp)
|
||||
add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${INTERFACES_SRCS} ${INSTANCE_SRCS})
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "reference_hstu_attention.hpp"
|
||||
#include "reference_hstu_attention_fwd.hpp"
|
||||
|
||||
#include "hstu_attention_util.hpp"
|
||||
#include "hstu_attention_api.hpp"
|
||||
@@ -576,28 +576,28 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
|
||||
using CompDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType;
|
||||
|
||||
BOOL_SWITCH_3(is_jagged, kIsJagged, use_softmax, kUseSoftmax, use_causal, kUseCausal, [&] {
|
||||
ck_tile::reference_no_group_hstu_attention<InOutDataType,
|
||||
GemmAccDataType,
|
||||
CompDataType,
|
||||
kIsJagged,
|
||||
kUseSoftmax,
|
||||
kUseCausal>::Run(is_cross_attention,
|
||||
q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
mask_host,
|
||||
num_batch,
|
||||
scale_s,
|
||||
attn_scale,
|
||||
max_seqlen_q,
|
||||
max_seqlen_kv,
|
||||
seq_offsets_q,
|
||||
seq_offsets_kv,
|
||||
num_targets,
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen);
|
||||
ck_tile::reference_no_group_hstu_attention_fwd<InOutDataType,
|
||||
GemmAccDataType,
|
||||
CompDataType,
|
||||
kIsJagged,
|
||||
kUseSoftmax,
|
||||
kUseCausal>::Run(is_cross_attention,
|
||||
q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
mask_host,
|
||||
num_batch,
|
||||
scale_s,
|
||||
attn_scale,
|
||||
max_seqlen_q,
|
||||
max_seqlen_kv,
|
||||
seq_offsets_q,
|
||||
seq_offsets_kv,
|
||||
num_targets,
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen);
|
||||
});
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> o_host(
|
||||
@@ -1003,29 +1003,30 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
using CompDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType;
|
||||
|
||||
BOOL_SWITCH_2(use_softmax, kUseSoftmax, use_causal, kUseCausal, [&] {
|
||||
ck_tile::reference_group_hstu_attention<InOutDataType,
|
||||
GemmAccDataType,
|
||||
CompDataType,
|
||||
kUseSoftmax,
|
||||
kUseCausal>::Run(is_cross_attention,
|
||||
q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
mask_host,
|
||||
num_batch,
|
||||
num_batch / num_group,
|
||||
scale_s,
|
||||
max_max_seqlen_q,
|
||||
max_max_seqlen_kv,
|
||||
seq_offsets_q,
|
||||
seq_offsets_kv,
|
||||
num_targets,
|
||||
group_max_seqlens_q,
|
||||
group_contextual_seqlens,
|
||||
group_window_sizes,
|
||||
group_min_full_attn_seqlens,
|
||||
group_attn_scales);
|
||||
ck_tile::reference_group_hstu_attention_fwd<
|
||||
InOutDataType,
|
||||
GemmAccDataType,
|
||||
CompDataType,
|
||||
kUseSoftmax,
|
||||
kUseCausal>::Run(is_cross_attention,
|
||||
q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
mask_host,
|
||||
num_batch,
|
||||
num_batch / num_group,
|
||||
scale_s,
|
||||
max_max_seqlen_q,
|
||||
max_max_seqlen_kv,
|
||||
seq_offsets_q,
|
||||
seq_offsets_kv,
|
||||
num_targets,
|
||||
group_max_seqlens_q,
|
||||
group_contextual_seqlens,
|
||||
group_window_sizes,
|
||||
group_min_full_attn_seqlens,
|
||||
group_attn_scales);
|
||||
});
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> o_host(
|
||||
@@ -17,7 +17,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
// clang-format off
|
||||
// Reference implementation of HSTUAttention problem, which does the following from input tensors:
|
||||
// Reference implementation of HSTUAttention forward problem, which does the following from input tensors:
|
||||
// S[num_batch, num_head, seqlen, seqlen] = Q[num_batch, seqlen, num_head, hdim_qk] @ key^T[num_batch, seqlen, num_head, hdim_v]
|
||||
// P[num_batch, num_head, seqlen, seqlen] = SiLU(Masking(S[num_batch, num_head, seqlen, seqlen]))
|
||||
// O[num_batch, num_head, seqlen, hdim_v] = P[num_batch, num_head, seqlen, seqlen] @ value^T[num_batch, num_head, seqlen, hdim_v]
|
||||
@@ -31,7 +31,7 @@ template <typename InOutDataType,
|
||||
bool kIsJagged,
|
||||
bool kUseSoftmax,
|
||||
bool kUseCausal>
|
||||
struct reference_no_group_hstu_attention
|
||||
struct reference_no_group_hstu_attention_fwd
|
||||
{
|
||||
static void Run(bool is_cross_attention,
|
||||
const HostTensor<InOutDataType>& q_batch_seq_nhead_hdim,
|
||||
@@ -321,7 +321,7 @@ template <typename InOutDataType,
|
||||
typename CompDataType,
|
||||
bool kUseSoftmax,
|
||||
bool kUseCausal>
|
||||
struct reference_group_hstu_attention
|
||||
struct reference_group_hstu_attention_fwd
|
||||
{
|
||||
static void
|
||||
Run(bool is_cross_attention,
|
||||
Reference in New Issue
Block a user