Clarify the using the max_seqlen and max_seqlen_q

This commit is contained in:
Qianfeng Zhang
2026-04-17 09:13:45 +00:00
parent 5c84f54fd9
commit db3263469c
9 changed files with 143 additions and 121 deletions

View File

@@ -103,6 +103,7 @@ auto create_args(int argc, char* argv[])
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal/bigger than the maximum of all uih seqlens")
.insert("max_seqlen_kv", "0", "max uih_seqlen_kv, can be ignored, or else must be equal/bigger than the maximum of all uih seqlens")
.insert("g_max_seqlens", "0", "max uih_seqlen of groups, can be ignored, or else each must be equal/bigger than maximum of all uih seqlens in its group")
.insert("g_max_seqlens_kv", "0", "max uih_seqlen of groups, can be ignored, or else each must be equal/bigger than maximum of all uih seqlens in its group")
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("max_target", "0", "max target, can be ignored, or else must be equal/bigger than the maximum of all targets")
.insert("softmax", "0", "use softmax or not")
@@ -478,7 +479,7 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
params.num_batch = num_batch;
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer();
params.max_seqlen = max(max_seqlen_q, max_seqlen_kv);
params.max_seqlen_q = max_seqlen_q;
params.q_ptr = q_dev.GetDeviceBuffer();
params.k_ptr = k_dev.GetDeviceBuffer();
params.v_ptr = v_dev.GetDeviceBuffer();
@@ -697,10 +698,11 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
else
is_cross_attention = true;
str_of_integers = arg_parser.get_str("g_max_seqlens");
std::vector<int> group_input_max_uih_seqlens = get_integers_from_string(str_of_integers);
str_of_integers = arg_parser.get_str("g_max_seqlens");
std::vector<int> group_input_max_uih_seqlens_q = get_integers_from_string(str_of_integers);
HSTU_CHECK(!group_input_max_uih_seqlens.empty(), "group window sizes shoud be defined!");
str_of_integers = arg_parser.get_str("g_max_seqlens_kv");
std::vector<int> group_input_max_uih_seqlens_kv = get_integers_from_string(str_of_integers);
str_of_integers = arg_parser.get_str("g_context_lens");
std::vector<int> group_contextual_seqlens = get_integers_from_string(str_of_integers);
@@ -721,10 +723,8 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
std::vector<float> group_attn_scales = get_floats_from_string(str_of_floats);
HSTU_CHECK(!group_attn_scales.empty(), "group attn_scales shoud be defined!");
// supplement seq_lengths_q using the last input value if user-provided lengths not enough
// supplement seq_lengths using the last input value if user-provided lengths not enough
supplement_array_by_last_element(seq_lengths_q, num_batch);
// supplement seq_lengths_kv using the last input value if user-provided lengths not enough
supplement_array_by_last_element(seq_lengths_kv, num_batch);
if(!num_targets.empty())
@@ -735,7 +735,8 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
// supplement group_input_max_uih_seqlens using the last input value if user-provided lengths
// not enough
supplement_array_by_last_element(group_input_max_uih_seqlens, num_group);
supplement_array_by_last_element(group_input_max_uih_seqlens_q, num_group);
supplement_array_by_last_element(group_input_max_uih_seqlens_kv, num_group);
// supplement group_contextual_seqlens using the last input value if user-provided lengths not
// enough
@@ -751,45 +752,64 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
// supplement group_attn_scales using the last input value if user-provided values not enough
supplement_array_by_last_element(group_attn_scales, num_group);
int phy_seqlen_q = 0;
int phy_seqlen_kv = 0;
int max_max_seqlen = 0;
int phy_seqlen_q = 0;
int phy_seqlen_kv = 0;
int max_max_seqlen_q = 0;
int max_max_seqlen_kv = 0;
std::vector<int> group_max_uih_seqlens;
std::vector<int> group_max_uih_seqlens_q;
std::vector<int> group_max_uih_seqlens_kv;
group_max_uih_seqlens.resize(num_group);
group_max_uih_seqlens_q.resize(num_group);
group_max_uih_seqlens_kv.resize(num_group);
for(int i_grp = 0; i_grp < num_group; i_grp++)
{
group_max_uih_seqlens[i_grp] = 0;
group_max_uih_seqlens_q[i_grp] = 0;
group_max_uih_seqlens_kv[i_grp] = 0;
for(int i_batch = 0; i_batch < num_batch_per_group; i_batch++)
{
auto i_global_batch = i_grp * num_batch_per_group + i_batch;
group_max_uih_seqlens[i_grp] =
max(group_max_uih_seqlens[i_grp], seq_lengths_q[i_global_batch]);
group_max_uih_seqlens_q[i_grp] =
max(group_max_uih_seqlens_q[i_grp], seq_lengths_q[i_global_batch]);
group_max_uih_seqlens_kv[i_grp] =
max(group_max_uih_seqlens_kv[i_grp], seq_lengths_kv[i_global_batch]);
};
HSTU_CHECK(group_input_max_uih_seqlens[i_grp] <= 0 ||
group_input_max_uih_seqlens[i_grp] >= group_max_uih_seqlens[i_grp],
HSTU_CHECK(group_input_max_uih_seqlens_q[i_grp] <= 0 ||
group_input_max_uih_seqlens_q[i_grp] >= group_max_uih_seqlens_q[i_grp],
"the user input of each group max_uih_seqlen can either be ignored or be bigger "
"than all uih_seqlens[] of the group");
group_max_uih_seqlens[i_grp] = group_input_max_uih_seqlens[i_grp] > 0
? group_input_max_uih_seqlens[i_grp]
: group_max_uih_seqlens[i_grp];
HSTU_CHECK(group_input_max_uih_seqlens_kv[i_grp] <= 0 ||
group_input_max_uih_seqlens_kv[i_grp] >= group_max_uih_seqlens_kv[i_grp],
"the user input of each group max_uih_seqlen can either be ignored or be bigger "
"than all uih_seqlens[] of the group");
group_max_uih_seqlens_q[i_grp] = group_input_max_uih_seqlens_q[i_grp] > 0
? group_input_max_uih_seqlens_q[i_grp]
: group_max_uih_seqlens_q[i_grp];
group_max_uih_seqlens_kv[i_grp] = group_input_max_uih_seqlens_kv[i_grp] > 0
? group_input_max_uih_seqlens_kv[i_grp]
: group_max_uih_seqlens_kv[i_grp];
};
std::vector<int> group_max_seqlens;
std::vector<int> group_max_seqlens_q;
std::vector<int> group_max_seqlens_kv;
group_max_seqlens.resize(num_group);
group_max_seqlens_q.resize(num_group);
group_max_seqlens_kv.resize(num_group);
for(int i_grp = 0; i_grp < num_group; i_grp++)
{
group_max_seqlens[i_grp] =
group_max_uih_seqlens[i_grp] + group_contextual_seqlens[i_grp] + num_targets[i_grp];
max_max_seqlen = max(max_max_seqlen, group_max_seqlens[i_grp]);
group_max_seqlens_q[i_grp] =
group_max_uih_seqlens_q[i_grp] + group_contextual_seqlens[i_grp] + num_targets[i_grp];
max_max_seqlen_q = max(max_max_seqlen_q, group_max_seqlens_q[i_grp]);
group_max_seqlens_kv[i_grp] =
group_max_uih_seqlens_kv[i_grp] + group_contextual_seqlens[i_grp] + num_targets[i_grp];
max_max_seqlen_kv = max(max_max_seqlen_kv, group_max_seqlens_kv[i_grp]);
};
std::vector<int> seq_offsets_q;
@@ -859,10 +879,12 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
ck_tile::HostTensor<InOutDataType> o_host_ref(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
ck_tile::HostTensor<int8_t> mask_host(
save_mask
? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_max_seqlen, max_max_seqlen}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<int8_t> mask_host(save_mask
? std::array<ck_tile::index_t, 4>{num_batch,
num_head,
max_max_seqlen_q,
max_max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
if(!initialize_qkv)
{
@@ -904,14 +926,14 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
if(!num_targets.empty())
num_targets_dev.ToDevice(num_targets.data());
ck_tile::DeviceMem group_max_seqlens_dev(group_max_seqlens.size() * sizeof(int));
ck_tile::DeviceMem group_max_seqlens_q_dev(group_max_seqlens_q.size() * sizeof(int));
ck_tile::DeviceMem group_contextual_seqlens_dev(group_contextual_seqlens.size() * sizeof(int));
ck_tile::DeviceMem group_window_sizes_dev(group_window_sizes.size() * sizeof(int));
ck_tile::DeviceMem group_min_full_attn_seqlens_dev(group_min_full_attn_seqlens.size() *
sizeof(int));
ck_tile::DeviceMem group_attn_scales_dev(group_attn_scales.size() * sizeof(float));
group_max_seqlens_dev.ToDevice(group_max_seqlens.data());
group_max_seqlens_q_dev.ToDevice(group_max_seqlens_q.data());
group_contextual_seqlens_dev.ToDevice(group_contextual_seqlens.data());
group_window_sizes_dev.ToDevice(group_window_sizes.data());
group_min_full_attn_seqlens_dev.ToDevice(group_min_full_attn_seqlens.data());
@@ -921,38 +943,38 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
float scale_s = (alpha != 0.f) ? alpha : 1.0f / std::sqrt(hdim_qk);
params.is_cross_attention = is_cross_attention;
params.num_batch = num_batch;
params.num_group = num_group;
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer();
params.max_seqlen = max_max_seqlen;
params.q_ptr = q_dev.GetDeviceBuffer();
params.k_ptr = k_dev.GetDeviceBuffer();
params.v_ptr = v_dev.GetDeviceBuffer();
params.bias_ptr = nullptr; // bias is not supported at present
params.o_ptr = o_dev.GetDeviceBuffer();
params.hdim_qk = hdim_qk;
params.hdim_v = hdim_v;
params.num_head = num_head;
params.scale_s = scale_s;
params.seq_stride_q = q_host.get_strides()[1];
params.seq_stride_k = k_host.get_strides()[1];
params.seq_stride_v = v_host.get_strides()[1];
params.seq_stride_bias = 0;
params.seq_stride_o = o_host_ref.get_strides()[1];
params.nhead_stride_q = q_host.get_strides()[2];
params.nhead_stride_k = k_host.get_strides()[2];
params.nhead_stride_v = v_host.get_strides()[2];
params.nhead_stride_bias = 0;
params.nhead_stride_o = o_host_ref.get_strides()[2];
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
params.use_softmax = use_softmax;
params.use_causal = use_causal;
params.p_drop = 0.0f; // dropout is not supported at present
params.philox_seed = 0UL;
params.philox_offset = 0UL;
params.group_max_seqlen_ptr = group_max_seqlens_dev.GetDeviceBuffer();
params.is_cross_attention = is_cross_attention;
params.num_batch = num_batch;
params.num_group = num_group;
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer();
params.max_seqlen_q = max_max_seqlen_q;
params.q_ptr = q_dev.GetDeviceBuffer();
params.k_ptr = k_dev.GetDeviceBuffer();
params.v_ptr = v_dev.GetDeviceBuffer();
params.bias_ptr = nullptr; // bias is not supported at present
params.o_ptr = o_dev.GetDeviceBuffer();
params.hdim_qk = hdim_qk;
params.hdim_v = hdim_v;
params.num_head = num_head;
params.scale_s = scale_s;
params.seq_stride_q = q_host.get_strides()[1];
params.seq_stride_k = k_host.get_strides()[1];
params.seq_stride_v = v_host.get_strides()[1];
params.seq_stride_bias = 0;
params.seq_stride_o = o_host_ref.get_strides()[1];
params.nhead_stride_q = q_host.get_strides()[2];
params.nhead_stride_k = k_host.get_strides()[2];
params.nhead_stride_v = v_host.get_strides()[2];
params.nhead_stride_bias = 0;
params.nhead_stride_o = o_host_ref.get_strides()[2];
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
params.use_softmax = use_softmax;
params.use_causal = use_causal;
params.p_drop = 0.0f; // dropout is not supported at present
params.philox_seed = 0UL;
params.philox_offset = 0UL;
params.group_max_seqlen_q_ptr = group_max_seqlens_q_dev.GetDeviceBuffer();
params.group_contextual_seqlen_ptr = group_contextual_seqlens_dev.GetDeviceBuffer();
params.group_window_size_ptr = group_window_sizes_dev.GetDeviceBuffer();
params.group_min_full_attn_seqlen_ptr = group_min_full_attn_seqlens_dev.GetDeviceBuffer();
@@ -994,11 +1016,12 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
num_batch,
num_batch / num_group,
scale_s,
max_max_seqlen,
max_max_seqlen_q,
max_max_seqlen_kv,
seq_offsets_q,
seq_offsets_kv,
num_targets,
group_max_seqlens,
group_max_seqlens_q,
group_contextual_seqlens,
group_window_sizes,
group_min_full_attn_seqlens,

View File

@@ -174,7 +174,7 @@ struct HstuAttentionFwdKernel
int32_t window_size; // to be set by the per-group window_size
int32_t min_full_attn_seqlen; // to be set by the per-group min_full_attn_seqlen
const int32_t* group_max_seqlen_ptr;
const int32_t* group_max_seqlen_q_ptr;
const int32_t* group_contextual_seqlen_ptr;
const int32_t* group_window_size_ptr;
const int32_t* group_min_full_attn_seqlen_ptr;
@@ -318,8 +318,7 @@ struct HstuAttentionFwdKernel
seq_stride_o,
num_head,
scale_s,
attn_scale ? attn_scale
: 1.0f / static_cast<float>(max(seqlen_q, seqlen_kv)), // max_seqlen
attn_scale ? attn_scale : 1.0f / static_cast<float>(seqlen_q),
contextual_seqlen,
window_size,
min_full_attn_seqlen}, // args for common karg
@@ -351,7 +350,7 @@ struct HstuAttentionFwdKernel
void* o_ptr,
const void* seq_q_offsets_ptr,
const void* seq_kv_offsets_ptr,
ck_tile::index_t max_seqlen,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_qk,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
@@ -397,7 +396,7 @@ struct HstuAttentionFwdKernel
-1, // seqlen_kv will be updated by another pointer
num_head,
scale_s,
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen),
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen_q),
contextual_seqlen,
window_size,
min_full_attn_seqlen}, // args for common karg
@@ -429,7 +428,7 @@ struct HstuAttentionFwdKernel
ck_tile::index_t num_batch_per_group,
const void* seq_q_offsets_ptr,
const void* seq_kv_offsets_ptr,
const void* group_max_seqlen_ptr,
const void* group_max_seqlen_q_ptr,
const void* group_contextual_seqlen_ptr,
const void* group_window_size_ptr,
const void* group_min_full_attn_seqlen_ptr,
@@ -480,7 +479,7 @@ struct HstuAttentionFwdKernel
0, // to be set by the per-group contextual_seqlen
0, // to be set by the per-group window_size
0, // to be set by the per-group min_full_attn_seqlen
reinterpret_cast<const int32_t*>(group_max_seqlen_ptr),
reinterpret_cast<const int32_t*>(group_max_seqlen_q_ptr),
reinterpret_cast<const int32_t*>(group_contextual_seqlen_ptr),
reinterpret_cast<const int32_t*>(group_window_size_ptr),
reinterpret_cast<const int32_t*>(group_min_full_attn_seqlen_ptr),
@@ -654,9 +653,9 @@ struct HstuAttentionFwdKernel
index_t i_group =
__builtin_amdgcn_readfirstlane(i_batch / kargs.num_batch_per_group);
float attn_scale = kargs.group_attn_scale_ptr[i_group];
index_t max_seqlen = kargs.group_max_seqlen_ptr[i_group];
kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen));
float attn_scale = kargs.group_attn_scale_ptr[i_group];
index_t max_seqlen_q = kargs.group_max_seqlen_q_ptr[i_group];
kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen_q));
kargs.contextual_seqlen = kargs.group_contextual_seqlen_ptr[i_group];
kargs.window_size = kargs.group_window_size_ptr[i_group];
kargs.min_full_attn_seqlen = kargs.group_min_full_attn_seqlen_ptr[i_group];

