Update to support grouped mode hstu attention

This commit is contained in:
Qianfeng Zhang
2026-03-09 16:15:58 +00:00
parent 73d6e0eb67
commit 302537c5a8
408 changed files with 5284 additions and 713 deletions

View File

@@ -3,7 +3,7 @@ set(EXAMPLE_HSTU_ATTENTION "tile_example_hstu_attention")
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_HSTU_ATTENTION}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
set(INTERFACES_SRCS hstu_attention_jagged_forward_bf16.cpp hstu_attention_jagged_forward_fp16.cpp hstu_attention_batched_forward_bf16.cpp hstu_attention_batched_forward_fp16.cpp)
set(INTERFACES_SRCS hstu_attention_no_group_forward_bf16.cpp hstu_attention_no_group_forward_fp16.cpp hstu_attention_group_forward_bf16.cpp hstu_attention_group_forward_fp16.cpp)
add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention.cpp)
target_include_directories(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${INTERFACES_SRCS} ${INSTANCE_SRCS})

View File

@@ -29,27 +29,33 @@
``` C++
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("g", "1", "num of attention group, bigger than 1 indicating group hstu")
.insert("prec", "fp16", "data type. fp16/bf16")
.insert("jagged", "0", "q/k/v batched sequence is jagged or not")
.insert("b", "12", "batch size")
.insert("b", "12", "number of batches")
.insert("nhead", "4", "number of heads")
.insert("hdim_qk", "64", "headdim size of Q/K")
.insert("hdim_v", "64", "headdim size of V/O")
.insert("seqlens", "400", "seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("seqlens", "400", "uih seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("g_max_seqlens", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
.insert("g_local_lens", "5,", "list of all group's length of the diagonal window for enabling masking, value 0 to disable")
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
.insert("g_context_lens", "6,", "list of all group's sequence length at the begin of the query sequence that should be included for attention")
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
.insert("g_minfull_lens", "6", "list of all groups's sequence length at the end of the query sequence that should be included for attention")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("norm_dist", "0", "if true, initialize the data in normal distribution, or else in uniform distribution")
.insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)")
.insert("attn_scale", "0", "scale factor of SiLu(Q@K), 0 means using 1/max_seqlen for scaling")
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
.insert("perf", "0", "weather measure execution time or not");
.insert("attn_scale", "0", "scale factor of SiLU(Q@K). 0 means using 1/max_seqlen for scaling")
.insert("g_attn_scales", "1.0,", "list of all groups's scale factors of S=@@K. 0 means using 1/max_seqlen of the group for scaling")
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
.insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes")
.insert("perf", "0", "weather measure execution time or not")
.insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");
```

View File

