mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Detach HstuBlockMask from pipeline definition and construct the HstuBlockMask type in the kernel according to window_size
This commit is contained in:
@@ -26,17 +26,16 @@ HSTU_FORWARD_INSTANCE_TEMPLATE_INC = """
|
||||
"""
|
||||
|
||||
HSTU_FORWARD_INSTANCE_TEMPLATE = """
|
||||
{extern}template void run_{mode}_forward_causal_local_bias_dropout_dispatch<
|
||||
{extern}template void run_{mode}_forward_causal_bias_dropout_dispatch<
|
||||
{dtype},
|
||||
{has_causal},
|
||||
{has_local},
|
||||
{has_bias},
|
||||
{has_dropout},
|
||||
{max_k}>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
"""
|
||||
|
||||
HSTU_FORWARD_INSTANCE_FNAME = (
|
||||
"hstu_attention_{mode}_forward_{dtype_str}_{has_or_no_causal_str}_{has_or_no_local_str}_"
|
||||
"hstu_attention_{mode}_forward_{dtype_str}_{has_or_no_causal_str}_"
|
||||
"{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp"
|
||||
)
|
||||
|
||||
@@ -49,11 +48,6 @@ BOOL_MAP_CAUSAL = {
|
||||
False: "no_causal",
|
||||
}
|
||||
|
||||
BOOL_MAP_LOCAL = {
|
||||
True: "has_local",
|
||||
False: "no_local",
|
||||
}
|
||||
|
||||
BOOL_MAP_BIAS = {
|
||||
True: "has_bias",
|
||||
False: "no_bias",
|
||||
@@ -84,7 +78,7 @@ MODE_NAME_MAP = {
|
||||
def create_forward_instances(instance_dir: Path, headdims: List) -> None:
|
||||
for mode in ["batched", "jagged"]:
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
for has_causal, has_local in ([True, True], [True, False], [False, True], [False, False]):
|
||||
for has_causal in [True, False]:
|
||||
for has_bias in [True, False]:
|
||||
for has_dropout in [True, False]:
|
||||
for max_k in headdims:
|
||||
@@ -92,7 +86,6 @@ def create_forward_instances(instance_dir: Path, headdims: List) -> None:
|
||||
mode=mode,
|
||||
dtype_str=dtype,
|
||||
has_or_no_causal_str=BOOL_MAP_CAUSAL[has_causal],
|
||||
has_or_no_local_str=BOOL_MAP_LOCAL[has_local],
|
||||
has_or_no_bias_str=BOOL_MAP_BIAS[has_bias],
|
||||
has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout],
|
||||
max_k_str=INT_MAP_MAX_K[max_k],
|
||||
@@ -108,7 +101,6 @@ def create_forward_instances(instance_dir: Path, headdims: List) -> None:
|
||||
mode=mode,
|
||||
dtype=TYPE_CTYPE_MAP[dtype],
|
||||
has_causal=BOOL_MAP[has_causal],
|
||||
has_local=BOOL_MAP[has_local],
|
||||
has_bias=BOOL_MAP[has_bias],
|
||||
has_dropout=BOOL_MAP[has_dropout],
|
||||
max_k=max_k,
|
||||
@@ -140,14 +132,13 @@ def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None:
|
||||
for max_k in headdims:
|
||||
for has_bias in [True, False]:
|
||||
for has_dropout in [True, False]:
|
||||
for has_causal, has_local in zip([True, False],[True, False]):
|
||||
for has_causal in [True, False]:
|
||||
forward_instance = (
|
||||
HSTU_FORWARD_INSTANCE_TEMPLATE.format(
|
||||
extern="extern ",
|
||||
mode=mode,
|
||||
dtype=TYPE_CTYPE_MAP[dtype],
|
||||
has_causal=BOOL_MAP[has_causal],
|
||||
has_local=BOOL_MAP[has_local],
|
||||
has_bias=BOOL_MAP[has_bias],
|
||||
has_dropout=BOOL_MAP[has_dropout],
|
||||
max_k=max_k,
|
||||
|
||||
@@ -17,24 +17,11 @@ void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStrea
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
if(param.window_size > 0)
|
||||
{
|
||||
run_batched_forward_causal_local_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
true,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
run_batched_forward_causal_local_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
false,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
};
|
||||
run_batched_forward_causal_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
#include "hstu_attention_fwd_setting.hpp"
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_block_masking.hpp"
|
||||
#include "hstu_attention_pipeline_problem.hpp"
|
||||
#include "hstu_attention_traits.hpp"
|
||||
#include "hstu_attention_fwd_pipeline.hpp"
|
||||
@@ -22,14 +21,12 @@
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseLocal,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK>
|
||||
struct batched_forward_causal_local_bias_dropout_dispatch
|
||||
struct batched_forward_causal_bias_dropout_dispatch
|
||||
{
|
||||
using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting<MaxK>::Type;
|
||||
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
|
||||
|
||||
template <typename HstuTraits>
|
||||
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
@@ -40,7 +37,7 @@ struct batched_forward_causal_local_bias_dropout_dispatch
|
||||
false, // kIsJagged
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
HstuMask,
|
||||
kUseCausal,
|
||||
HstuAttentionTileSetting,
|
||||
HstuTraits>;
|
||||
|
||||
@@ -140,17 +137,15 @@ struct batched_forward_causal_local_bias_dropout_dispatch
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseLocal,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK>
|
||||
void run_batched_forward_causal_local_bias_dropout_dispatch(HstuAttentionFwdParams& param,
|
||||
hipStream_t stream)
|
||||
void run_batched_forward_causal_bias_dropout_dispatch(HstuAttentionFwdParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
batched_forward_causal_local_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseLocal,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>::Run(param, stream);
|
||||
batched_forward_causal_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>::Run(param, stream);
|
||||
};
|
||||
|
||||
@@ -17,24 +17,11 @@ void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStrea
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
if(param.window_size > 0)
|
||||
{
|
||||
run_batched_forward_causal_local_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
true,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
run_batched_forward_causal_local_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
false,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
};
|
||||
run_batched_forward_causal_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
@@ -41,15 +41,14 @@ struct HstuAttentionFwdKernel
|
||||
using BiasDataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::BiasDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::ODataType>;
|
||||
|
||||
static constexpr bool kIsJagged = HstuAttentionPipeline::kIsJagged;
|
||||
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = HstuAttentionPipeline::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV = HstuAttentionPipeline::kPadHeadDimV;
|
||||
static constexpr auto kHasBias = HstuAttentionPipeline::kHasBias;
|
||||
static constexpr bool kHasDropout = HstuAttentionPipeline::kHasDropout;
|
||||
using HstuMask = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::HstuMask>;
|
||||
static constexpr bool kHasLocalMask = HstuMask::kUseLocal;
|
||||
static constexpr bool kIsJagged = HstuAttentionPipeline::kIsJagged;
|
||||
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = HstuAttentionPipeline::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV = HstuAttentionPipeline::kPadHeadDimV;
|
||||
static constexpr auto kHasBias = HstuAttentionPipeline::kHasBias;
|
||||
static constexpr bool kHasDropout = HstuAttentionPipeline::kHasDropout;
|
||||
static constexpr bool kHasCausalMask = HstuAttentionPipeline::kHasCausal;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class problem, introduce an template
|
||||
// arg
|
||||
@@ -93,6 +92,8 @@ struct HstuAttentionFwdKernel
|
||||
float scale_p; // scaling value exerted on the SiLU result
|
||||
|
||||
ck_tile::index_t contextual_seqlen;
|
||||
ck_tile::index_t window_size;
|
||||
ck_tile::index_t min_full_attn_seqlen;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdJaggModeBaseKargs
|
||||
@@ -126,10 +127,6 @@ struct HstuAttentionFwdKernel
|
||||
float scale_p; // scaling value exerted on the SiLU result
|
||||
|
||||
ck_tile::index_t contextual_seqlen;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdMaskKargs
|
||||
{
|
||||
ck_tile::index_t window_size;
|
||||
ck_tile::index_t min_full_attn_seqlen;
|
||||
};
|
||||
@@ -170,9 +167,6 @@ struct HstuAttentionFwdKernel
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdBatchModeKargs : HstuAttentionFwdBatchModeBaseKargs,
|
||||
std::conditional_t<kHasLocalMask,
|
||||
HstuAttentionFwdMaskKargs,
|
||||
HstuAttentionFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdBatchModeBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
@@ -183,9 +177,6 @@ struct HstuAttentionFwdKernel
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdJaggModeBaseKargs,
|
||||
std::conditional_t<kHasLocalMask,
|
||||
HstuAttentionFwdMaskKargs,
|
||||
HstuAttentionFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdCommonBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
@@ -258,17 +249,13 @@ struct HstuAttentionFwdKernel
|
||||
num_head,
|
||||
-scale_s,
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(seqlen), // max_seqlen
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
};
|
||||
|
||||
if constexpr(kHasLocalMask)
|
||||
{
|
||||
kargs.window_size = window_size;
|
||||
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
}
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
kargs.bias_ptr = bias_ptr;
|
||||
@@ -337,17 +324,13 @@ struct HstuAttentionFwdKernel
|
||||
num_head,
|
||||
-scale_s,
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen),
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
};
|
||||
|
||||
if constexpr(kHasLocalMask)
|
||||
{
|
||||
kargs.window_size = window_size;
|
||||
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
}
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
kargs.bias_ptr = bias_ptr;
|
||||
@@ -374,11 +357,8 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t num_tile_in_seqlen =
|
||||
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0);
|
||||
|
||||
if constexpr(kHasLocalMask)
|
||||
{
|
||||
if(has_minfull_attn_seqlen)
|
||||
num_tile_in_seqlen += 1;
|
||||
};
|
||||
if(has_minfull_attn_seqlen)
|
||||
num_tile_in_seqlen += 1;
|
||||
|
||||
if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim)
|
||||
{
|
||||
@@ -518,40 +498,34 @@ struct HstuAttentionFwdKernel
|
||||
bool is_tile_in_first_split = true;
|
||||
index_t i_m0;
|
||||
|
||||
if constexpr(kHasLocalMask)
|
||||
if(kargs.min_full_attn_seqlen > 0)
|
||||
{
|
||||
if(kargs.min_full_attn_seqlen > 0)
|
||||
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
|
||||
if(kargs.seqlen - num_target > kargs.min_full_attn_seqlen)
|
||||
{
|
||||
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
|
||||
if(kargs.seqlen - num_target > kargs.min_full_attn_seqlen)
|
||||
{
|
||||
seqlen_in_first_split = kargs.seqlen - num_target - kargs.min_full_attn_seqlen;
|
||||
seqlen_in_first_split = kargs.seqlen - num_target - kargs.min_full_attn_seqlen;
|
||||
|
||||
index_t num_tile_in_first_split = ck_tile::integer_divide_ceil(
|
||||
seqlen_in_first_split, HstuAttentionPipeline::kM0);
|
||||
index_t num_tile_in_first_split =
|
||||
ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0);
|
||||
|
||||
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
|
||||
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
|
||||
|
||||
i_m0 =
|
||||
is_tile_in_first_split
|
||||
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
|
||||
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
|
||||
HstuAttentionPipeline::kM0) +
|
||||
seqlen_in_first_split;
|
||||
}
|
||||
else
|
||||
{
|
||||
seqlen_in_first_split = 0;
|
||||
is_tile_in_first_split = false;
|
||||
|
||||
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
|
||||
kargs.min_full_attn_seqlen = kargs.seqlen - num_target;
|
||||
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
};
|
||||
i_m0 = is_tile_in_first_split
|
||||
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
|
||||
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
|
||||
HstuAttentionPipeline::kM0) +
|
||||
seqlen_in_first_split;
|
||||
}
|
||||
else
|
||||
{
|
||||
seqlen_in_first_split = 0;
|
||||
is_tile_in_first_split = false;
|
||||
|
||||
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
|
||||
kargs.min_full_attn_seqlen = kargs.seqlen - num_target;
|
||||
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
};
|
||||
}
|
||||
else
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
@@ -563,19 +537,6 @@ struct HstuAttentionFwdKernel
|
||||
if(seqlen_q_in_ctrl <= i_m0)
|
||||
return;
|
||||
|
||||
HstuMask mask = [&]() {
|
||||
if constexpr(kHasLocalMask)
|
||||
return make_hstu_block_mask_with_local<HstuMask>(is_tile_in_first_split,
|
||||
kargs.seqlen,
|
||||
kargs.contextual_seqlen,
|
||||
num_target,
|
||||
kargs.window_size,
|
||||
kargs.min_full_attn_seqlen);
|
||||
else
|
||||
return make_hstu_block_mask_without_local<HstuMask>(
|
||||
kargs.seqlen, kargs.contextual_seqlen, num_target);
|
||||
}();
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const QKVDataType* q_ptr = reinterpret_cast<const QKVDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
@@ -706,15 +667,44 @@ struct HstuAttentionFwdKernel
|
||||
}();
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
if(kargs.window_size > 0)
|
||||
{
|
||||
using HstuMaskType = typename ck_tile::HstuBlockMasking<kHasCausalMask, true>::Type;
|
||||
const auto mask =
|
||||
make_hstu_block_mask_with_local<HstuMaskType>(is_tile_in_first_split,
|
||||
kargs.seqlen,
|
||||
kargs.contextual_seqlen,
|
||||
num_target,
|
||||
kargs.window_size,
|
||||
kargs.min_full_attn_seqlen);
|
||||
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
{
|
||||
using HstuMaskType =
|
||||
typename ck_tile::HstuBlockMasking<kHasCausalMask, false>::Type;
|
||||
const auto mask = make_hstu_block_mask_without_local<HstuMaskType>(
|
||||
kargs.seqlen, kargs.contextual_seqlen, num_target);
|
||||
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
};
|
||||
}();
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
|
||||
@@ -21,7 +21,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using HstuMask = remove_cvref_t<typename Problem::HstuMask>;
|
||||
|
||||
using HstuAttentionTileSetting = remove_cvref_t<typename Problem::HstuAttentionTileSetting>;
|
||||
|
||||
@@ -40,6 +39,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
static constexpr bool kIsJagged = Problem::kIsJagged;
|
||||
static constexpr auto kHasBias = Problem::kHasBias;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasCausal = Problem::kHasCausal;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
|
||||
@@ -118,7 +118,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
typename BiasElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction>
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
@@ -131,7 +132,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
HstuMask mask,
|
||||
HstuMask& mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLu result
|
||||
void* smem_ptr,
|
||||
@@ -577,7 +578,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp>
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
|
||||
@@ -17,24 +17,11 @@ void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
if(param.window_size > 0)
|
||||
{
|
||||
run_jagged_forward_causal_local_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
true,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
run_jagged_forward_causal_local_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
false,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
};
|
||||
run_jagged_forward_causal_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
#include "hstu_attention_fwd_setting.hpp"
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_block_masking.hpp"
|
||||
#include "hstu_attention_pipeline_problem.hpp"
|
||||
#include "hstu_attention_traits.hpp"
|
||||
#include "hstu_attention_fwd_pipeline.hpp"
|
||||
@@ -22,14 +21,12 @@
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseLocal,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK>
|
||||
struct jagged_forward_causal_local_bias_dropout_dispatch
|
||||
struct jagged_forward_causal_bias_dropout_dispatch
|
||||
{
|
||||
using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting<MaxK>::Type;
|
||||
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
|
||||
|
||||
template <typename HstuTraits>
|
||||
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
@@ -40,7 +37,7 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
|
||||
true, // kIsJagged
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
HstuMask,
|
||||
kUseCausal,
|
||||
HstuAttentionTileSetting,
|
||||
HstuTraits>;
|
||||
|
||||
@@ -131,17 +128,15 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseLocal,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK>
|
||||
void run_jagged_forward_causal_local_bias_dropout_dispatch(HstuAttentionFwdParams& param,
|
||||
hipStream_t stream)
|
||||
void run_jagged_forward_causal_bias_dropout_dispatch(HstuAttentionFwdParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
jagged_forward_causal_local_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseLocal,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>::Run(param, stream);
|
||||
jagged_forward_causal_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>::Run(param, stream);
|
||||
};
|
||||
|
||||
@@ -17,24 +17,11 @@ void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
if(param.window_size > 0)
|
||||
{
|
||||
run_jagged_forward_causal_local_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
true,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
run_jagged_forward_causal_local_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
false,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
};
|
||||
run_jagged_forward_causal_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
@@ -20,7 +20,7 @@ template <typename InOutDataType_,
|
||||
bool kIsJagged_,
|
||||
bool kHasBias_,
|
||||
bool kHasDropout_,
|
||||
typename HstuMask_, // encoding Causal and Local, contextual masking
|
||||
bool kHasCausal_,
|
||||
typename AttentionTileSetting_,
|
||||
typename Traits_>
|
||||
struct HstuAttentionFwdPipelineProblem
|
||||
@@ -41,8 +41,7 @@ struct HstuAttentionFwdPipelineProblem
|
||||
static constexpr bool kIsJagged = kIsJagged_;
|
||||
static constexpr bool kHasBias = kHasBias_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
|
||||
using HstuMask = remove_cvref_t<HstuMask_>;
|
||||
static constexpr bool kHasCausal = kHasCausal_;
|
||||
|
||||
using HstuAttentionTileSetting = remove_cvref_t<AttentionTileSetting_>;
|
||||
|
||||
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,194 +9,170 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,9 +9,8 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,194 +9,170 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
extern template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -9,10 +9,9 @@
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
template void run_batched_forward_causal_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// The file is automatically generated, don't modify!
|
||||
// See the generator script
|
||||
// `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user