View File

@@ -173,7 +173,7 @@ struct HstuAttentionFwdSplitKVKernel
int32_t window_size; // to be set by the per-group window_size
int32_t min_full_attn_seqlen; // to be set by the per-group min_full_attn_seqlen
const int32_t* group_max_seqlen_ptr;
const int32_t* group_max_seqlen_q_ptr;
const int32_t* group_contextual_seqlen_ptr;
const int32_t* group_window_size_ptr;
const int32_t* group_min_full_attn_seqlen_ptr;
@@ -313,8 +313,7 @@ struct HstuAttentionFwdSplitKVKernel
seq_stride_v,
num_head,
scale_s,
attn_scale ? attn_scale
: 1.0f / static_cast<float>(max(seqlen_q, seqlen_kv)), // max_seqlen
attn_scale ? attn_scale : 1.0f / static_cast<float>(seqlen_q),
contextual_seqlen,
window_size,
min_full_attn_seqlen}, // args for common karg
@@ -347,7 +346,7 @@ struct HstuAttentionFwdSplitKVKernel
ck_tile::index_t num_splits, // number of splitted seqlen_kv
const void* seq_q_offsets_ptr,
const void* seq_kv_offsets_ptr,
ck_tile::index_t max_seqlen,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_qk,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
@@ -390,7 +389,7 @@ struct HstuAttentionFwdSplitKVKernel
-1, // seqlen_kv will be updated by another pointer
num_head,
scale_s,
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen),
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen_q),
contextual_seqlen,
window_size,
min_full_attn_seqlen}, // args for common karg
@@ -423,7 +422,7 @@ struct HstuAttentionFwdSplitKVKernel
ck_tile::index_t num_batch_per_group,
const void* seq_q_offsets_ptr,
const void* seq_kv_offsets_ptr,
const void* group_max_seqlen_ptr,
const void* group_max_seqlen_q_ptr,
const void* group_contextual_seqlen_ptr,
const void* group_window_size_ptr,
const void* group_min_full_attn_seqlen_ptr,
@@ -471,7 +470,7 @@ struct HstuAttentionFwdSplitKVKernel
0, // to be set by the per-group contextual_seqlen
0, // to be set by the per-group window_size
0, // to be set by the per-group min_full_attn_seqlen
reinterpret_cast<const int32_t*>(group_max_seqlen_ptr),
reinterpret_cast<const int32_t*>(group_max_seqlen_q_ptr),
reinterpret_cast<const int32_t*>(group_contextual_seqlen_ptr),
reinterpret_cast<const int32_t*>(group_window_size_ptr),
reinterpret_cast<const int32_t*>(group_min_full_attn_seqlen_ptr),
@@ -674,9 +673,9 @@ struct HstuAttentionFwdSplitKVKernel
index_t i_group =
__builtin_amdgcn_readfirstlane(i_batch / kargs.num_batch_per_group);
float attn_scale = kargs.group_attn_scale_ptr[i_group];
index_t max_seqlen = kargs.group_max_seqlen_ptr[i_group];
kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen));
float attn_scale = kargs.group_attn_scale_ptr[i_group];
index_t max_seqlen_q = kargs.group_max_seqlen_q_ptr[i_group];
kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen_q));
kargs.contextual_seqlen = kargs.group_contextual_seqlen_ptr[i_group];
kargs.window_size = kargs.group_window_size_ptr[i_group];
kargs.min_full_attn_seqlen = kargs.group_min_full_attn_seqlen_ptr[i_group];