@@ -11,6 +11,7 @@
#include <utility>
#include <vector>
#include <random>
#include <cassert>
#include <ck_tile/host/host_tensor.hpp>
#include <ck_tile/host/fill.hpp>
@@ -90,9 +91,10 @@ auto create_args(int argc, char* argv[])
// clang-format off
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("g", "1", "num of attention group, bigger than 1 indicating group hstu")
.insert("prec", "fp16", "data type. fp16/bf16")
.insert("jagged", "0", "q/k/v batched sequence is jagged or not")
.insert("b", "12", "batch size")
.insert("b", "12", "number of batches")
.insert("nhead", "4", "number of heads")
.insert("hdim_qk", "64", "headdim size of Q/K")
.insert("hdim_v", "64", "headdim size of V/O")
@@ -100,17 +102,22 @@ auto create_args(int argc, char* argv[])
.insert("seqlens_kv", "", "uih seqlen of single or all batches for key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("max_seqlen_kv", "0", "max uih_seqlen_kv, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("g_max_seqlens", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
.insert("softmax", "0", "use softmax or not")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
.insert("g_local_lens", "5,", "list of all group's length of the diagonal window for enabling masking, value 0 to disable")
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
.insert("g_context_lens", "6,", "list of all group's sequence length at the begin of the query sequence that should be included for attention")
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("g_minfull_lens", "6", "list of all groups's sequence length at the end of the query sequence that should be included for attention")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("norm_dist", "0", "if true, initialize the data in normal distribution, or else in uniform distribution")
.insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)")
.insert("attn_scale", "0", "scale factor of SiLU(Q@K). 0 means using 1/max_seqlen for scaling")
.insert("g_attn_scales", "1.0,", "list of all groups's scale factors of S=@@K. 0 means using 1/max_seqlen of the group for scaling")
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
.insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes")
.insert("perf", "0", "weather measure execution time or not")
@@ -121,35 +128,66 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
static std::vector<int> get_integers_from_string(std::string lengthsStr)
static std::vector<int> get_integers_from_string(std::string srcStr)
{
std::vector<int> lengths;
std::vector<int> integers;
std::size_t pos = 0;
std::size_t new_pos;
new_pos = lengthsStr.find(',', pos);
new_pos = srcStr.find(',', pos);
while(new_pos != std::string::npos)
{
std::string sliceStr = lengthsStr.substr(pos, new_pos - pos);
std::string sliceStr = srcStr.substr(pos, new_pos - pos);
int len = std::stoi(sliceStr);
lengths.push_back(len);
integers.push_back(len);
pos = new_pos + 1;
new_pos = lengthsStr.find(',', pos);
new_pos = srcStr.find(',', pos);
};
std::string sliceStr = lengthsStr.substr(pos);
std::string sliceStr = srcStr.substr(pos);
if(!sliceStr.empty())
{
int len = std::stoi(sliceStr);
lengths.push_back(len);
integers.push_back(len);
};
return (lengths);
return (integers);
};
static std::vector<float> get_floats_from_string(std::string srcStr)
{
std::vector<float> values;
std::size_t pos = 0;
std::size_t new_pos;
new_pos = srcStr.find(',', pos);
while(new_pos != std::string::npos)
{
std::string sliceStr = srcStr.substr(pos, new_pos - pos);
float val = std::stof(sliceStr);
values.push_back(val);
pos = new_pos + 1;
new_pos = srcStr.find(',', pos);
};
std::string sliceStr = srcStr.substr(pos);
if(!sliceStr.empty())
{
float val = std::stof(sliceStr);
values.push_back(val);
};
return (values);
};
template <typename T>
@@ -164,42 +202,6 @@ void supplement_array_by_last_element(std::vector<T>& arr, int target_num_elemen
};
};
static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdParams& param)
{
if(param.is_jagged)
{
os << "Jagged inputs used! " << std::endl;
os << "use causal: " << param.use_causal << std::endl;
os << "Num of batches: " << param.num_batch << std::endl;
os << "Num of heads: " << param.num_head << std::endl;
os << "QK hdim: " << param.hdim_qk << " V hdim: " << param.hdim_v << std::endl;
os << "Q/K/V/O seq stride: " << param.seq_stride_q << " " << param.seq_stride_k << " "
<< param.seq_stride_v << " " << param.seq_stride_o << std::endl;
os << "Q/K/V/O nhead stride: " << param.nhead_stride_q << " " << param.nhead_stride_k << " "
<< param.nhead_stride_v << " " << param.nhead_stride_o << std::endl;
os << "contextual_seqlen: " << param.contextual_seqlen << std::endl;
os << "window_size: " << param.window_size << std::endl;
os << "min_full_attn_seqlen: " << param.min_full_attn_seqlen << std::endl;
}
else
{
os << "Batched inputs used! " << std::endl;
os << "use causal: " << param.use_causal << std::endl;
os << "Num of batches: " << param.num_batch << std::endl;
os << "Num of heads: " << param.num_head << std::endl;
os << "QK hdim: " << param.hdim_qk << " V hdim: " << param.hdim_v << std::endl;
os << "Q/K/V/O seq stride: " << param.seq_stride_q << " " << param.seq_stride_k << " "
<< param.seq_stride_v << " " << param.seq_stride_o << std::endl;
os << "Q/K/V/O nhead stride: " << param.nhead_stride_q << " " << param.nhead_stride_k << " "
<< param.nhead_stride_v << " " << param.nhead_stride_o << std::endl;
os << "Q/K/V/O batch stride: " << param.batch_stride_q << " " << param.batch_stride_k << " "
<< param.batch_stride_v << " " << param.batch_stride_o << std::endl;
os << "contextual_seqlen: " << param.contextual_seqlen << std::endl;
os << "window_size: " << param.window_size << std::endl;
os << "min_full_attn_seqlen: " << param.min_full_attn_seqlen << std::endl;
};
};
// threshold for different dtypes
template <typename DataType>
auto get_elimit()
@@ -219,10 +221,9 @@ auto get_elimit<ck_tile::bf16_t>()
}
template <typename InOutDataType>
bool run(const ck_tile::ArgParser& arg_parser)
bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
{
bool do_validation = static_cast<bool>(arg_parser.get_int("v"));
bool is_jagged = static_cast<bool>(arg_parser.get_int("jagged"));
int num_batch = arg_parser.get_int("b");
int num_head = arg_parser.get_int("nhead");
int hdim_qk = arg_parser.get_int("hdim_qk");
@@ -230,11 +231,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
bool use_softmax = static_cast<bool>(arg_parser.get_int("softmax"));
bool use_causal = static_cast<bool>(arg_parser.get_int("causal"));
int window_size = arg_parser.get_int("local_len");
int contextual_seqlen = arg_parser.get_int("context_len");
int min_full_attn_seqlen = arg_parser.get_int("minfull_len");
float alpha = arg_parser.get_float("alpha");
float attn_scale = arg_parser.get_float("attn_scale");
int seed = arg_parser.get_int("seed");
@@ -245,8 +241,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
bool save_mask = static_cast<bool>(arg_parser.get_int("save_mask"));
bool initialize_qkv = static_cast<bool>(arg_parser.get_int("init_qkv"));
std::string str_of_targets = arg_parser.get_str("targets");
std::vector<int> num_targets = get_integers_from_string(str_of_targets);
std::string str_of_integers;
str_of_integers = arg_parser.get_str("targets");
std::vector<int> num_targets = get_integers_from_string(str_of_integers);
int window_size = arg_parser.get_int("local_len");
int contextual_seqlen = arg_parser.get_int("context_len");
int min_full_attn_seqlen = arg_parser.get_int("minfull_len");
std::string str_of_lengths_q = arg_parser.get_str("seqlens");
std::vector<int> seq_lengths_q = get_integers_from_string(str_of_lengths_q);
@@ -265,16 +268,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
bool is_cross_attention = false;
if(!num_targets.empty())
{
// supplement num_targets using the last input value if user-provided lengths not enough
supplement_array_by_last_element(num_targets, num_batch);
// only consider num_batch values even if more values are provided by the user
for(int i = 0; i < num_batch; i++)
max_target = max(max_target, num_targets[i]);
};
HSTU_CHECK(!seq_lengths_q.empty(), "sequence lengths of q shoud be defined!");
// assume seq_lengths_kv is same as seq_lengths_q if not defined, or else when
@@ -293,15 +286,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(is_jagged)
{
// supplement seq_lengths_q using the last input value if user-provided lengths not
// enough
// supplement seq_lengths_q using the last input value if user-provided lengths not enough
supplement_array_by_last_element(seq_lengths_q, num_batch);
// supplement seq_lengths_kv using the last input value if user-provided lengths not
// enough
// supplement seq_lengths_kv using the last input value if user-provided lengths not enough
supplement_array_by_last_element(seq_lengths_kv, num_batch);
// only consider num_batch values even if more values are provided by the user
for(int i = 0; i < num_batch; i++)
{
max_uih_seqlen_q = max(max_uih_seqlen_q, seq_lengths_q[i]);
@@ -316,6 +306,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
max_uih_seqlen_kv = seq_lengths_kv[0];
};
if(!num_targets.empty())
{
// supplement num_targets using the last input value if user-provided lengths not enough
supplement_array_by_last_element(num_targets, num_batch);
// only consider num_batch values even if more values are provided by the user
for(int i = 0; i < num_batch; i++)
max_target = max(max_target, num_targets[i]);
};
// the user input of max_uih_seqlen can either be ignored or be bigger than all uih_seqlens
// the user input of max_target can either be ignored or be bigger than all targets
HSTU_CHECK(input_max_uih_seqlen_q <= 0 || input_max_uih_seqlen_q >= max_uih_seqlen_q,
@@ -386,7 +386,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
long total_flops = 0;
// estimate the total flops occurred, ignoring the scaling and SILu
// estimate the total flops occurred, ignoring the scaling and SiLu
if(is_jagged)
{
for(int i = 0; i < num_batch; i++)
@@ -467,7 +467,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(!num_targets.empty())
num_targets_dev.ToDevice(num_targets.data());
HstuAttentionFwdParams params;
HstuAttentionNoGroupFwdParams params;
float scale_s = (alpha != 0.f) ? alpha : 1.0f / std::sqrt(hdim_qk);
@@ -552,26 +552,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
params.philox_offset = 0UL;
};
// show_hstu_attention_fwd_param(std::cout, params);
std::ignore = show_hstu_attention_fwd_param;
hipStream_t stream;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
if constexpr(std::is_same<InOutDataType, ck_tile::fp16_t>::value)
{
if(is_jagged)
hstu_attention_jagged_forward_fp16(params, stream);
else
hstu_attention_batched_forward_fp16(params, stream);
hstu_attention_no_group_forward_fp16(params, stream);
}
else if constexpr(std::is_same<InOutDataType, ck_tile::bf16_t>::value)
{
if(is_jagged)
hstu_attention_jagged_forward_bf16(params, stream);
else
hstu_attention_batched_forward_bf16(params, stream);
hstu_attention_no_group_forward_bf16(params, stream);
}
else
throw std::runtime_error("Other data type is not supported at present!");
@@ -584,28 +575,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
using CompDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType;
BOOL_SWITCH_3(is_jagged, kIsJagged, use_softmax, kUseSoftmax, use_causal, kUseCausal, [&] {
ck_tile::reference_hstu_attention<InOutDataType,
GemmAccDataType,
CompDataType,
kIsJagged,
kUseSoftmax,
kUseCausal>::Run(is_cross_attention,
q_host,
k_host,
v_host,
o_host_ref,
mask_host,
num_batch,
scale_s,
attn_scale,
max_seqlen_q,
max_seqlen_kv,
seq_offsets_q,
seq_offsets_kv,
num_targets,
contextual_seqlen,
window_size,
min_full_attn_seqlen);
ck_tile::reference_no_group_hstu_attention<InOutDataType,
GemmAccDataType,
CompDataType,
kIsJagged,
kUseSoftmax,
kUseCausal>::Run(is_cross_attention,
q_host,
k_host,
v_host,
o_host_ref,
mask_host,
num_batch,
scale_s,
attn_scale,
max_seqlen_q,
max_seqlen_kv,
seq_offsets_q,
seq_offsets_kv,
num_targets,
contextual_seqlen,
window_size,
min_full_attn_seqlen);
});
ck_tile::HostTensor<InOutDataType> o_host(
@@ -638,17 +629,385 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
if constexpr(std::is_same<InOutDataType, ck_tile::fp16_t>::value)
{
if(is_jagged)
hstu_attention_jagged_forward_fp16(params, stream);
else
hstu_attention_batched_forward_fp16(params, stream);
hstu_attention_no_group_forward_fp16(params, stream);
}
else if constexpr(std::is_same<InOutDataType, ck_tile::bf16_t>::value)
{
if(is_jagged)
hstu_attention_jagged_forward_bf16(params, stream);
else
hstu_attention_batched_forward_bf16(params, stream);
hstu_attention_no_group_forward_bf16(params, stream);
}
}
timer.stop(stream);
auto ms = timer.duration() / 10.f;
std::cout << "Average execution time of the hstu_attention operation is " << ms
<< " milli-seconds, estimated TFLOPS is "
<< (static_cast<float>(total_flops) / ms) / 1.0e9 << std::endl;
}
return res;
}
template <typename InOutDataType>
bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
{
bool do_validation = static_cast<bool>(arg_parser.get_int("v"));
int num_batch = arg_parser.get_int("b");
HSTU_CHECK(num_group > 1, "ru_group_hstu should only be called when num_group > 1 !");
HSTU_CHECK(num_batch > 0 && num_batch % num_group == 0,
"number of batches should be a multi-fold value of num_group!");
int num_batch_per_group = num_batch / num_group;
int num_head = arg_parser.get_int("nhead");
int hdim_qk = arg_parser.get_int("hdim_qk");
int hdim_v = arg_parser.get_int("hdim_v");
bool use_softmax = static_cast<bool>(arg_parser.get_int("softmax"));
bool use_causal = static_cast<bool>(arg_parser.get_int("causal"));
float alpha = arg_parser.get_float("alpha");
int seed = arg_parser.get_int("seed");
bool use_normal_dist = arg_parser.get_int("norm_dist");
bool measure_perf = static_cast<bool>(arg_parser.get_int("perf"));
bool dump_output = static_cast<bool>(arg_parser.get_int("dump_output"));
bool save_mask = static_cast<bool>(arg_parser.get_int("save_mask"));
bool initialize_qkv = static_cast<bool>(arg_parser.get_int("init_qkv"));
std::string str_of_integers;
str_of_integers = arg_parser.get_str("targets");
std::vector<int> num_targets = get_integers_from_string(str_of_integers);
std::string str_of_lengths_q = arg_parser.get_str("seqlens");
std::vector<int> seq_lengths_q = get_integers_from_string(str_of_lengths_q);
std::string str_of_lengths_kv = arg_parser.get_str("seqlens_kv");
std::vector<int> seq_lengths_kv = get_integers_from_string(str_of_lengths_kv);
bool is_cross_attention = false;
HSTU_CHECK(!seq_lengths_q.empty(), "sequence lengths shoud be defined!");
// assume seq_lengths_kv is same as seq_lengths_q if not defined, or else when
// seq_lengths_kv is explicitly defined, we think the input case is a cross_attention case
if(seq_lengths_kv.empty())
seq_lengths_kv = seq_lengths_q;
else
is_cross_attention = true;
str_of_integers = arg_parser.get_str("g_max_seqlens");
std::vector<int> group_max_seqlens = get_integers_from_string(str_of_integers);
HSTU_CHECK(!group_max_seqlens.empty(), "group window sizes shoud be defined!");
str_of_integers = arg_parser.get_str("g_context_lens");
std::vector<int> group_contextual_seqlens = get_integers_from_string(str_of_integers);
HSTU_CHECK(!group_contextual_seqlens.empty(), "group contextual seqlens shoud be defined!");
str_of_integers = arg_parser.get_str("g_local_lens");
std::vector<int> group_window_sizes = get_integers_from_string(str_of_integers);
HSTU_CHECK(!group_window_sizes.empty(), "group window sizes shoud be defined!");
str_of_integers = arg_parser.get_str("g_minfull_lens");
std::vector<int> group_min_full_attn_seqlens = get_integers_from_string(str_of_integers);
HSTU_CHECK(!group_min_full_attn_seqlens.empty(),
"group min_full_attn seqlens shoud be defined!");
std::string str_of_floats = arg_parser.get_str("g_attn_scales");
std::vector<float> group_attn_scales = get_floats_from_string(str_of_floats);
HSTU_CHECK(!group_attn_scales.empty(), "group attn_scales shoud be defined!");
// supplement seq_lengths_q using the last input value if user-provided lengths not enough
supplement_array_by_last_element(seq_lengths_q, num_batch);
// supplement seq_lengths_kv using the last input value if user-provided lengths not enough
supplement_array_by_last_element(seq_lengths_kv, num_batch);
if(!num_targets.empty())
{
// supplement num_targets using the last input value if user-provided lengths not enough
supplement_array_by_last_element(num_targets, num_batch);
};
// supplement group_max_seqlens using the last input value if user-provided lengths not enough
supplement_array_by_last_element(group_max_seqlens, num_group);
// supplement group_contextual_seqlens using the last input value if user-provided lengths not
// enough
supplement_array_by_last_element(group_contextual_seqlens, num_group);
// supplement group_window_sizes using the last input value if user-provided lengths not enough
supplement_array_by_last_element(group_window_sizes, num_group);
// supplement group_min_full_attn_seqlens using the last input value if user-provided lengths
// not enough
supplement_array_by_last_element(group_min_full_attn_seqlens, num_group);
// supplement group_attn_scales using the last input value if user-provided values not enough
supplement_array_by_last_element(group_attn_scales, num_group);
int phy_seqlen_q = 0;
int phy_seqlen_kv = 0;
int max_max_seqlen = 0;
// only consider num_group values even if more values were provided by the user
for(int i = 0; i < num_group; i++)
{
max_max_seqlen = max(max_max_seqlen, group_max_seqlens[i]);
};
std::vector<int> seq_offsets_q;
std::vector<int> seq_offsets_kv;
seq_offsets_q.push_back(0);
for(int i = 0; i < num_batch; i++)
{
int i_group = i / num_batch_per_group;
int batch_seqlen =
num_targets.empty()
? seq_lengths_q[i] + group_contextual_seqlens[i_group]
: seq_lengths_q[i] + num_targets[i] + group_contextual_seqlens[i_group];
phy_seqlen_q += batch_seqlen;
seq_offsets_q.push_back(phy_seqlen_q);
};
seq_offsets_kv.push_back(0);
for(int i = 0; i < num_batch; i++)
{
if(!is_cross_attention)
{
int i_group = i / num_batch_per_group;
int batch_seqlen =
num_targets.empty()
? seq_lengths_kv[i] + group_contextual_seqlens[i_group]
: seq_lengths_kv[i] + num_targets[i] + group_contextual_seqlens[i_group];
phy_seqlen_kv += batch_seqlen;
seq_offsets_kv.push_back(phy_seqlen_kv);
}
else // for cross_attention, assume target_in_kv == false
{
int i_group = i / num_batch_per_group;
int batch_seqlen = seq_lengths_kv[i] + group_contextual_seqlens[i_group];
phy_seqlen_kv += batch_seqlen;
seq_offsets_kv.push_back(phy_seqlen_kv);
}
};
long total_flops = 0;
// estimate the total flops occurred, ignoring the scaling and SILu
for(int i = 0; i < num_batch; i++)
{
int len_q = seq_offsets_q[i + 1] - seq_offsets_q[i];
int len_kv = seq_offsets_kv[i + 1] - seq_offsets_kv[i];
total_flops += (static_cast<long>(len_q) * len_kv * hdim_qk +
static_cast<long>(len_q) * hdim_v * len_kv) *
2;
};
total_flops *= num_head;
int batches_for_alloc = 1;
ck_tile::HostTensor<InOutDataType> q_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_qk});
ck_tile::HostTensor<InOutDataType> k_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_kv, num_head, hdim_qk});
ck_tile::HostTensor<InOutDataType> v_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_kv, num_head, hdim_v});
ck_tile::HostTensor<InOutDataType> o_host_ref(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
ck_tile::HostTensor<int8_t> mask_host(
save_mask
? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_max_seqlen, max_max_seqlen}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
if(!initialize_qkv)
{
if(use_normal_dist)
{
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(v_host);
}
else
{
ck_tile::FillUniformDistribution<InOutDataType>{-1.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<InOutDataType>{-1.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<InOutDataType>{-1.f, 1.f, seed}(v_host);
};
}
else
{
readDataToBufferFromFile(q_host.data(), q_host.get_element_space_size(), "q.dat");
readDataToBufferFromFile(k_host.data(), k_host.get_element_space_size(), "k.dat");
readDataToBufferFromFile(v_host.data(), v_host.get_element_space_size(), "v.dat");
};
ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_dev(o_host_ref.get_element_space_size_in_bytes());
ck_tile::DeviceMem seq_offsets_q_dev(seq_offsets_q.size() * sizeof(int));
ck_tile::DeviceMem seq_offsets_kv_dev(seq_offsets_kv.size() * sizeof(int));
ck_tile::DeviceMem num_targets_dev(num_targets.size() * sizeof(int));
q_dev.ToDevice(q_host.data());
k_dev.ToDevice(k_host.data());
v_dev.ToDevice(v_host.data());
seq_offsets_q_dev.ToDevice(seq_offsets_q.data());
seq_offsets_kv_dev.ToDevice(seq_offsets_kv.data());
if(!num_targets.empty())
num_targets_dev.ToDevice(num_targets.data());
ck_tile::DeviceMem group_max_seqlens_dev(group_max_seqlens.size() * sizeof(int));
ck_tile::DeviceMem group_contextual_seqlens_dev(group_contextual_seqlens.size() * sizeof(int));
ck_tile::DeviceMem group_window_sizes_dev(group_window_sizes.size() * sizeof(int));
ck_tile::DeviceMem group_min_full_attn_seqlens_dev(group_min_full_attn_seqlens.size() *
sizeof(int));
ck_tile::DeviceMem group_attn_scales_dev(group_attn_scales.size() * sizeof(float));
group_max_seqlens_dev.ToDevice(group_max_seqlens.data());
group_contextual_seqlens_dev.ToDevice(group_contextual_seqlens.data());
group_window_sizes_dev.ToDevice(group_window_sizes.data());
group_min_full_attn_seqlens_dev.ToDevice(group_min_full_attn_seqlens.data());
group_attn_scales_dev.ToDevice(group_attn_scales.data());
HstuAttentionGroupFwdParams params;
float scale_s = (alpha != 0.f) ? alpha : 1.0f / std::sqrt(hdim_qk);
params.is_cross_attention = is_cross_attention;
params.num_batch = num_batch;
params.num_group = num_group;
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer();
params.max_seqlen = max_max_seqlen;
params.q_ptr = q_dev.GetDeviceBuffer();
params.k_ptr = k_dev.GetDeviceBuffer();
params.v_ptr = v_dev.GetDeviceBuffer();
params.bias_ptr = nullptr; // bias is not supported at present
params.o_ptr = o_dev.GetDeviceBuffer();
params.hdim_qk = hdim_qk;
params.hdim_v = hdim_v;
params.num_head = num_head;
params.scale_s = scale_s;
params.seq_stride_q = q_host.get_strides()[1];
params.seq_stride_k = k_host.get_strides()[1];
params.seq_stride_v = v_host.get_strides()[1];
params.seq_stride_bias = 0;
params.seq_stride_o = o_host_ref.get_strides()[1];
params.nhead_stride_q = q_host.get_strides()[2];
params.nhead_stride_k = k_host.get_strides()[2];
params.nhead_stride_v = v_host.get_strides()[2];
params.nhead_stride_bias = 0;
params.nhead_stride_o = o_host_ref.get_strides()[2];
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
params.use_softmax = use_softmax;
params.use_causal = use_causal;
params.p_drop = 0.0f; // dropout is not supported at present
params.philox_seed = 0UL;
params.philox_offset = 0UL;
params.group_max_seqlen_ptr = group_max_seqlens_dev.GetDeviceBuffer();
params.group_contextual_seqlen_ptr = group_contextual_seqlens_dev.GetDeviceBuffer();
params.group_window_size_ptr = group_window_sizes_dev.GetDeviceBuffer();
params.group_min_full_attn_seqlen_ptr = group_min_full_attn_seqlens_dev.GetDeviceBuffer();
params.group_attn_scale_ptr = group_attn_scales_dev.GetDeviceBuffer();
hipStream_t stream;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
if constexpr(std::is_same<InOutDataType, ck_tile::fp16_t>::value)
{
hstu_attention_group_forward_fp16(params, stream);
}
else if constexpr(std::is_same<InOutDataType, ck_tile::bf16_t>::value)
{
hstu_attention_group_forward_bf16(params, stream);
}
else
throw std::runtime_error("Other data type is not supported at present!");
bool res = true;
if(do_validation)
{
using GemmAccDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType;
using CompDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType;
BOOL_SWITCH_2(use_softmax, kUseSoftmax, use_causal, kUseCausal, [&] {
ck_tile::reference_group_hstu_attention<InOutDataType,
GemmAccDataType,
CompDataType,
kUseSoftmax,
kUseCausal>::Run(is_cross_attention,
q_host,
k_host,
v_host,
o_host_ref,
mask_host,
num_batch,
num_batch / num_group,
scale_s,
max_max_seqlen,
seq_offsets_q,
seq_offsets_kv,
num_targets,
group_max_seqlens,
group_contextual_seqlens,
group_window_sizes,
group_min_full_attn_seqlens,
group_attn_scales);
});
ck_tile::HostTensor<InOutDataType> o_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
o_dev.FromDevice(o_host.data());
if(dump_output)
{
dumpBufferToFile("output_dev.dat", o_host.data(), o_host.get_element_space_size());
dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size());
}
if(save_mask)
dumpBufferToFile(
"ck_hstu_mask.dat", mask_host.data(), mask_host.get_element_space_size());
auto [rtol, atol] = get_elimit<InOutDataType>();
res = ck_tile::check_err(
o_host, o_host_ref, std::string("hstu_attention output error"), rtol, atol);
};
if(measure_perf)
{
ck_tile::gpu_timer timer{};
timer.start(stream);
for(int i = 0; i < 10; i++)
{
if constexpr(std::is_same<InOutDataType, ck_tile::fp16_t>::value)
{
hstu_attention_group_forward_fp16(params, stream);
}
else if constexpr(std::is_same<InOutDataType, ck_tile::bf16_t>::value)
{
hstu_attention_group_forward_bf16(params, stream);
}
}
timer.stop(stream);
@@ -672,15 +1031,33 @@ int main(int argc, char* argv[])
return -1;
}
int num_group = static_cast<int>(arg_parser.get_int("g"));
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
if(num_group > 1)
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
if(data_type == "fp16")
{
return run_group_hstu<ck_tile::half_t>(arg_parser, num_group) ? 0 : -2;
}
else if(data_type == "bf16")
{
return run_group_hstu<ck_tile::bf16_t>(arg_parser, num_group) ? 0 : -2;
}
}
else if(data_type == "bf16")
else
{
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
}
bool is_jagged = static_cast<bool>(arg_parser.get_int("jagged"));
if(data_type == "fp16")
{
return run_no_group_hstu<ck_tile::half_t>(arg_parser, is_jagged) ? 0 : -2;
}
else if(data_type == "bf16")
{
return run_no_group_hstu<ck_tile::bf16_t>(arg_parser, is_jagged) ? 0 : -2;
}
};
return -3;
}

