mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Update to support grouped mode hstu attention
This commit is contained in:
@@ -3,7 +3,7 @@ set(EXAMPLE_HSTU_ATTENTION "tile_example_hstu_attention")
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_HSTU_ATTENTION}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
set(INTERFACES_SRCS hstu_attention_jagged_forward_bf16.cpp hstu_attention_jagged_forward_fp16.cpp hstu_attention_batched_forward_bf16.cpp hstu_attention_batched_forward_fp16.cpp)
|
||||
set(INTERFACES_SRCS hstu_attention_no_group_forward_bf16.cpp hstu_attention_no_group_forward_fp16.cpp hstu_attention_group_forward_bf16.cpp hstu_attention_group_forward_fp16.cpp)
|
||||
add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention.cpp)
|
||||
target_include_directories(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${INTERFACES_SRCS} ${INSTANCE_SRCS})
|
||||
|
||||
@@ -29,27 +29,33 @@
|
||||
|
||||
``` C++
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("g", "1", "num of attention group, bigger than 1 indicating group hstu")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("jagged", "0", "q/k/v batched sequence is jagged or not")
|
||||
.insert("b", "12", "batch size")
|
||||
.insert("b", "12", "number of batches")
|
||||
.insert("nhead", "4", "number of heads")
|
||||
.insert("hdim_qk", "64", "headdim size of Q/K")
|
||||
.insert("hdim_v", "64", "headdim size of V/O")
|
||||
.insert("seqlens", "400", "seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
|
||||
.insert("seqlens", "400", "uih seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
|
||||
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("g_max_seqlens", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
|
||||
.insert("causal", "1", "enable causal mask or not")
|
||||
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("g_local_lens", "5,", "list of all group's length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
.insert("g_context_lens", "6,", "list of all group's sequence length at the begin of the query sequence that should be included for attention")
|
||||
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
|
||||
.insert("g_minfull_lens", "6", "list of all groups's sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("norm_dist", "0", "if true, initialize the data in normal distribution, or else in uniform distribution")
|
||||
.insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("attn_scale", "0", "scale factor of SiLu(Q@K), 0 means using 1/max_seqlen for scaling")
|
||||
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
|
||||
.insert("perf", "0", "weather measure execution time or not");
|
||||
.insert("attn_scale", "0", "scale factor of SiLU(Q@K). 0 means using 1/max_seqlen for scaling")
|
||||
.insert("g_attn_scales", "1.0,", "list of all groups's scale factors of S=@@K. 0 means using 1/max_seqlen of the group for scaling")
|
||||
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
|
||||
.insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes")
|
||||
.insert("perf", "0", "weather measure execution time or not")
|
||||
.insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");
|
||||
```
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <cassert>
|
||||
|
||||
#include <ck_tile/host/host_tensor.hpp>
|
||||
#include <ck_tile/host/fill.hpp>
|
||||
@@ -90,9 +91,10 @@ auto create_args(int argc, char* argv[])
|
||||
|
||||
// clang-format off
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("g", "1", "num of attention group, bigger than 1 indicating group hstu")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("jagged", "0", "q/k/v batched sequence is jagged or not")
|
||||
.insert("b", "12", "batch size")
|
||||
.insert("b", "12", "number of batches")
|
||||
.insert("nhead", "4", "number of heads")
|
||||
.insert("hdim_qk", "64", "headdim size of Q/K")
|
||||
.insert("hdim_v", "64", "headdim size of V/O")
|
||||
@@ -100,17 +102,22 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("seqlens_kv", "", "uih seqlen of single or all batches for key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
|
||||
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("max_seqlen_kv", "0", "max uih_seqlen_kv, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("g_max_seqlens", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
|
||||
.insert("softmax", "0", "use softmax or not")
|
||||
.insert("causal", "1", "enable causal mask or not")
|
||||
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("g_local_lens", "5,", "list of all group's length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
.insert("g_context_lens", "6,", "list of all group's sequence length at the begin of the query sequence that should be included for attention")
|
||||
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("g_minfull_lens", "6", "list of all groups's sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("norm_dist", "0", "if true, initialize the data in normal distribution, or else in uniform distribution")
|
||||
.insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("attn_scale", "0", "scale factor of SiLU(Q@K). 0 means using 1/max_seqlen for scaling")
|
||||
.insert("g_attn_scales", "1.0,", "list of all groups's scale factors of S=@@K. 0 means using 1/max_seqlen of the group for scaling")
|
||||
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
|
||||
.insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes")
|
||||
.insert("perf", "0", "weather measure execution time or not")
|
||||
@@ -121,35 +128,66 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
static std::vector<int> get_integers_from_string(std::string lengthsStr)
|
||||
static std::vector<int> get_integers_from_string(std::string srcStr)
|
||||
{
|
||||
std::vector<int> lengths;
|
||||
std::vector<int> integers;
|
||||
std::size_t pos = 0;
|
||||
std::size_t new_pos;
|
||||
|
||||
new_pos = lengthsStr.find(',', pos);
|
||||
new_pos = srcStr.find(',', pos);
|
||||
while(new_pos != std::string::npos)
|
||||
{
|
||||
std::string sliceStr = lengthsStr.substr(pos, new_pos - pos);
|
||||
std::string sliceStr = srcStr.substr(pos, new_pos - pos);
|
||||
|
||||
int len = std::stoi(sliceStr);
|
||||
|
||||
lengths.push_back(len);
|
||||
integers.push_back(len);
|
||||
|
||||
pos = new_pos + 1;
|
||||
new_pos = lengthsStr.find(',', pos);
|
||||
new_pos = srcStr.find(',', pos);
|
||||
};
|
||||
|
||||
std::string sliceStr = lengthsStr.substr(pos);
|
||||
std::string sliceStr = srcStr.substr(pos);
|
||||
|
||||
if(!sliceStr.empty())
|
||||
{
|
||||
int len = std::stoi(sliceStr);
|
||||
|
||||
lengths.push_back(len);
|
||||
integers.push_back(len);
|
||||
};
|
||||
|
||||
return (lengths);
|
||||
return (integers);
|
||||
};
|
||||
|
||||
static std::vector<float> get_floats_from_string(std::string srcStr)
|
||||
{
|
||||
std::vector<float> values;
|
||||
std::size_t pos = 0;
|
||||
std::size_t new_pos;
|
||||
|
||||
new_pos = srcStr.find(',', pos);
|
||||
while(new_pos != std::string::npos)
|
||||
{
|
||||
std::string sliceStr = srcStr.substr(pos, new_pos - pos);
|
||||
|
||||
float val = std::stof(sliceStr);
|
||||
|
||||
values.push_back(val);
|
||||
|
||||
pos = new_pos + 1;
|
||||
new_pos = srcStr.find(',', pos);
|
||||
};
|
||||
|
||||
std::string sliceStr = srcStr.substr(pos);
|
||||
|
||||
if(!sliceStr.empty())
|
||||
{
|
||||
float val = std::stof(sliceStr);
|
||||
|
||||
values.push_back(val);
|
||||
};
|
||||
|
||||
return (values);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -164,42 +202,6 @@ void supplement_array_by_last_element(std::vector<T>& arr, int target_num_elemen
|
||||
};
|
||||
};
|
||||
|
||||
static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdParams& param)
|
||||
{
|
||||
if(param.is_jagged)
|
||||
{
|
||||
os << "Jagged inputs used! " << std::endl;
|
||||
os << "use causal: " << param.use_causal << std::endl;
|
||||
os << "Num of batches: " << param.num_batch << std::endl;
|
||||
os << "Num of heads: " << param.num_head << std::endl;
|
||||
os << "QK hdim: " << param.hdim_qk << " V hdim: " << param.hdim_v << std::endl;
|
||||
os << "Q/K/V/O seq stride: " << param.seq_stride_q << " " << param.seq_stride_k << " "
|
||||
<< param.seq_stride_v << " " << param.seq_stride_o << std::endl;
|
||||
os << "Q/K/V/O nhead stride: " << param.nhead_stride_q << " " << param.nhead_stride_k << " "
|
||||
<< param.nhead_stride_v << " " << param.nhead_stride_o << std::endl;
|
||||
os << "contextual_seqlen: " << param.contextual_seqlen << std::endl;
|
||||
os << "window_size: " << param.window_size << std::endl;
|
||||
os << "min_full_attn_seqlen: " << param.min_full_attn_seqlen << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
os << "Batched inputs used! " << std::endl;
|
||||
os << "use causal: " << param.use_causal << std::endl;
|
||||
os << "Num of batches: " << param.num_batch << std::endl;
|
||||
os << "Num of heads: " << param.num_head << std::endl;
|
||||
os << "QK hdim: " << param.hdim_qk << " V hdim: " << param.hdim_v << std::endl;
|
||||
os << "Q/K/V/O seq stride: " << param.seq_stride_q << " " << param.seq_stride_k << " "
|
||||
<< param.seq_stride_v << " " << param.seq_stride_o << std::endl;
|
||||
os << "Q/K/V/O nhead stride: " << param.nhead_stride_q << " " << param.nhead_stride_k << " "
|
||||
<< param.nhead_stride_v << " " << param.nhead_stride_o << std::endl;
|
||||
os << "Q/K/V/O batch stride: " << param.batch_stride_q << " " << param.batch_stride_k << " "
|
||||
<< param.batch_stride_v << " " << param.batch_stride_o << std::endl;
|
||||
os << "contextual_seqlen: " << param.contextual_seqlen << std::endl;
|
||||
os << "window_size: " << param.window_size << std::endl;
|
||||
os << "min_full_attn_seqlen: " << param.min_full_attn_seqlen << std::endl;
|
||||
};
|
||||
};
|
||||
|
||||
// threshold for different dtypes
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
@@ -219,10 +221,9 @@ auto get_elimit<ck_tile::bf16_t>()
|
||||
}
|
||||
|
||||
template <typename InOutDataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
|
||||
{
|
||||
bool do_validation = static_cast<bool>(arg_parser.get_int("v"));
|
||||
bool is_jagged = static_cast<bool>(arg_parser.get_int("jagged"));
|
||||
int num_batch = arg_parser.get_int("b");
|
||||
int num_head = arg_parser.get_int("nhead");
|
||||
int hdim_qk = arg_parser.get_int("hdim_qk");
|
||||
@@ -230,11 +231,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bool use_softmax = static_cast<bool>(arg_parser.get_int("softmax"));
|
||||
bool use_causal = static_cast<bool>(arg_parser.get_int("causal"));
|
||||
|
||||
int window_size = arg_parser.get_int("local_len");
|
||||
|
||||
int contextual_seqlen = arg_parser.get_int("context_len");
|
||||
int min_full_attn_seqlen = arg_parser.get_int("minfull_len");
|
||||
|
||||
float alpha = arg_parser.get_float("alpha");
|
||||
float attn_scale = arg_parser.get_float("attn_scale");
|
||||
int seed = arg_parser.get_int("seed");
|
||||
@@ -245,8 +241,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bool save_mask = static_cast<bool>(arg_parser.get_int("save_mask"));
|
||||
bool initialize_qkv = static_cast<bool>(arg_parser.get_int("init_qkv"));
|
||||
|
||||
std::string str_of_targets = arg_parser.get_str("targets");
|
||||
std::vector<int> num_targets = get_integers_from_string(str_of_targets);
|
||||
std::string str_of_integers;
|
||||
|
||||
str_of_integers = arg_parser.get_str("targets");
|
||||
std::vector<int> num_targets = get_integers_from_string(str_of_integers);
|
||||
|
||||
int window_size = arg_parser.get_int("local_len");
|
||||
|
||||
int contextual_seqlen = arg_parser.get_int("context_len");
|
||||
int min_full_attn_seqlen = arg_parser.get_int("minfull_len");
|
||||
|
||||
std::string str_of_lengths_q = arg_parser.get_str("seqlens");
|
||||
std::vector<int> seq_lengths_q = get_integers_from_string(str_of_lengths_q);
|
||||
@@ -265,16 +268,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
bool is_cross_attention = false;
|
||||
|
||||
if(!num_targets.empty())
|
||||
{
|
||||
// supplement num_targets using the last input value if user-provided lengths not enough
|
||||
supplement_array_by_last_element(num_targets, num_batch);
|
||||
|
||||
// only consider num_batch values even if more values are provided by the user
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
max_target = max(max_target, num_targets[i]);
|
||||
};
|
||||
|
||||
HSTU_CHECK(!seq_lengths_q.empty(), "sequence lengths of q shoud be defined!");
|
||||
|
||||
// assume seq_lengths_kv is same as seq_lengths_q if not defined, or else when
|
||||
@@ -293,15 +286,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
// supplement seq_lengths_q using the last input value if user-provided lengths not
|
||||
// enough
|
||||
// supplement seq_lengths_q using the last input value if user-provided lengths not enough
|
||||
supplement_array_by_last_element(seq_lengths_q, num_batch);
|
||||
|
||||
// supplement seq_lengths_kv using the last input value if user-provided lengths not
|
||||
// enough
|
||||
// supplement seq_lengths_kv using the last input value if user-provided lengths not enough
|
||||
supplement_array_by_last_element(seq_lengths_kv, num_batch);
|
||||
|
||||
// only consider num_batch values even if more values are provided by the user
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
{
|
||||
max_uih_seqlen_q = max(max_uih_seqlen_q, seq_lengths_q[i]);
|
||||
@@ -316,6 +306,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
max_uih_seqlen_kv = seq_lengths_kv[0];
|
||||
};
|
||||
|
||||
if(!num_targets.empty())
|
||||
{
|
||||
// supplement num_targets using the last input value if user-provided lengths not enough
|
||||
supplement_array_by_last_element(num_targets, num_batch);
|
||||
|
||||
// only consider num_batch values even if more values are provided by the user
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
max_target = max(max_target, num_targets[i]);
|
||||
};
|
||||
|
||||
// the user input of max_uih_seqlen can either be ignored or be bigger than all uih_seqlens
|
||||
// the user input of max_target can either be ignored or be bigger than all targets
|
||||
HSTU_CHECK(input_max_uih_seqlen_q <= 0 || input_max_uih_seqlen_q >= max_uih_seqlen_q,
|
||||
@@ -386,7 +386,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
long total_flops = 0;
|
||||
|
||||
// estimate the total flops occurred, ignoring the scaling and SILu
|
||||
// estimate the total flops occurred, ignoring the scaling and SiLu
|
||||
if(is_jagged)
|
||||
{
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
@@ -467,7 +467,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(!num_targets.empty())
|
||||
num_targets_dev.ToDevice(num_targets.data());
|
||||
|
||||
HstuAttentionFwdParams params;
|
||||
HstuAttentionNoGroupFwdParams params;
|
||||
|
||||
float scale_s = (alpha != 0.f) ? alpha : 1.0f / std::sqrt(hdim_qk);
|
||||
|
||||
@@ -552,26 +552,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
params.philox_offset = 0UL;
|
||||
};
|
||||
|
||||
// show_hstu_attention_fwd_param(std::cout, params);
|
||||
std::ignore = show_hstu_attention_fwd_param;
|
||||
|
||||
hipStream_t stream;
|
||||
|
||||
HIP_CHECK_ERROR(hipStreamCreate(&stream));
|
||||
|
||||
if constexpr(std::is_same<InOutDataType, ck_tile::fp16_t>::value)
|
||||
{
|
||||
if(is_jagged)
|
||||
hstu_attention_jagged_forward_fp16(params, stream);
|
||||
else
|
||||
hstu_attention_batched_forward_fp16(params, stream);
|
||||
hstu_attention_no_group_forward_fp16(params, stream);
|
||||
}
|
||||
else if constexpr(std::is_same<InOutDataType, ck_tile::bf16_t>::value)
|
||||
{
|
||||
if(is_jagged)
|
||||
hstu_attention_jagged_forward_bf16(params, stream);
|
||||
else
|
||||
hstu_attention_batched_forward_bf16(params, stream);
|
||||
hstu_attention_no_group_forward_bf16(params, stream);
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("Other data type is not supported at present!");
|
||||
@@ -584,28 +575,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using CompDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType;
|
||||
|
||||
BOOL_SWITCH_3(is_jagged, kIsJagged, use_softmax, kUseSoftmax, use_causal, kUseCausal, [&] {
|
||||
ck_tile::reference_hstu_attention<InOutDataType,
|
||||
GemmAccDataType,
|
||||
CompDataType,
|
||||
kIsJagged,
|
||||
kUseSoftmax,
|
||||
kUseCausal>::Run(is_cross_attention,
|
||||
q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
mask_host,
|
||||
num_batch,
|
||||
scale_s,
|
||||
attn_scale,
|
||||
max_seqlen_q,
|
||||
max_seqlen_kv,
|
||||
seq_offsets_q,
|
||||
seq_offsets_kv,
|
||||
num_targets,
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen);
|
||||
ck_tile::reference_no_group_hstu_attention<InOutDataType,
|
||||
GemmAccDataType,
|
||||
CompDataType,
|
||||
kIsJagged,
|
||||
kUseSoftmax,
|
||||
kUseCausal>::Run(is_cross_attention,
|
||||
q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
mask_host,
|
||||
num_batch,
|
||||
scale_s,
|
||||
attn_scale,
|
||||
max_seqlen_q,
|
||||
max_seqlen_kv,
|
||||
seq_offsets_q,
|
||||
seq_offsets_kv,
|
||||
num_targets,
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen);
|
||||
});
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> o_host(
|
||||
@@ -638,17 +629,385 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
if constexpr(std::is_same<InOutDataType, ck_tile::fp16_t>::value)
|
||||
{
|
||||
if(is_jagged)
|
||||
hstu_attention_jagged_forward_fp16(params, stream);
|
||||
else
|
||||
hstu_attention_batched_forward_fp16(params, stream);
|
||||
hstu_attention_no_group_forward_fp16(params, stream);
|
||||
}
|
||||
else if constexpr(std::is_same<InOutDataType, ck_tile::bf16_t>::value)
|
||||
{
|
||||
if(is_jagged)
|
||||
hstu_attention_jagged_forward_bf16(params, stream);
|
||||
else
|
||||
hstu_attention_batched_forward_bf16(params, stream);
|
||||
hstu_attention_no_group_forward_bf16(params, stream);
|
||||
}
|
||||
}
|
||||
timer.stop(stream);
|
||||
|
||||
auto ms = timer.duration() / 10.f;
|
||||
|
||||
std::cout << "Average execution time of the hstu_attention operation is " << ms
|
||||
<< " milli-seconds, estimated TFLOPS is "
|
||||
<< (static_cast<float>(total_flops) / ms) / 1.0e9 << std::endl;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename InOutDataType>
|
||||
bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
{
|
||||
bool do_validation = static_cast<bool>(arg_parser.get_int("v"));
|
||||
|
||||
int num_batch = arg_parser.get_int("b");
|
||||
|
||||
HSTU_CHECK(num_group > 1, "ru_group_hstu should only be called when num_group > 1 !");
|
||||
HSTU_CHECK(num_batch > 0 && num_batch % num_group == 0,
|
||||
"number of batches should be a multi-fold value of num_group!");
|
||||
|
||||
int num_batch_per_group = num_batch / num_group;
|
||||
|
||||
int num_head = arg_parser.get_int("nhead");
|
||||
int hdim_qk = arg_parser.get_int("hdim_qk");
|
||||
int hdim_v = arg_parser.get_int("hdim_v");
|
||||
bool use_softmax = static_cast<bool>(arg_parser.get_int("softmax"));
|
||||
bool use_causal = static_cast<bool>(arg_parser.get_int("causal"));
|
||||
float alpha = arg_parser.get_float("alpha");
|
||||
int seed = arg_parser.get_int("seed");
|
||||
bool use_normal_dist = arg_parser.get_int("norm_dist");
|
||||
bool measure_perf = static_cast<bool>(arg_parser.get_int("perf"));
|
||||
bool dump_output = static_cast<bool>(arg_parser.get_int("dump_output"));
|
||||
|
||||
bool save_mask = static_cast<bool>(arg_parser.get_int("save_mask"));
|
||||
bool initialize_qkv = static_cast<bool>(arg_parser.get_int("init_qkv"));
|
||||
|
||||
std::string str_of_integers;
|
||||
|
||||
str_of_integers = arg_parser.get_str("targets");
|
||||
std::vector<int> num_targets = get_integers_from_string(str_of_integers);
|
||||
|
||||
std::string str_of_lengths_q = arg_parser.get_str("seqlens");
|
||||
std::vector<int> seq_lengths_q = get_integers_from_string(str_of_lengths_q);
|
||||
|
||||
std::string str_of_lengths_kv = arg_parser.get_str("seqlens_kv");
|
||||
std::vector<int> seq_lengths_kv = get_integers_from_string(str_of_lengths_kv);
|
||||
|
||||
bool is_cross_attention = false;
|
||||
|
||||
HSTU_CHECK(!seq_lengths_q.empty(), "sequence lengths shoud be defined!");
|
||||
|
||||
// assume seq_lengths_kv is same as seq_lengths_q if not defined, or else when
|
||||
// seq_lengths_kv is explicitly defined, we think the input case is a cross_attention case
|
||||
if(seq_lengths_kv.empty())
|
||||
seq_lengths_kv = seq_lengths_q;
|
||||
else
|
||||
is_cross_attention = true;
|
||||
|
||||
str_of_integers = arg_parser.get_str("g_max_seqlens");
|
||||
std::vector<int> group_max_seqlens = get_integers_from_string(str_of_integers);
|
||||
|
||||
HSTU_CHECK(!group_max_seqlens.empty(), "group window sizes shoud be defined!");
|
||||
|
||||
str_of_integers = arg_parser.get_str("g_context_lens");
|
||||
std::vector<int> group_contextual_seqlens = get_integers_from_string(str_of_integers);
|
||||
|
||||
HSTU_CHECK(!group_contextual_seqlens.empty(), "group contextual seqlens shoud be defined!");
|
||||
|
||||
str_of_integers = arg_parser.get_str("g_local_lens");
|
||||
std::vector<int> group_window_sizes = get_integers_from_string(str_of_integers);
|
||||
|
||||
HSTU_CHECK(!group_window_sizes.empty(), "group window sizes shoud be defined!");
|
||||
|
||||
str_of_integers = arg_parser.get_str("g_minfull_lens");
|
||||
std::vector<int> group_min_full_attn_seqlens = get_integers_from_string(str_of_integers);
|
||||
HSTU_CHECK(!group_min_full_attn_seqlens.empty(),
|
||||
"group min_full_attn seqlens shoud be defined!");
|
||||
|
||||
std::string str_of_floats = arg_parser.get_str("g_attn_scales");
|
||||
std::vector<float> group_attn_scales = get_floats_from_string(str_of_floats);
|
||||
HSTU_CHECK(!group_attn_scales.empty(), "group attn_scales shoud be defined!");
|
||||
|
||||
// supplement seq_lengths_q using the last input value if user-provided lengths not enough
|
||||
supplement_array_by_last_element(seq_lengths_q, num_batch);
|
||||
|
||||
// supplement seq_lengths_kv using the last input value if user-provided lengths not enough
|
||||
supplement_array_by_last_element(seq_lengths_kv, num_batch);
|
||||
|
||||
if(!num_targets.empty())
|
||||
{
|
||||
// supplement num_targets using the last input value if user-provided lengths not enough
|
||||
supplement_array_by_last_element(num_targets, num_batch);
|
||||
};
|
||||
|
||||
// supplement group_max_seqlens using the last input value if user-provided lengths not enough
|
||||
supplement_array_by_last_element(group_max_seqlens, num_group);
|
||||
|
||||
// supplement group_contextual_seqlens using the last input value if user-provided lengths not
|
||||
// enough
|
||||
supplement_array_by_last_element(group_contextual_seqlens, num_group);
|
||||
|
||||
// supplement group_window_sizes using the last input value if user-provided lengths not enough
|
||||
supplement_array_by_last_element(group_window_sizes, num_group);
|
||||
|
||||
// supplement group_min_full_attn_seqlens using the last input value if user-provided lengths
|
||||
// not enough
|
||||
supplement_array_by_last_element(group_min_full_attn_seqlens, num_group);
|
||||
|
||||
// supplement group_attn_scales using the last input value if user-provided values not enough
|
||||
supplement_array_by_last_element(group_attn_scales, num_group);
|
||||
|
||||
int phy_seqlen_q = 0;
|
||||
int phy_seqlen_kv = 0;
|
||||
int max_max_seqlen = 0;
|
||||
|
||||
// only consider num_group values even if more values were provided by the user
|
||||
for(int i = 0; i < num_group; i++)
|
||||
{
|
||||
max_max_seqlen = max(max_max_seqlen, group_max_seqlens[i]);
|
||||
};
|
||||
|
||||
std::vector<int> seq_offsets_q;
|
||||
std::vector<int> seq_offsets_kv;
|
||||
|
||||
seq_offsets_q.push_back(0);
|
||||
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
{
|
||||
int i_group = i / num_batch_per_group;
|
||||
int batch_seqlen =
|
||||
num_targets.empty()
|
||||
? seq_lengths_q[i] + group_contextual_seqlens[i_group]
|
||||
: seq_lengths_q[i] + num_targets[i] + group_contextual_seqlens[i_group];
|
||||
|
||||
phy_seqlen_q += batch_seqlen;
|
||||
seq_offsets_q.push_back(phy_seqlen_q);
|
||||
};
|
||||
|
||||
seq_offsets_kv.push_back(0);
|
||||
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
{
|
||||
if(!is_cross_attention)
|
||||
{
|
||||
int i_group = i / num_batch_per_group;
|
||||
int batch_seqlen =
|
||||
num_targets.empty()
|
||||
? seq_lengths_kv[i] + group_contextual_seqlens[i_group]
|
||||
: seq_lengths_kv[i] + num_targets[i] + group_contextual_seqlens[i_group];
|
||||
|
||||
phy_seqlen_kv += batch_seqlen;
|
||||
seq_offsets_kv.push_back(phy_seqlen_kv);
|
||||
}
|
||||
else // for cross_attention, assume target_in_kv == false
|
||||
{
|
||||
int i_group = i / num_batch_per_group;
|
||||
int batch_seqlen = seq_lengths_kv[i] + group_contextual_seqlens[i_group];
|
||||
|
||||
phy_seqlen_kv += batch_seqlen;
|
||||
seq_offsets_kv.push_back(phy_seqlen_kv);
|
||||
}
|
||||
};
|
||||
|
||||
long total_flops = 0;
|
||||
|
||||
// estimate the total flops occurred, ignoring the scaling and SILu
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
{
|
||||
int len_q = seq_offsets_q[i + 1] - seq_offsets_q[i];
|
||||
int len_kv = seq_offsets_kv[i + 1] - seq_offsets_kv[i];
|
||||
total_flops += (static_cast<long>(len_q) * len_kv * hdim_qk +
|
||||
static_cast<long>(len_q) * hdim_v * len_kv) *
|
||||
2;
|
||||
};
|
||||
|
||||
total_flops *= num_head;
|
||||
|
||||
int batches_for_alloc = 1;
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> q_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_qk});
|
||||
ck_tile::HostTensor<InOutDataType> k_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_kv, num_head, hdim_qk});
|
||||
ck_tile::HostTensor<InOutDataType> v_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_kv, num_head, hdim_v});
|
||||
ck_tile::HostTensor<InOutDataType> o_host_ref(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
|
||||
|
||||
ck_tile::HostTensor<int8_t> mask_host(
|
||||
save_mask
|
||||
? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_max_seqlen, max_max_seqlen}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
|
||||
if(!initialize_qkv)
|
||||
{
|
||||
if(use_normal_dist)
|
||||
{
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(v_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<InOutDataType>{-1.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<InOutDataType>{-1.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<InOutDataType>{-1.f, 1.f, seed}(v_host);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
readDataToBufferFromFile(q_host.data(), q_host.get_element_space_size(), "q.dat");
|
||||
readDataToBufferFromFile(k_host.data(), k_host.get_element_space_size(), "k.dat");
|
||||
readDataToBufferFromFile(v_host.data(), v_host.get_element_space_size(), "v.dat");
|
||||
};
|
||||
|
||||
ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_dev(o_host_ref.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem seq_offsets_q_dev(seq_offsets_q.size() * sizeof(int));
|
||||
ck_tile::DeviceMem seq_offsets_kv_dev(seq_offsets_kv.size() * sizeof(int));
|
||||
ck_tile::DeviceMem num_targets_dev(num_targets.size() * sizeof(int));
|
||||
|
||||
q_dev.ToDevice(q_host.data());
|
||||
k_dev.ToDevice(k_host.data());
|
||||
v_dev.ToDevice(v_host.data());
|
||||
|
||||
seq_offsets_q_dev.ToDevice(seq_offsets_q.data());
|
||||
seq_offsets_kv_dev.ToDevice(seq_offsets_kv.data());
|
||||
if(!num_targets.empty())
|
||||
num_targets_dev.ToDevice(num_targets.data());
|
||||
|
||||
ck_tile::DeviceMem group_max_seqlens_dev(group_max_seqlens.size() * sizeof(int));
|
||||
ck_tile::DeviceMem group_contextual_seqlens_dev(group_contextual_seqlens.size() * sizeof(int));
|
||||
ck_tile::DeviceMem group_window_sizes_dev(group_window_sizes.size() * sizeof(int));
|
||||
ck_tile::DeviceMem group_min_full_attn_seqlens_dev(group_min_full_attn_seqlens.size() *
|
||||
sizeof(int));
|
||||
ck_tile::DeviceMem group_attn_scales_dev(group_attn_scales.size() * sizeof(float));
|
||||
|
||||
group_max_seqlens_dev.ToDevice(group_max_seqlens.data());
|
||||
group_contextual_seqlens_dev.ToDevice(group_contextual_seqlens.data());
|
||||
group_window_sizes_dev.ToDevice(group_window_sizes.data());
|
||||
group_min_full_attn_seqlens_dev.ToDevice(group_min_full_attn_seqlens.data());
|
||||
group_attn_scales_dev.ToDevice(group_attn_scales.data());
|
||||
|
||||
HstuAttentionGroupFwdParams params;
|
||||
|
||||
float scale_s = (alpha != 0.f) ? alpha : 1.0f / std::sqrt(hdim_qk);
|
||||
|
||||
params.is_cross_attention = is_cross_attention;
|
||||
params.num_batch = num_batch;
|
||||
params.num_group = num_group;
|
||||
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
|
||||
params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer();
|
||||
params.max_seqlen = max_max_seqlen;
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
params.bias_ptr = nullptr; // bias is not supported at present
|
||||
params.o_ptr = o_dev.GetDeviceBuffer();
|
||||
params.hdim_qk = hdim_qk;
|
||||
params.hdim_v = hdim_v;
|
||||
params.num_head = num_head;
|
||||
params.scale_s = scale_s;
|
||||
params.seq_stride_q = q_host.get_strides()[1];
|
||||
params.seq_stride_k = k_host.get_strides()[1];
|
||||
params.seq_stride_v = v_host.get_strides()[1];
|
||||
params.seq_stride_bias = 0;
|
||||
params.seq_stride_o = o_host_ref.get_strides()[1];
|
||||
params.nhead_stride_q = q_host.get_strides()[2];
|
||||
params.nhead_stride_k = k_host.get_strides()[2];
|
||||
params.nhead_stride_v = v_host.get_strides()[2];
|
||||
params.nhead_stride_bias = 0;
|
||||
params.nhead_stride_o = o_host_ref.get_strides()[2];
|
||||
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
|
||||
params.use_softmax = use_softmax;
|
||||
params.use_causal = use_causal;
|
||||
params.p_drop = 0.0f; // dropout is not supported at present
|
||||
params.philox_seed = 0UL;
|
||||
params.philox_offset = 0UL;
|
||||
params.group_max_seqlen_ptr = group_max_seqlens_dev.GetDeviceBuffer();
|
||||
params.group_contextual_seqlen_ptr = group_contextual_seqlens_dev.GetDeviceBuffer();
|
||||
params.group_window_size_ptr = group_window_sizes_dev.GetDeviceBuffer();
|
||||
params.group_min_full_attn_seqlen_ptr = group_min_full_attn_seqlens_dev.GetDeviceBuffer();
|
||||
params.group_attn_scale_ptr = group_attn_scales_dev.GetDeviceBuffer();
|
||||
|
||||
hipStream_t stream;
|
||||
|
||||
HIP_CHECK_ERROR(hipStreamCreate(&stream));
|
||||
|
||||
if constexpr(std::is_same<InOutDataType, ck_tile::fp16_t>::value)
|
||||
{
|
||||
hstu_attention_group_forward_fp16(params, stream);
|
||||
}
|
||||
else if constexpr(std::is_same<InOutDataType, ck_tile::bf16_t>::value)
|
||||
{
|
||||
hstu_attention_group_forward_bf16(params, stream);
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("Other data type is not supported at present!");
|
||||
|
||||
bool res = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
using GemmAccDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType;
|
||||
using CompDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType;
|
||||
|
||||
BOOL_SWITCH_2(use_softmax, kUseSoftmax, use_causal, kUseCausal, [&] {
|
||||
ck_tile::reference_group_hstu_attention<InOutDataType,
|
||||
GemmAccDataType,
|
||||
CompDataType,
|
||||
kUseSoftmax,
|
||||
kUseCausal>::Run(is_cross_attention,
|
||||
q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
mask_host,
|
||||
num_batch,
|
||||
num_batch / num_group,
|
||||
scale_s,
|
||||
max_max_seqlen,
|
||||
seq_offsets_q,
|
||||
seq_offsets_kv,
|
||||
num_targets,
|
||||
group_max_seqlens,
|
||||
group_contextual_seqlens,
|
||||
group_window_sizes,
|
||||
group_min_full_attn_seqlens,
|
||||
group_attn_scales);
|
||||
});
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> o_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
|
||||
|
||||
o_dev.FromDevice(o_host.data());
|
||||
|
||||
if(dump_output)
|
||||
{
|
||||
dumpBufferToFile("output_dev.dat", o_host.data(), o_host.get_element_space_size());
|
||||
dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size());
|
||||
}
|
||||
|
||||
if(save_mask)
|
||||
dumpBufferToFile(
|
||||
"ck_hstu_mask.dat", mask_host.data(), mask_host.get_element_space_size());
|
||||
|
||||
auto [rtol, atol] = get_elimit<InOutDataType>();
|
||||
|
||||
res = ck_tile::check_err(
|
||||
o_host, o_host_ref, std::string("hstu_attention output error"), rtol, atol);
|
||||
};
|
||||
|
||||
if(measure_perf)
|
||||
{
|
||||
ck_tile::gpu_timer timer{};
|
||||
|
||||
timer.start(stream);
|
||||
for(int i = 0; i < 10; i++)
|
||||
{
|
||||
if constexpr(std::is_same<InOutDataType, ck_tile::fp16_t>::value)
|
||||
{
|
||||
hstu_attention_group_forward_fp16(params, stream);
|
||||
}
|
||||
else if constexpr(std::is_same<InOutDataType, ck_tile::bf16_t>::value)
|
||||
{
|
||||
hstu_attention_group_forward_bf16(params, stream);
|
||||
}
|
||||
}
|
||||
timer.stop(stream);
|
||||
@@ -672,15 +1031,33 @@ int main(int argc, char* argv[])
|
||||
return -1;
|
||||
}
|
||||
|
||||
int num_group = static_cast<int>(arg_parser.get_int("g"));
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
|
||||
if(num_group > 1)
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_group_hstu<ck_tile::half_t>(arg_parser, num_group) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_group_hstu<ck_tile::bf16_t>(arg_parser, num_group) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
else
|
||||
{
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
bool is_jagged = static_cast<bool>(arg_parser.get_int("jagged"));
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_no_group_hstu<ck_tile::half_t>(arg_parser, is_jagged) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_no_group_hstu<ck_tile::bf16_t>(arg_parser, is_jagged) ? 0 : -2;
|
||||
}
|
||||
};
|
||||
|
||||
return -3;
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ HSTU_FORWARD_INSTANCE_TEMPLATE = """
|
||||
{use_softmax},
|
||||
{has_bias},
|
||||
{has_dropout},
|
||||
{max_k}>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
{max_k}>(HstuAttention{group_or_not}FwdParams& param, hipStream_t stream);
|
||||
"""
|
||||
|
||||
HSTU_FORWARD_INSTANCE_FNAME = (
|
||||
@@ -76,13 +76,14 @@ TYPE_FNAME_MAP = {
|
||||
"bf16": "half",
|
||||
}
|
||||
|
||||
MODE_NAME_MAP = {
|
||||
"batched": "Batched",
|
||||
"jagged": "Jagged",
|
||||
MODE_GROUP_OR_NOT_MAP = {
|
||||
"batched": "NoGroup",
|
||||
"jagged": "NoGroup",
|
||||
"group": "Group",
|
||||
}
|
||||
|
||||
def create_forward_instances(instance_dir: Path, headdims: List) -> None:
|
||||
for mode in ["batched", "jagged"]:
|
||||
for mode in ["batched", "jagged", "group"]:
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
for has_causal in [True, False]:
|
||||
for use_softmax in [True, False]:
|
||||
@@ -113,7 +114,7 @@ def create_forward_instances(instance_dir: Path, headdims: List) -> None:
|
||||
has_bias=BOOL_MAP[has_bias],
|
||||
has_dropout=BOOL_MAP[has_dropout],
|
||||
max_k=max_k,
|
||||
cap_mode=MODE_NAME_MAP[mode],
|
||||
group_or_not=MODE_GROUP_OR_NOT_MAP[mode],
|
||||
)
|
||||
(instance_dir / fname).write_text(
|
||||
HSTU_COPYRIGHT_HEADER
|
||||
@@ -123,7 +124,7 @@ def create_forward_instances(instance_dir: Path, headdims: List) -> None:
|
||||
|
||||
|
||||
def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None:
|
||||
for mode in ["batched", "jagged"]:
|
||||
for mode in ["batched", "jagged", "group"]:
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
ref_fname = HSTU_INSTANCE_REF_FNAME.format(
|
||||
mode=mode,
|
||||
@@ -153,7 +154,7 @@ def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None:
|
||||
has_bias=BOOL_MAP[has_bias],
|
||||
has_dropout=BOOL_MAP[has_dropout],
|
||||
max_k=max_k,
|
||||
cap_mode=MODE_NAME_MAP[mode],
|
||||
group_or_not=MODE_GROUP_OR_NOT_MAP[mode],
|
||||
)
|
||||
)
|
||||
file.write(forward_instance)
|
||||
|
||||
@@ -7,7 +7,11 @@
|
||||
|
||||
#include "hstu_attention_params.hpp"
|
||||
|
||||
extern void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
extern void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
extern void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
extern void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
extern void hstu_attention_no_group_forward_fp16(HstuAttentionNoGroupFwdParams& param,
|
||||
hipStream_t stream);
|
||||
extern void hstu_attention_no_group_forward_bf16(HstuAttentionNoGroupFwdParams& param,
|
||||
hipStream_t stream);
|
||||
extern void hstu_attention_group_forward_fp16(HstuAttentionGroupFwdParams& param,
|
||||
hipStream_t stream);
|
||||
extern void hstu_attention_group_forward_bf16(HstuAttentionGroupFwdParams& param,
|
||||
hipStream_t stream);
|
||||
|
||||
@@ -48,6 +48,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
|
||||
kIsCrossAttention,
|
||||
false, // kUseGroup
|
||||
false, // kIsJagged
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
@@ -56,7 +57,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
kUseTrLoad,
|
||||
HstuAttentionTileSetting>;
|
||||
|
||||
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
@@ -127,7 +128,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
static void RunWithKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.q_ptr,
|
||||
@@ -185,7 +186,7 @@ template <typename InOutDataType,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK>
|
||||
void run_batched_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionFwdParams& param,
|
||||
void run_batched_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGroupFwdParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
batched_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
|
||||
@@ -40,6 +40,7 @@ struct HstuAttentionFwdKernel
|
||||
using ODataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::ODataType>;
|
||||
|
||||
static constexpr bool kIsCrossAttention = HstuAttentionPipeline::Problem::kIsCrossAttention;
|
||||
static constexpr bool kUseGroup = HstuAttentionPipeline::Problem::kUseGroup;
|
||||
static constexpr bool kIsJagged = HstuAttentionPipeline::Problem::kIsJagged;
|
||||
static constexpr auto kHasBias = HstuAttentionPipeline::Problem::kHasBias;
|
||||
static constexpr bool kHasDropout = HstuAttentionPipeline::Problem::kHasDropout;
|
||||
@@ -60,7 +61,7 @@ struct HstuAttentionFwdKernel
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
struct HstuAttentionFwdBatchModeBaseKargs
|
||||
struct HstuAttentionNoGroupBatchedFwdBaseKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -98,7 +99,7 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t min_full_attn_seqlen;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdJaggModeBaseKargs
|
||||
struct HstuAttentionNoGroupJaggedFwdBaseKargs
|
||||
{
|
||||
const int32_t* seq_q_offsets_ptr;
|
||||
const int32_t* seq_kv_offsets_ptr;
|
||||
@@ -135,6 +136,51 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t min_full_attn_seqlen;
|
||||
};
|
||||
|
||||
struct HstuAttentionGroupFwdBaseKargs
|
||||
{
|
||||
ck_tile::index_t num_batch_per_group;
|
||||
|
||||
const int32_t* seq_q_offsets_ptr;
|
||||
const int32_t* seq_kv_offsets_ptr;
|
||||
|
||||
ck_tile::index_t seq_stride_q;
|
||||
ck_tile::index_t seq_stride_k;
|
||||
ck_tile::index_t seq_stride_v;
|
||||
ck_tile::index_t seq_stride_o;
|
||||
|
||||
const int32_t* num_targets_ptr;
|
||||
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
ck_tile::index_t hdim_qk;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_kv;
|
||||
|
||||
ck_tile::index_t num_head;
|
||||
float scale_s; // scaling value exerted on the immediate Q@K result
|
||||
float scale_p; // scaling value exerted on the SiLU result
|
||||
|
||||
int32_t contextual_seqlen; // to be set by the per-group contextual_seqlen
|
||||
int32_t window_size; // to be set by the per-group window_size
|
||||
int32_t min_full_attn_seqlen; // to be set by the per-group min_full_attn_seqlen
|
||||
|
||||
const int32_t* group_max_seqlen_ptr;
|
||||
const int32_t* group_contextual_seqlen_ptr;
|
||||
const int32_t* group_window_size_ptr;
|
||||
const int32_t* group_min_full_attn_seqlen_ptr;
|
||||
const float* group_attn_scale_ptr;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdCommonBiasKargs
|
||||
{
|
||||
const void* bias_ptr = nullptr;
|
||||
@@ -170,30 +216,48 @@ struct HstuAttentionFwdKernel
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdBatchModeKargs : HstuAttentionFwdBatchModeBaseKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdBatchModeBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
struct HstuAttentionNoGroupBatchedFwdKargs
|
||||
: HstuAttentionNoGroupBatchedFwdBaseKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdBatchModeBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
{
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdJaggModeBaseKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdCommonBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
struct HstuAttentionNoGroupJaggedFwdKargs
|
||||
: HstuAttentionNoGroupJaggedFwdBaseKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdCommonBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
{
|
||||
};
|
||||
|
||||
using Kargs = std::
|
||||
conditional_t<kIsJagged, HstuAttentionFwdJaggModeKargs, HstuAttentionFwdBatchModeKargs>;
|
||||
struct HstuAttentionGroupFwdKargs : HstuAttentionGroupFwdBaseKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdCommonBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
{
|
||||
};
|
||||
|
||||
template <bool Cond = !kIsJagged>
|
||||
using Kargs = std::conditional_t<kUseGroup,
|
||||
HstuAttentionGroupFwdKargs,
|
||||
std::conditional_t<kIsJagged,
|
||||
HstuAttentionNoGroupJaggedFwdKargs,
|
||||
HstuAttentionNoGroupBatchedFwdKargs>>;
|
||||
|
||||
static constexpr bool kUseNoGroupBatched = (!kUseGroup && !kIsJagged);
|
||||
static constexpr bool kUseNoGroupJagged = (!kUseGroup && kIsJagged);
|
||||
|
||||
template <bool Cond = kUseNoGroupBatched>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
@@ -278,7 +342,7 @@ struct HstuAttentionFwdKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
template <bool Cond = kIsJagged>
|
||||
template <bool Cond = kUseNoGroupJagged>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
@@ -355,11 +419,95 @@ struct HstuAttentionFwdKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
template <bool Cond = kUseGroup>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t num_batch_per_group,
|
||||
const void* seq_q_offsets_ptr,
|
||||
const void* seq_kv_offsets_ptr,
|
||||
const void* group_max_seqlen_ptr,
|
||||
const void* group_contextual_seqlen_ptr,
|
||||
const void* group_window_size_ptr,
|
||||
const void* group_min_full_attn_seqlen_ptr,
|
||||
const void* group_attn_scale_ptr,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
float scale_s,
|
||||
ck_tile::index_t seq_stride_q,
|
||||
ck_tile::index_t seq_stride_k,
|
||||
ck_tile::index_t seq_stride_v,
|
||||
ck_tile::index_t seq_stride_bias,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
float p_drop,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset)
|
||||
{
|
||||
Kargs kargs{
|
||||
{num_batch_per_group,
|
||||
reinterpret_cast<const int32_t*>(seq_q_offsets_ptr),
|
||||
reinterpret_cast<const int32_t*>(seq_kv_offsets_ptr),
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
seq_stride_o,
|
||||
reinterpret_cast<const int32_t*>(num_targets_ptr),
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o,
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
-1, // seqlen_q will be updated by another pointer
|
||||
-1, // seqlen_kv will be updated by another pointer
|
||||
num_head,
|
||||
scale_s,
|
||||
1.0f, // to be set according to the per-group attn_scale and max_seqlen
|
||||
0, // to be set by the per-group contextual_seqlen
|
||||
0, // to be set by the per-group window_size
|
||||
0, // to be set by the per-group min_full_attn_seqlen
|
||||
reinterpret_cast<const int32_t*>(group_max_seqlen_ptr),
|
||||
reinterpret_cast<const int32_t*>(group_contextual_seqlen_ptr),
|
||||
reinterpret_cast<const int32_t*>(group_window_size_ptr),
|
||||
reinterpret_cast<const int32_t*>(group_min_full_attn_seqlen_ptr),
|
||||
reinterpret_cast<const float*>(group_attn_scale_ptr)}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
kargs.bias_ptr = bias_ptr;
|
||||
kargs.seq_stride_bias = seq_stride_bias;
|
||||
kargs.nhead_stride_bias = nhead_stride_bias;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, philox_seed, philox_offset);
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_,
|
||||
ck_tile::index_t hdim_v_,
|
||||
bool has_minfull_attn_seqlen)
|
||||
bool has_minfull_attn_seqlen = false)
|
||||
{
|
||||
// The Q sequence [0, seqlen) will be split to two parts for allocating workgroups:
|
||||
// 1) [0, seqlen - target - min_full_attn_seqlen)
|
||||
@@ -367,8 +515,15 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t num_tile_in_seqlen =
|
||||
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0);
|
||||
|
||||
if(has_minfull_attn_seqlen)
|
||||
if constexpr(kUseGroup)
|
||||
{
|
||||
num_tile_in_seqlen += 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(has_minfull_attn_seqlen)
|
||||
num_tile_in_seqlen += 1;
|
||||
};
|
||||
|
||||
if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim)
|
||||
{
|
||||
@@ -492,6 +647,20 @@ struct HstuAttentionFwdKernel
|
||||
kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch];
|
||||
kargs.seqlen_kv =
|
||||
kargs.seq_kv_offsets_ptr[i_batch + 1] - kargs.seq_kv_offsets_ptr[i_batch];
|
||||
|
||||
// read from device memory for the group specific mask and scaling parameters
|
||||
if constexpr(kUseGroup)
|
||||
{
|
||||
index_t i_group =
|
||||
__builtin_amdgcn_readfirstlane(i_batch / kargs.num_batch_per_group);
|
||||
|
||||
float attn_scale = kargs.group_attn_scale_ptr[i_group];
|
||||
index_t max_seqlen = kargs.group_max_seqlen_ptr[i_group];
|
||||
kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen));
|
||||
kargs.contextual_seqlen = kargs.group_contextual_seqlen_ptr[i_group];
|
||||
kargs.window_size = kargs.group_window_size_ptr[i_group];
|
||||
kargs.min_full_attn_seqlen = kargs.group_min_full_attn_seqlen_ptr[i_group];
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -6,11 +6,11 @@
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
#include "hstu_attention_group_forward_dispatch.hpp"
|
||||
|
||||
#include "instances/hstu_attention_jagged_forward_bf16_instances_ref.hpp"
|
||||
#include "instances/hstu_attention_group_forward_bf16_instances_ref.hpp"
|
||||
|
||||
void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
void hstu_attention_group_forward_bf16(HstuAttentionGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const bool has_dropout = (param.p_drop > 0.0f);
|
||||
const bool has_bias = (param.bias_ptr != nullptr);
|
||||
@@ -18,12 +18,12 @@ void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
|
||||
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
run_group_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,184 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core/numeric/integer.hpp>
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/host/stream_config.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
#include "hstu_attention_fwd_setting.hpp"
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_pipeline_problem.hpp"
|
||||
#include "hstu_attention_traits.hpp"
|
||||
#include "hstu_attention_with_softmax_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_kernel.hpp"
|
||||
#include "hstu_attention_epilogue.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseSoftmax,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK>
|
||||
struct group_forward_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
using HstuAttentionTileSetting =
|
||||
typename std::conditional_t<kUseSoftmax,
|
||||
HstuAttentionWithSoftmaxFwdTileSetting<MaxK>,
|
||||
HstuAttentionNoSoftmaxFwdTileSetting<MaxK>>::Type;
|
||||
|
||||
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
#else
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
#endif
|
||||
|
||||
template <bool kIsCrossAttention>
|
||||
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
InOutDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
|
||||
kIsCrossAttention,
|
||||
true, // kUseGroup
|
||||
true, // kIsJagged
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kUseTrLoad,
|
||||
HstuAttentionTileSetting>;
|
||||
|
||||
static void Run(HstuAttentionGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kQKHeaddim == 0);
|
||||
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0);
|
||||
|
||||
// no need to check seqlen_q since it is not used as fastest dim,
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
constexpr bool kPadSeqLenK = true;
|
||||
|
||||
BOOL_SWITCH_2(pad_headdim_qk, kPadHeadDimQK, pad_headdim_v, kPadHeadDimV, [&] {
|
||||
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
|
||||
kPadSeqLenK,
|
||||
kPadHeadDimQK,
|
||||
kPadHeadDimV,
|
||||
occupancy>;
|
||||
|
||||
using HstuEpilogue = ck_tile::NRepetitions2DEpilogue<ck_tile::Default2DEpilogueProblem<
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimV>>;
|
||||
|
||||
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
|
||||
using HstuPipelineProblem = HstuPipelineProblemTemp<kIsCrossAttention>;
|
||||
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
|
||||
HstuTraits>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
|
||||
HstuTraits>>;
|
||||
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithKernel<HstuKernel>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<
|
||||
HstuPipelineProblem,
|
||||
HstuTraits>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem,
|
||||
HstuTraits>>;
|
||||
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithKernel<HstuKernel>(param, stream);
|
||||
};
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithKernel(HstuAttentionGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.q_ptr,
|
||||
param.k_ptr,
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_ptr,
|
||||
param.num_batch / param.num_group,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.is_cross_attention ? param.seq_kv_offsets_ptr
|
||||
: param.seq_q_offsets_ptr,
|
||||
param.group_max_seqlen_ptr,
|
||||
param.group_contextual_seqlen_ptr,
|
||||
param.group_window_size_ptr,
|
||||
param.group_min_full_attn_seqlen_ptr,
|
||||
param.group_attn_scale_ptr,
|
||||
param.hdim_qk,
|
||||
param.hdim_v,
|
||||
param.num_head,
|
||||
param.scale_s,
|
||||
param.seq_stride_q,
|
||||
param.seq_stride_k,
|
||||
param.seq_stride_v,
|
||||
param.seq_stride_bias,
|
||||
param.seq_stride_o,
|
||||
param.nhead_stride_q,
|
||||
param.nhead_stride_k,
|
||||
param.nhead_stride_v,
|
||||
param.nhead_stride_bias,
|
||||
param.nhead_stride_o,
|
||||
param.num_targets_ptr,
|
||||
param.p_drop,
|
||||
param.philox_seed,
|
||||
param.philox_offset);
|
||||
}();
|
||||
|
||||
dim3 kGridSize =
|
||||
HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen, param.hdim_v);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{stream, false},
|
||||
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseSoftmax,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK>
|
||||
void run_group_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionGroupFwdParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
group_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>::Run(param, stream);
|
||||
};
|
||||
@@ -6,24 +6,25 @@
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
#include "hstu_attention_group_forward_dispatch.hpp"
|
||||
|
||||
#include "instances/hstu_attention_jagged_forward_fp16_instances_ref.hpp"
|
||||
#include "instances/hstu_attention_group_forward_fp16_instances_ref.hpp"
|
||||
|
||||
void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
void hstu_attention_group_forward_fp16(HstuAttentionGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const bool has_dropout = (param.p_drop > 0.0f);
|
||||
const bool has_bias = (param.bias_ptr != nullptr);
|
||||
const bool use_causal = param.use_causal;
|
||||
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
|
||||
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
run_group_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -48,7 +48,8 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
|
||||
kIsCrossAttention,
|
||||
true, // kIsJagged
|
||||
false, // kUseGroup
|
||||
true, // kIsJagged
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
kUseCausal,
|
||||
@@ -56,7 +57,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
kUseTrLoad,
|
||||
HstuAttentionTileSetting>;
|
||||
|
||||
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
@@ -117,7 +118,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
static void RunWithKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.q_ptr,
|
||||
@@ -174,7 +175,7 @@ template <typename InOutDataType,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK>
|
||||
void run_jagged_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionFwdParams& param,
|
||||
void run_jagged_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGroupFwdParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
jagged_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
|
||||
@@ -7,10 +7,12 @@
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
#include "instances/hstu_attention_batched_forward_bf16_instances_ref.hpp"
|
||||
#include "instances/hstu_attention_jagged_forward_bf16_instances_ref.hpp"
|
||||
|
||||
void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
void hstu_attention_no_group_forward_bf16(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const bool has_dropout = (param.p_drop > 0.0f);
|
||||
const bool has_bias = (param.bias_ptr != nullptr);
|
||||
@@ -18,12 +20,20 @@ void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStrea
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
|
||||
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
if(param.is_jagged)
|
||||
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
else
|
||||
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -7,10 +7,12 @@
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
#include "instances/hstu_attention_batched_forward_fp16_instances_ref.hpp"
|
||||
#include "instances/hstu_attention_jagged_forward_fp16_instances_ref.hpp"
|
||||
|
||||
void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
void hstu_attention_no_group_forward_fp16(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const bool has_dropout = (param.p_drop > 0.0f);
|
||||
const bool has_bias = (param.bias_ptr != nullptr);
|
||||
@@ -18,12 +20,20 @@ void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStrea
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
|
||||
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
if(param.is_jagged)
|
||||
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
else
|
||||
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
|
||||
struct HstuAttentionFwdParams
|
||||
struct HstuAttentionNoGroupFwdParams
|
||||
{
|
||||
// for self-attention (is_cross_attention = false), we requires
|
||||
// 1) either seqlen_kv == 0 or seqlen_kv == seqlen_q
|
||||
@@ -55,6 +55,7 @@ struct HstuAttentionFwdParams
|
||||
const void* num_targets_ptr;
|
||||
|
||||
bool use_causal;
|
||||
// parameters used by Non-Group HSTU
|
||||
ck_tile::index_t window_size;
|
||||
ck_tile::index_t contextual_seqlen;
|
||||
ck_tile::index_t min_full_attn_seqlen;
|
||||
@@ -65,3 +66,63 @@ struct HstuAttentionFwdParams
|
||||
uint64_t philox_seed;
|
||||
uint64_t philox_offset;
|
||||
};
|
||||
|
||||
struct HstuAttentionGroupFwdParams
|
||||
{
|
||||
// for self-attention (is_cross_attention = false), we requires
|
||||
// 1) either seq_kv_offsets_ptr == nullptr, or seq_kv_offsets_ptr == seq_q_offsets_ptr
|
||||
bool is_cross_attention;
|
||||
|
||||
ck_tile::index_t num_group;
|
||||
ck_tile::index_t num_batch;
|
||||
const void* seq_q_offsets_ptr;
|
||||
const void* seq_kv_offsets_ptr;
|
||||
ck_tile::index_t max_seqlen; // the maximum of all the groups' max_seqlen
|
||||
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t hdim_qk;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t num_head;
|
||||
float scale_s; // scaling factor exerted on the immediate Q@K result
|
||||
|
||||
ck_tile::index_t seq_stride_q;
|
||||
ck_tile::index_t seq_stride_k;
|
||||
ck_tile::index_t seq_stride_v;
|
||||
ck_tile::index_t seq_stride_bias;
|
||||
ck_tile::index_t seq_stride_o;
|
||||
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
// batched mode only parameters
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
const void* num_targets_ptr;
|
||||
|
||||
bool use_causal;
|
||||
|
||||
// parameters used by Group HSTU
|
||||
const void* group_attn_scale_ptr;
|
||||
const void* group_max_seqlen_ptr;
|
||||
const void* group_window_size_ptr;
|
||||
const void* group_contextual_seqlen_ptr;
|
||||
const void* group_min_full_attn_seqlen_ptr;
|
||||
|
||||
bool use_softmax;
|
||||
|
||||
float p_drop;
|
||||
uint64_t philox_seed;
|
||||
uint64_t philox_offset;
|
||||
};
|
||||
|
||||
@@ -64,6 +64,7 @@ template <typename InOutDataType_,
|
||||
typename CompDataType_, // data type for SiLU and other non-linear calculation
|
||||
typename BiasDataType_,
|
||||
bool kIsCrossAttention_,
|
||||
bool kUseGroup_,
|
||||
bool kIsJagged_,
|
||||
bool kHasBias_,
|
||||
bool kHasDropout_,
|
||||
@@ -87,6 +88,7 @@ struct HstuAttentionFwdPipelineProblem
|
||||
using PDataType = QKVDataType;
|
||||
|
||||
static constexpr bool kIsCrossAttention = kIsCrossAttention_;
|
||||
static constexpr bool kUseGroup = kUseGroup_;
|
||||
static constexpr bool kIsJagged = kIsJagged_;
|
||||
static constexpr bool kHasBias = kHasBias_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
@@ -94,6 +96,9 @@ struct HstuAttentionFwdPipelineProblem
|
||||
static constexpr bool kUseSoftmax = kUseSoftmax_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
|
||||
static_assert(!kUseGroup || (kUseGroup && kIsJagged),
|
||||
"Group HSTU is only used with jagged mode!");
|
||||
|
||||
using HstuAttentionTileSetting = remove_cvref_t<AttentionTileSetting_>;
|
||||
|
||||
static constexpr index_t kNumGemm0Warps = AttentionTileSetting_::NumGemm0Warps;
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,7 +15,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -23,7 +23,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -31,7 +31,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -39,7 +39,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -47,7 +47,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -55,7 +55,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -63,7 +63,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -71,7 +71,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -79,7 +79,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -87,7 +87,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -95,7 +95,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -103,7 +103,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -111,7 +111,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -119,7 +119,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -127,7 +127,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -135,7 +135,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -143,7 +143,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -151,7 +151,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -159,7 +159,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -167,7 +167,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -175,7 +175,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -183,7 +183,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -191,7 +191,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -199,7 +199,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -207,7 +207,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -215,7 +215,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -223,7 +223,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -231,7 +231,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -239,7 +239,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -247,7 +247,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -255,7 +255,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -263,7 +263,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -271,7 +271,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -279,7 +279,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -287,7 +287,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -295,7 +295,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -303,7 +303,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -311,7 +311,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -319,7 +319,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -327,7 +327,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -335,7 +335,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -343,7 +343,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -351,7 +351,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -359,7 +359,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -367,7 +367,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -375,7 +375,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -383,7 +383,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -391,7 +391,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -399,7 +399,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -407,7 +407,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -415,7 +415,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -423,7 +423,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -431,7 +431,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -439,7 +439,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -447,7 +447,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -455,7 +455,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -463,7 +463,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -471,7 +471,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -479,7 +479,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -487,7 +487,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -495,7 +495,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -503,7 +503,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -511,7 +511,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
@@ -519,4 +519,4 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
96>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user