View File

@@ -133,7 +133,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
param.seq_q_offsets_ptr,
param.is_cross_attention ? param.seq_kv_offsets_ptr
: param.seq_q_offsets_ptr,
param.group_max_seqlen_ptr,
param.group_max_seqlen_q_ptr,
param.group_contextual_seqlen_ptr,
param.group_window_size_ptr,
param.group_min_full_attn_seqlen_ptr,
@@ -159,7 +159,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
}();
dim3 kGridSize =
HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen, param.hdim_v);
HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
@@ -178,7 +178,7 @@ template <typename InOutDataType,
void run_group_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionGroupFwdParams& param,
hipStream_t stream)
{
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen) == 128)
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen_q) == 128)
group_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
@@ -197,7 +197,7 @@ void run_group_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionGroupFw
// ToDo: enable splitkv when kUseSoftmax is true
if(!disable_fwd_splitkv && !kUseSoftmax &&
shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen))
shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen_q))
{
if constexpr(!kUseSoftmax)
{

View File

@@ -176,11 +176,11 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
static void RunWithFwdSplitKVKernel(HstuAttentionGroupFwdParams& param, hipStream_t stream)
{
param.num_splits =
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen);
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen_q);
// assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head,
// num_splits, hdim]
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen *
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen_q *
param.num_head * param.num_splits * param.hdim_v *
sizeof(OaccDataType);
@@ -197,7 +197,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.seq_q_offsets_ptr,
param.is_cross_attention ? param.seq_kv_offsets_ptr
: param.seq_q_offsets_ptr,
param.group_max_seqlen_ptr,
param.group_max_seqlen_q_ptr,
param.group_contextual_seqlen_ptr,
param.group_window_size_ptr,
param.group_min_full_attn_seqlen_ptr,
@@ -221,7 +221,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
}();
dim3 kGridSize = HstuKernel::GridSize(
param.num_batch, param.num_head, param.max_seqlen, param.hdim_v, param.num_splits);
param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v, param.num_splits);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
@@ -245,7 +245,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.hdim_v);
}();
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen);
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;