View File

@@ -32,7 +32,7 @@ HSTU_FORWARD_INSTANCE_TEMPLATE = """
{use_softmax},
{has_bias},
{has_dropout},
{max_k}>(HstuAttentionFwdParams& param, hipStream_t stream);
{max_k}>(HstuAttention{group_or_not}FwdParams& param, hipStream_t stream);
"""
HSTU_FORWARD_INSTANCE_FNAME = (
@@ -76,13 +76,14 @@ TYPE_FNAME_MAP = {
"bf16": "half",
}
MODE_NAME_MAP = {
"batched": "Batched",
"jagged": "Jagged",
MODE_GROUP_OR_NOT_MAP = {
"batched": "NoGroup",
"jagged": "NoGroup",
"group": "Group",
}
def create_forward_instances(instance_dir: Path, headdims: List) -> None:
for mode in ["batched", "jagged"]:
for mode in ["batched", "jagged", "group"]:
for dtype in ["fp16", "bf16"]:
for has_causal in [True, False]:
for use_softmax in [True, False]:
@@ -113,7 +114,7 @@ def create_forward_instances(instance_dir: Path, headdims: List) -> None:
has_bias=BOOL_MAP[has_bias],
has_dropout=BOOL_MAP[has_dropout],
max_k=max_k,
cap_mode=MODE_NAME_MAP[mode],
group_or_not=MODE_GROUP_OR_NOT_MAP[mode],
)
(instance_dir / fname).write_text(
HSTU_COPYRIGHT_HEADER
@@ -123,7 +124,7 @@ def create_forward_instances(instance_dir: Path, headdims: List) -> None:
def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None:
for mode in ["batched", "jagged"]:
for mode in ["batched", "jagged", "group"]:
for dtype in ["fp16", "bf16"]:
ref_fname = HSTU_INSTANCE_REF_FNAME.format(
mode=mode,
@@ -153,7 +154,7 @@ def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None:
has_bias=BOOL_MAP[has_bias],
has_dropout=BOOL_MAP[has_dropout],
max_k=max_k,
cap_mode=MODE_NAME_MAP[mode],
group_or_not=MODE_GROUP_OR_NOT_MAP[mode],
)
)
file.write(forward_instance)

View File

@@ -7,7 +7,11 @@
#include "hstu_attention_params.hpp"
extern void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream);
extern void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream);
extern void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream);
extern void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream);
extern void hstu_attention_no_group_forward_fp16(HstuAttentionNoGroupFwdParams& param,
hipStream_t stream);
extern void hstu_attention_no_group_forward_bf16(HstuAttentionNoGroupFwdParams& param,
hipStream_t stream);
extern void hstu_attention_group_forward_fp16(HstuAttentionGroupFwdParams& param,
hipStream_t stream);
extern void hstu_attention_group_forward_bf16(HstuAttentionGroupFwdParams& param,
hipStream_t stream);

View File

@@ -48,6 +48,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
kIsCrossAttention,
false, // kUseGroup
false, // kIsJagged
kHasBias,
kHasDropout,
@@ -56,7 +57,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
kUseTrLoad,
HstuAttentionTileSetting>;
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
{
constexpr ck_tile::index_t occupancy = -1;
@@ -127,7 +128,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
};
template <typename HstuKernel>
static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream)
static void RunWithKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
{
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.q_ptr,
@@ -185,7 +186,7 @@ template <typename InOutDataType,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK>
void run_batched_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionFwdParams& param,
void run_batched_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGroupFwdParams& param,
hipStream_t stream)
{
batched_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,

View File

@@ -40,6 +40,7 @@ struct HstuAttentionFwdKernel
using ODataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::ODataType>;
static constexpr bool kIsCrossAttention = HstuAttentionPipeline::Problem::kIsCrossAttention;
static constexpr bool kUseGroup = HstuAttentionPipeline::Problem::kUseGroup;
static constexpr bool kIsJagged = HstuAttentionPipeline::Problem::kIsJagged;
static constexpr auto kHasBias = HstuAttentionPipeline::Problem::kHasBias;
static constexpr bool kHasDropout = HstuAttentionPipeline::Problem::kHasDropout;
@@ -60,7 +61,7 @@ struct HstuAttentionFwdKernel
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct HstuAttentionFwdBatchModeBaseKargs
struct HstuAttentionNoGroupBatchedFwdBaseKargs
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
@@ -98,7 +99,7 @@ struct HstuAttentionFwdKernel
ck_tile::index_t min_full_attn_seqlen;
};
struct HstuAttentionFwdJaggModeBaseKargs
struct HstuAttentionNoGroupJaggedFwdBaseKargs
{
const int32_t* seq_q_offsets_ptr;
const int32_t* seq_kv_offsets_ptr;
@@ -135,6 +136,51 @@ struct HstuAttentionFwdKernel
ck_tile::index_t min_full_attn_seqlen;
};
struct HstuAttentionGroupFwdBaseKargs
{
ck_tile::index_t num_batch_per_group;
const int32_t* seq_q_offsets_ptr;
const int32_t* seq_kv_offsets_ptr;
ck_tile::index_t seq_stride_q;
ck_tile::index_t seq_stride_k;
ck_tile::index_t seq_stride_v;
ck_tile::index_t seq_stride_o;
const int32_t* num_targets_ptr;
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
void* o_ptr;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t hdim_qk;
ck_tile::index_t hdim_v;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_kv;
ck_tile::index_t num_head;
float scale_s; // scaling value exerted on the immediate Q@K result
float scale_p; // scaling value exerted on the SiLU result
int32_t contextual_seqlen; // to be set by the per-group contextual_seqlen
int32_t window_size; // to be set by the per-group window_size
int32_t min_full_attn_seqlen; // to be set by the per-group min_full_attn_seqlen
const int32_t* group_max_seqlen_ptr;
const int32_t* group_contextual_seqlen_ptr;
const int32_t* group_window_size_ptr;
const int32_t* group_min_full_attn_seqlen_ptr;
const float* group_attn_scale_ptr;
};
struct HstuAttentionFwdCommonBiasKargs
{
const void* bias_ptr = nullptr;
@@ -170,30 +216,48 @@ struct HstuAttentionFwdKernel
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
};
struct HstuAttentionFwdBatchModeKargs : HstuAttentionFwdBatchModeBaseKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdBatchModeBiasKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
struct HstuAttentionNoGroupBatchedFwdKargs
: HstuAttentionNoGroupBatchedFwdBaseKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdBatchModeBiasKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
{
};
struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdJaggModeBaseKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdCommonBiasKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
struct HstuAttentionNoGroupJaggedFwdKargs
: HstuAttentionNoGroupJaggedFwdBaseKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdCommonBiasKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
{
};
using Kargs = std::
conditional_t<kIsJagged, HstuAttentionFwdJaggModeKargs, HstuAttentionFwdBatchModeKargs>;
struct HstuAttentionGroupFwdKargs : HstuAttentionGroupFwdBaseKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdCommonBiasKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
{
};
template <bool Cond = !kIsJagged>
using Kargs = std::conditional_t<kUseGroup,
HstuAttentionGroupFwdKargs,
std::conditional_t<kIsJagged,
HstuAttentionNoGroupJaggedFwdKargs,
HstuAttentionNoGroupBatchedFwdKargs>>;
static constexpr bool kUseNoGroupBatched = (!kUseGroup && !kIsJagged);
static constexpr bool kUseNoGroupJagged = (!kUseGroup && kIsJagged);
template <bool Cond = kUseNoGroupBatched>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
@@ -278,7 +342,7 @@ struct HstuAttentionFwdKernel
return kargs;
}
template <bool Cond = kIsJagged>
template <bool Cond = kUseNoGroupJagged>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
@@ -355,11 +419,95 @@ struct HstuAttentionFwdKernel
return kargs;
}
template <bool Cond = kUseGroup>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* o_ptr,
ck_tile::index_t num_batch_per_group,
const void* seq_q_offsets_ptr,
const void* seq_kv_offsets_ptr,
const void* group_max_seqlen_ptr,
const void* group_contextual_seqlen_ptr,
const void* group_window_size_ptr,
const void* group_min_full_attn_seqlen_ptr,
const void* group_attn_scale_ptr,
ck_tile::index_t hdim_qk,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
float scale_s,
ck_tile::index_t seq_stride_q,
ck_tile::index_t seq_stride_k,
ck_tile::index_t seq_stride_v,
ck_tile::index_t seq_stride_bias,
ck_tile::index_t seq_stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_o,
const void* num_targets_ptr,
float p_drop,
uint64_t philox_seed,
uint64_t philox_offset)
{
Kargs kargs{
{num_batch_per_group,
reinterpret_cast<const int32_t*>(seq_q_offsets_ptr),
reinterpret_cast<const int32_t*>(seq_kv_offsets_ptr),
seq_stride_q,
seq_stride_k,
seq_stride_v,
seq_stride_o,
reinterpret_cast<const int32_t*>(num_targets_ptr),
q_ptr,
k_ptr,
v_ptr,
o_ptr,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_o,
hdim_qk,
hdim_v,
-1, // seqlen_q will be updated by another pointer
-1, // seqlen_kv will be updated by another pointer
num_head,
scale_s,
1.0f, // to be set according to the per-group attn_scale and max_seqlen
0, // to be set by the per-group contextual_seqlen
0, // to be set by the per-group window_size
0, // to be set by the per-group min_full_attn_seqlen
reinterpret_cast<const int32_t*>(group_max_seqlen_ptr),
reinterpret_cast<const int32_t*>(group_contextual_seqlen_ptr),
reinterpret_cast<const int32_t*>(group_window_size_ptr),
reinterpret_cast<const int32_t*>(group_min_full_attn_seqlen_ptr),
reinterpret_cast<const float*>(group_attn_scale_ptr)}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dropout
};
if constexpr(kHasBias)
{
kargs.bias_ptr = bias_ptr;
kargs.seq_stride_bias = seq_stride_bias;
kargs.nhead_stride_bias = nhead_stride_bias;
}
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, philox_seed, philox_offset);
}
return kargs;
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_,
ck_tile::index_t hdim_v_,
bool has_minfull_attn_seqlen)
bool has_minfull_attn_seqlen = false)
{
// The Q sequence [0, seqlen) will be split to two parts for allocating workgroups:
// 1) [0, seqlen - target - min_full_attn_seqlen)
@@ -367,8 +515,15 @@ struct HstuAttentionFwdKernel
ck_tile::index_t num_tile_in_seqlen =
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0);
if(has_minfull_attn_seqlen)
if constexpr(kUseGroup)
{
num_tile_in_seqlen += 1;
}
else
{
if(has_minfull_attn_seqlen)
num_tile_in_seqlen += 1;
};
if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim)
{
@@ -492,6 +647,20 @@ struct HstuAttentionFwdKernel
kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch];
kargs.seqlen_kv =
kargs.seq_kv_offsets_ptr[i_batch + 1] - kargs.seq_kv_offsets_ptr[i_batch];
// read from device memory for the group specific mask and scaling parameters
if constexpr(kUseGroup)
{
index_t i_group =
__builtin_amdgcn_readfirstlane(i_batch / kargs.num_batch_per_group);
float attn_scale = kargs.group_attn_scale_ptr[i_group];
index_t max_seqlen = kargs.group_max_seqlen_ptr[i_group];
kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen));
kargs.contextual_seqlen = kargs.group_contextual_seqlen_ptr[i_group];
kargs.window_size = kargs.group_window_size_ptr[i_group];
kargs.min_full_attn_seqlen = kargs.group_min_full_attn_seqlen_ptr[i_group];
};
}
else
{

View File

@@ -6,11 +6,11 @@
#include "hstu_attention_bool_switch.hpp"
#include "hstu_attention_hdim_switch.hpp"
#include "hstu_attention_jagged_forward_dispatch.hpp"
#include "hstu_attention_group_forward_dispatch.hpp"
#include "instances/hstu_attention_jagged_forward_bf16_instances_ref.hpp"
#include "instances/hstu_attention_group_forward_bf16_instances_ref.hpp"
void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream)
void hstu_attention_group_forward_bf16(HstuAttentionGroupFwdParams& param, hipStream_t stream)
{
const bool has_dropout = (param.p_drop > 0.0f);
const bool has_bias = (param.bias_ptr != nullptr);
@@ -18,12 +18,12 @@ void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
run_group_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
});
});
});