View File

@@ -132,7 +132,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
param.seq_q_offsets_ptr,
param.is_cross_attention ? param.seq_kv_offsets_ptr
: param.seq_q_offsets_ptr,
param.max_seqlen,
param.max_seqlen_q,
param.hdim_qk,
param.hdim_v,
param.num_head,
@@ -160,7 +160,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
dim3 kGridSize = HstuKernel::GridSize(param.num_batch,
param.num_head,
param.max_seqlen,
param.max_seqlen_q,
param.hdim_v,
has_minfull_attn_seqlen);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
@@ -181,7 +181,7 @@ template <typename InOutDataType,
void run_jagged_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGroupFwdParams& param,
hipStream_t stream)
{
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen) == 128)
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen_q) == 128)
jagged_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
@@ -200,7 +200,7 @@ void run_jagged_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGrou
// ToDo: enable splitkv when kUseSoftmax is true
if(!disable_fwd_splitkv && !kUseSoftmax &&
shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen))
shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen_q))
{
if constexpr(!kUseSoftmax)
{

View File

@@ -175,11 +175,11 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
{
param.num_splits =
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen);
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen_q);
// assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head,
// num_splits, hdim]
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen *
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen_q *
param.num_head * param.num_splits * param.hdim_v *
sizeof(OaccDataType);
@@ -195,7 +195,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.seq_q_offsets_ptr,
param.is_cross_attention ? param.seq_kv_offsets_ptr
: param.seq_q_offsets_ptr,
param.max_seqlen,
param.max_seqlen_q,
param.hdim_qk,
param.hdim_v,
param.num_head,
@@ -221,7 +221,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
dim3 kGridSize = HstuKernel::GridSize(param.num_batch,
param.num_head,
param.max_seqlen,
param.max_seqlen_q,
param.hdim_v,
param.num_splits,
has_minfull_attn_seqlen);
@@ -248,7 +248,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.hdim_v);
}();
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen);
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;