View File

@@ -0,0 +1,184 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core/numeric/integer.hpp>
#include <ck_tile/host/kernel_launch.hpp>
#include <ck_tile/host/stream_config.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include "hstu_attention_bool_switch.hpp"
#include "hstu_attention_fwd_type_config.hpp"
#include "hstu_attention_fwd_setting.hpp"
#include "hstu_attention_params.hpp"
#include "hstu_attention_hdim_switch.hpp"
#include "hstu_attention_pipeline_problem.hpp"
#include "hstu_attention_traits.hpp"
#include "hstu_attention_with_softmax_fwd_pipeline.hpp"
#include "hstu_attention_no_softmax_fwd_pipeline.hpp"
#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp"
#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp"
#include "hstu_attention_fwd_kernel.hpp"
#include "hstu_attention_epilogue.hpp"
template <typename InOutDataType,
bool kUseCausal,
bool kUseSoftmax,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK>
struct group_forward_causal_softmax_bias_dropout_dispatch
{
using HstuAttentionTileSetting =
typename std::conditional_t<kUseSoftmax,
HstuAttentionWithSoftmaxFwdTileSetting<MaxK>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK>>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
#else
static constexpr bool kUseTrLoad = false;
#endif
template <bool kIsCrossAttention>
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
InOutDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
kIsCrossAttention,
true, // kUseGroup
true, // kIsJagged
kHasBias,
kHasDropout,
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionTileSetting>;
static void Run(HstuAttentionGroupFwdParams& param, hipStream_t stream)
{
constexpr ck_tile::index_t occupancy = -1;
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kQKHeaddim == 0);
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0);
// no need to check seqlen_q since it is not used as fastest dim,
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
constexpr bool kPadSeqLenQ = false;
constexpr bool kPadSeqLenK = true;
BOOL_SWITCH_2(pad_headdim_qk, kPadHeadDimQK, pad_headdim_v, kPadHeadDimV, [&] {
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDimQK,
kPadHeadDimV,
occupancy>;
using HstuEpilogue = ck_tile::NRepetitions2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
kPadSeqLenQ,
kPadHeadDimV>>;
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
using HstuPipelineProblem = HstuPipelineProblemTemp<kIsCrossAttention>;
if constexpr(!kUseTrLoad)
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
HstuTraits>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
HstuTraits>>;
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
RunWithKernel<HstuKernel>(param, stream);
}
else
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<
HstuPipelineProblem,
HstuTraits>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem,
HstuTraits>>;
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
RunWithKernel<HstuKernel>(param, stream);
};
});
});
};
template <typename HstuKernel>
static void RunWithKernel(HstuAttentionGroupFwdParams& param, hipStream_t stream)
{
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.q_ptr,
param.k_ptr,
param.v_ptr,
param.bias_ptr,
param.o_ptr,
param.num_batch / param.num_group,
param.seq_q_offsets_ptr,
param.is_cross_attention ? param.seq_kv_offsets_ptr
: param.seq_q_offsets_ptr,
param.group_max_seqlen_ptr,
param.group_contextual_seqlen_ptr,
param.group_window_size_ptr,
param.group_min_full_attn_seqlen_ptr,
param.group_attn_scale_ptr,
param.hdim_qk,
param.hdim_v,
param.num_head,
param.scale_s,
param.seq_stride_q,
param.seq_stride_k,
param.seq_stride_v,
param.seq_stride_bias,
param.seq_stride_o,
param.nhead_stride_q,
param.nhead_stride_k,
param.nhead_stride_v,
param.nhead_stride_bias,
param.nhead_stride_o,
param.num_targets_ptr,
param.p_drop,
param.philox_seed,
param.philox_offset);
}();
dim3 kGridSize =
HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen, param.hdim_v);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
};
};
template <typename InOutDataType,
bool kUseCausal,
bool kUseSoftmax,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK>
void run_group_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionGroupFwdParams& param,
hipStream_t stream)
{
group_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
};

View File

@@ -6,24 +6,25 @@
#include "hstu_attention_bool_switch.hpp"
#include "hstu_attention_hdim_switch.hpp"
#include "hstu_attention_jagged_forward_dispatch.hpp"
#include "hstu_attention_group_forward_dispatch.hpp"
#include "instances/hstu_attention_jagged_forward_fp16_instances_ref.hpp"
#include "instances/hstu_attention_group_forward_fp16_instances_ref.hpp"
void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream)
void hstu_attention_group_forward_fp16(HstuAttentionGroupFwdParams& param, hipStream_t stream)
{
const bool has_dropout = (param.p_drop > 0.0f);
const bool has_bias = (param.bias_ptr != nullptr);
const bool use_causal = param.use_causal;
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
run_group_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
});
});
});

View File

@@ -48,7 +48,8 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
kIsCrossAttention,
true, // kIsJagged
false, // kUseGroup
true, // kIsJagged
kHasBias,
kHasDropout,
kUseCausal,
@@ -56,7 +57,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
kUseTrLoad,
HstuAttentionTileSetting>;
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
{
constexpr ck_tile::index_t occupancy = -1;
@@ -117,7 +118,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
};
template <typename HstuKernel>
static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream)
static void RunWithKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
{
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.q_ptr,
@@ -174,7 +175,7 @@ template <typename InOutDataType,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK>
void run_jagged_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionFwdParams& param,
void run_jagged_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGroupFwdParams& param,
hipStream_t stream)
{
jagged_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,

View File

@@ -7,10 +7,12 @@
#include "hstu_attention_bool_switch.hpp"
#include "hstu_attention_hdim_switch.hpp"
#include "hstu_attention_batched_forward_dispatch.hpp"
#include "hstu_attention_jagged_forward_dispatch.hpp"
#include "instances/hstu_attention_batched_forward_bf16_instances_ref.hpp"
#include "instances/hstu_attention_jagged_forward_bf16_instances_ref.hpp"
void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream)
void hstu_attention_no_group_forward_bf16(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
{
const bool has_dropout = (param.p_drop > 0.0f);
const bool has_bias = (param.bias_ptr != nullptr);
@@ -18,12 +20,20 @@ void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStrea
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
if(param.is_jagged)
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
else
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
});
});
});