View File

@@ -19,7 +19,7 @@ struct HstuAttentionNoGroupFwdParams
ck_tile::index_t seqlen_kv; // batched mode only
const void* seq_q_offsets_ptr; // jagged mode only
const void* seq_kv_offsets_ptr; // jagged mode only
ck_tile::index_t max_seqlen; // jagged mode only
ck_tile::index_t max_seqlen_q; // jagged mode only
const void* q_ptr;
const void* k_ptr;
@@ -84,7 +84,7 @@ struct HstuAttentionGroupFwdParams
ck_tile::index_t num_batch;
const void* seq_q_offsets_ptr;
const void* seq_kv_offsets_ptr;
ck_tile::index_t max_seqlen; // the maximum of all the groups' max_seqlen
ck_tile::index_t max_seqlen_q; // the maximum of all the groups' max_seqlen_q
const void* q_ptr;
const void* k_ptr;
@@ -122,7 +122,7 @@ struct HstuAttentionGroupFwdParams
// parameters used by Group HSTU
const void* group_attn_scale_ptr;
const void* group_max_seqlen_ptr;
const void* group_max_seqlen_q_ptr; // use for setting attn_scales
const void* group_window_size_ptr;
const void* group_contextual_seqlen_ptr;
const void* group_min_full_attn_seqlen_ptr;