View File

@@ -7,10 +7,12 @@
#include "hstu_attention_bool_switch.hpp"
#include "hstu_attention_hdim_switch.hpp"
#include "hstu_attention_batched_forward_dispatch.hpp"
#include "hstu_attention_jagged_forward_dispatch.hpp"
#include "instances/hstu_attention_batched_forward_fp16_instances_ref.hpp"
#include "instances/hstu_attention_jagged_forward_fp16_instances_ref.hpp"
void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream)
void hstu_attention_no_group_forward_fp16(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
{
const bool has_dropout = (param.p_drop > 0.0f);
const bool has_bias = (param.bias_ptr != nullptr);
@@ -18,12 +20,20 @@ void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStrea
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
if(param.is_jagged)
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
else
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>(param, stream);
});
});
});

View File

@@ -5,7 +5,7 @@
#include <ck_tile/core.hpp>
struct HstuAttentionFwdParams
struct HstuAttentionNoGroupFwdParams
{
// for self-attention (is_cross_attention = false), we requires
// 1) either seqlen_kv == 0 or seqlen_kv == seqlen_q
@@ -55,6 +55,7 @@ struct HstuAttentionFwdParams
const void* num_targets_ptr;
bool use_causal;
// parameters used by Non-Group HSTU
ck_tile::index_t window_size;
ck_tile::index_t contextual_seqlen;
ck_tile::index_t min_full_attn_seqlen;
@@ -65,3 +66,63 @@ struct HstuAttentionFwdParams
uint64_t philox_seed;
uint64_t philox_offset;
};
struct HstuAttentionGroupFwdParams
{
// for self-attention (is_cross_attention = false), we requires
// 1) either seq_kv_offsets_ptr == nullptr, or seq_kv_offsets_ptr == seq_q_offsets_ptr
bool is_cross_attention;
ck_tile::index_t num_group;
ck_tile::index_t num_batch;
const void* seq_q_offsets_ptr;
const void* seq_kv_offsets_ptr;
ck_tile::index_t max_seqlen; // the maximum of all the groups' max_seqlen
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr;
void* o_ptr;
ck_tile::index_t hdim_qk;
ck_tile::index_t hdim_v;
ck_tile::index_t num_head;
float scale_s; // scaling factor exerted on the immediate Q@K result
ck_tile::index_t seq_stride_q;
ck_tile::index_t seq_stride_k;
ck_tile::index_t seq_stride_v;
ck_tile::index_t seq_stride_bias;
ck_tile::index_t seq_stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_o;
// batched mode only parameters
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_o;
const void* num_targets_ptr;
bool use_causal;
// parameters used by Group HSTU
const void* group_attn_scale_ptr;
const void* group_max_seqlen_ptr;
const void* group_window_size_ptr;
const void* group_contextual_seqlen_ptr;
const void* group_min_full_attn_seqlen_ptr;
bool use_softmax;
float p_drop;
uint64_t philox_seed;
uint64_t philox_offset;
};