View File

@@ -43,7 +43,7 @@ struct reference_no_group_hstu_attention
float alpha,
float attn_scale,
int max_seqlen_q,
int max_seqlen_kv,
int max_seqlen_kv, // only used as last dim of the tensor for saving the mask
std::vector<int> seq_q_offsets,
std::vector<int> seq_kv_offsets,
std::vector<int> num_targets, // define masking length at the end of token
@@ -116,9 +116,7 @@ struct reference_no_group_hstu_attention
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
float scale_p = attn_scale
? attn_scale
: 1.0f / static_cast<float>(max(max_seqlen_q, max_seqlen_kv));
float scale_p = attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen_q);
BOOL_SWITCH_2(window_size > 0, kHasLocal, is_cross_attention, kIsCrossAttention, [&] {
using HstuMaskType =
@@ -335,12 +333,15 @@ struct reference_group_hstu_attention
int num_batch,
int num_batch_per_group,
float alpha,
int max_max_seqlen, // the maximum of all groups's max_seqlen
int max_max_seqlen_q, // the maximum of all groups's max_seqlen_q, only used as second last
// dim of the tensor for saving the mask
int max_max_seqlen_kv, // the maximum of all groups's max_seqlen_k, only used as last dim of
// the tensor for saving the mask
const std::vector<int>& seq_q_offsets,
const std::vector<int>& seq_kv_offsets,
const std::vector<int>& num_targets, // define masking length at the end of token
// sequence to be excluded for attention
const std::vector<int>& group_max_seqlens, // max seqlen list by groups
const std::vector<int>& num_targets, // define masking length at the end of token
// sequence to be excluded for attention
const std::vector<int>& group_max_seqlens_q, // max seqlen_q list by groups
const std::vector<int>& group_contextual_seqlens, // contextual seqlen list by groups
const std::vector<int>& group_window_sizes, // window_size list by groups
const std::vector<int>& group_min_full_attn_seqlens, // min_full_attn_seqlen list by groups
@@ -374,8 +375,8 @@ struct reference_group_hstu_attention
if(static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[0]) == num_batch &&
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[1]) == num_head &&
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_max_seqlen &&
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_max_seqlen)
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_max_seqlen_q &&
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_max_seqlen_kv)
save_mask = true;
// check num_tagets
@@ -394,10 +395,10 @@ struct reference_group_hstu_attention
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
int max_seqlen = group_max_seqlens[i_group];
int max_seqlen_q = group_max_seqlens_q[i_group];
float attn_scale = group_attn_scales[i_group];
float scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen));
float scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen_q));
int contextual_seqlen = group_contextual_seqlens[i_group];
int window_size = group_window_sizes[i_group];
@@ -468,8 +469,8 @@ struct reference_group_hstu_attention
if(save_mask)
{
for(int sq = 0; sq < max_seqlen; sq++)
for(int sk = 0; sk < max_seqlen; sk++)
for(int sq = 0; sq < max_max_seqlen_q; sq++)
for(int sk = 0; sk < max_max_seqlen_kv; sk++)
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) = 0;
for(int sq = 0; sq < seqlen_q; sq++)