View File

@@ -64,6 +64,7 @@ template <typename InOutDataType_,
typename CompDataType_, // data type for SiLU and other non-linear calculation
typename BiasDataType_,
bool kIsCrossAttention_,
bool kUseGroup_,
bool kIsJagged_,
bool kHasBias_,
bool kHasDropout_,
@@ -87,6 +88,7 @@ struct HstuAttentionFwdPipelineProblem
using PDataType = QKVDataType;
static constexpr bool kIsCrossAttention = kIsCrossAttention_;
static constexpr bool kUseGroup = kUseGroup_;
static constexpr bool kIsJagged = kIsJagged_;
static constexpr bool kHasBias = kHasBias_;
static constexpr bool kHasDropout = kHasDropout_;
@@ -94,6 +96,9 @@ struct HstuAttentionFwdPipelineProblem
static constexpr bool kUseSoftmax = kUseSoftmax_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
static_assert(!kUseGroup || (kUseGroup && kIsJagged),
"Group HSTU is only used with jagged mode!");
using HstuAttentionTileSetting = remove_cvref_t<AttentionTileSetting_>;
static constexpr index_t kNumGemm0Warps = AttentionTileSetting_::NumGemm0Warps;

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,7 +15,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -23,7 +23,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -31,7 +31,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -39,7 +39,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -47,7 +47,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -55,7 +55,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -63,7 +63,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -71,7 +71,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -79,7 +79,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -87,7 +87,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -95,7 +95,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -103,7 +103,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -111,7 +111,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -119,7 +119,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -127,7 +127,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -135,7 +135,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -143,7 +143,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -151,7 +151,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -159,7 +159,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -167,7 +167,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -175,7 +175,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -183,7 +183,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -191,7 +191,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -199,7 +199,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -207,7 +207,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -215,7 +215,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -223,7 +223,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -231,7 +231,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -239,7 +239,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -247,7 +247,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -255,7 +255,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -263,7 +263,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -271,7 +271,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -279,7 +279,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -287,7 +287,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -295,7 +295,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -303,7 +303,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -311,7 +311,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -319,7 +319,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -327,7 +327,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -335,7 +335,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -343,7 +343,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -351,7 +351,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -359,7 +359,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -367,7 +367,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -375,7 +375,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -383,7 +383,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -391,7 +391,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -399,7 +399,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -407,7 +407,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -415,7 +415,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -423,7 +423,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -431,7 +431,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -439,7 +439,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -447,7 +447,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -455,7 +455,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -463,7 +463,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -471,7 +471,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -479,7 +479,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -487,7 +487,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -495,7 +495,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -503,7 +503,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -511,7 +511,7 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);
extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
@@ -519,4 +519,4 @@ extern template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
false,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
true,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
false,
false,
false,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
128>(HstuAttentionFwdParams& param, hipStream_t stream);
128>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
256>(HstuAttentionFwdParams& param, hipStream_t stream);
256>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
64>(HstuAttentionFwdParams& param, hipStream_t stream);
64>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

View File

@@ -15,4 +15,4 @@ template void run_batched_forward_causal_softmax_bias_dropout_dispatch<
true,
true,
true,
96>(HstuAttentionFwdParams& param, hipStream_t stream);
96>(HstuAttentionNoGroupFwdParams& param, hipStream_t stream);

Some files were not shown because too many files have changed in this diff Show More