mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Clarify the using the max_seqlen and max_seqlen_q
This commit is contained in:
@@ -103,6 +103,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal/bigger than the maximum of all uih seqlens")
|
||||
.insert("max_seqlen_kv", "0", "max uih_seqlen_kv, can be ignored, or else must be equal/bigger than the maximum of all uih seqlens")
|
||||
.insert("g_max_seqlens", "0", "max uih_seqlen of groups, can be ignored, or else each must be equal/bigger than maximum of all uih seqlens in its group")
|
||||
.insert("g_max_seqlens_kv", "0", "max uih_seqlen of groups, can be ignored, or else each must be equal/bigger than maximum of all uih seqlens in its group")
|
||||
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("max_target", "0", "max target, can be ignored, or else must be equal/bigger than the maximum of all targets")
|
||||
.insert("softmax", "0", "use softmax or not")
|
||||
@@ -478,7 +479,7 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
|
||||
params.num_batch = num_batch;
|
||||
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
|
||||
params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer();
|
||||
params.max_seqlen = max(max_seqlen_q, max_seqlen_kv);
|
||||
params.max_seqlen_q = max_seqlen_q;
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
@@ -697,10 +698,11 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
else
|
||||
is_cross_attention = true;
|
||||
|
||||
str_of_integers = arg_parser.get_str("g_max_seqlens");
|
||||
std::vector<int> group_input_max_uih_seqlens = get_integers_from_string(str_of_integers);
|
||||
str_of_integers = arg_parser.get_str("g_max_seqlens");
|
||||
std::vector<int> group_input_max_uih_seqlens_q = get_integers_from_string(str_of_integers);
|
||||
|
||||
HSTU_CHECK(!group_input_max_uih_seqlens.empty(), "group window sizes shoud be defined!");
|
||||
str_of_integers = arg_parser.get_str("g_max_seqlens_kv");
|
||||
std::vector<int> group_input_max_uih_seqlens_kv = get_integers_from_string(str_of_integers);
|
||||
|
||||
str_of_integers = arg_parser.get_str("g_context_lens");
|
||||
std::vector<int> group_contextual_seqlens = get_integers_from_string(str_of_integers);
|
||||
@@ -721,10 +723,8 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
std::vector<float> group_attn_scales = get_floats_from_string(str_of_floats);
|
||||
HSTU_CHECK(!group_attn_scales.empty(), "group attn_scales shoud be defined!");
|
||||
|
||||
// supplement seq_lengths_q using the last input value if user-provided lengths not enough
|
||||
// supplement seq_lengths using the last input value if user-provided lengths not enough
|
||||
supplement_array_by_last_element(seq_lengths_q, num_batch);
|
||||
|
||||
// supplement seq_lengths_kv using the last input value if user-provided lengths not enough
|
||||
supplement_array_by_last_element(seq_lengths_kv, num_batch);
|
||||
|
||||
if(!num_targets.empty())
|
||||
@@ -735,7 +735,8 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
|
||||
// supplement group_input_max_uih_seqlens using the last input value if user-provided lengths
|
||||
// not enough
|
||||
supplement_array_by_last_element(group_input_max_uih_seqlens, num_group);
|
||||
supplement_array_by_last_element(group_input_max_uih_seqlens_q, num_group);
|
||||
supplement_array_by_last_element(group_input_max_uih_seqlens_kv, num_group);
|
||||
|
||||
// supplement group_contextual_seqlens using the last input value if user-provided lengths not
|
||||
// enough
|
||||
@@ -751,45 +752,64 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
// supplement group_attn_scales using the last input value if user-provided values not enough
|
||||
supplement_array_by_last_element(group_attn_scales, num_group);
|
||||
|
||||
int phy_seqlen_q = 0;
|
||||
int phy_seqlen_kv = 0;
|
||||
int max_max_seqlen = 0;
|
||||
int phy_seqlen_q = 0;
|
||||
int phy_seqlen_kv = 0;
|
||||
int max_max_seqlen_q = 0;
|
||||
int max_max_seqlen_kv = 0;
|
||||
|
||||
std::vector<int> group_max_uih_seqlens;
|
||||
std::vector<int> group_max_uih_seqlens_q;
|
||||
std::vector<int> group_max_uih_seqlens_kv;
|
||||
|
||||
group_max_uih_seqlens.resize(num_group);
|
||||
group_max_uih_seqlens_q.resize(num_group);
|
||||
group_max_uih_seqlens_kv.resize(num_group);
|
||||
|
||||
for(int i_grp = 0; i_grp < num_group; i_grp++)
|
||||
{
|
||||
group_max_uih_seqlens[i_grp] = 0;
|
||||
group_max_uih_seqlens_q[i_grp] = 0;
|
||||
group_max_uih_seqlens_kv[i_grp] = 0;
|
||||
|
||||
for(int i_batch = 0; i_batch < num_batch_per_group; i_batch++)
|
||||
{
|
||||
auto i_global_batch = i_grp * num_batch_per_group + i_batch;
|
||||
|
||||
group_max_uih_seqlens[i_grp] =
|
||||
max(group_max_uih_seqlens[i_grp], seq_lengths_q[i_global_batch]);
|
||||
group_max_uih_seqlens_q[i_grp] =
|
||||
max(group_max_uih_seqlens_q[i_grp], seq_lengths_q[i_global_batch]);
|
||||
group_max_uih_seqlens_kv[i_grp] =
|
||||
max(group_max_uih_seqlens_kv[i_grp], seq_lengths_kv[i_global_batch]);
|
||||
};
|
||||
|
||||
HSTU_CHECK(group_input_max_uih_seqlens[i_grp] <= 0 ||
|
||||
group_input_max_uih_seqlens[i_grp] >= group_max_uih_seqlens[i_grp],
|
||||
HSTU_CHECK(group_input_max_uih_seqlens_q[i_grp] <= 0 ||
|
||||
group_input_max_uih_seqlens_q[i_grp] >= group_max_uih_seqlens_q[i_grp],
|
||||
"the user input of each group max_uih_seqlen can either be ignored or be bigger "
|
||||
"than all uih_seqlens[] of the group");
|
||||
|
||||
group_max_uih_seqlens[i_grp] = group_input_max_uih_seqlens[i_grp] > 0
|
||||
? group_input_max_uih_seqlens[i_grp]
|
||||
: group_max_uih_seqlens[i_grp];
|
||||
HSTU_CHECK(group_input_max_uih_seqlens_kv[i_grp] <= 0 ||
|
||||
group_input_max_uih_seqlens_kv[i_grp] >= group_max_uih_seqlens_kv[i_grp],
|
||||
"the user input of each group max_uih_seqlen can either be ignored or be bigger "
|
||||
"than all uih_seqlens[] of the group");
|
||||
|
||||
group_max_uih_seqlens_q[i_grp] = group_input_max_uih_seqlens_q[i_grp] > 0
|
||||
? group_input_max_uih_seqlens_q[i_grp]
|
||||
: group_max_uih_seqlens_q[i_grp];
|
||||
group_max_uih_seqlens_kv[i_grp] = group_input_max_uih_seqlens_kv[i_grp] > 0
|
||||
? group_input_max_uih_seqlens_kv[i_grp]
|
||||
: group_max_uih_seqlens_kv[i_grp];
|
||||
};
|
||||
|
||||
std::vector<int> group_max_seqlens;
|
||||
std::vector<int> group_max_seqlens_q;
|
||||
std::vector<int> group_max_seqlens_kv;
|
||||
|
||||
group_max_seqlens.resize(num_group);
|
||||
group_max_seqlens_q.resize(num_group);
|
||||
group_max_seqlens_kv.resize(num_group);
|
||||
|
||||
for(int i_grp = 0; i_grp < num_group; i_grp++)
|
||||
{
|
||||
group_max_seqlens[i_grp] =
|
||||
group_max_uih_seqlens[i_grp] + group_contextual_seqlens[i_grp] + num_targets[i_grp];
|
||||
max_max_seqlen = max(max_max_seqlen, group_max_seqlens[i_grp]);
|
||||
group_max_seqlens_q[i_grp] =
|
||||
group_max_uih_seqlens_q[i_grp] + group_contextual_seqlens[i_grp] + num_targets[i_grp];
|
||||
max_max_seqlen_q = max(max_max_seqlen_q, group_max_seqlens_q[i_grp]);
|
||||
group_max_seqlens_kv[i_grp] =
|
||||
group_max_uih_seqlens_kv[i_grp] + group_contextual_seqlens[i_grp] + num_targets[i_grp];
|
||||
max_max_seqlen_kv = max(max_max_seqlen_kv, group_max_seqlens_kv[i_grp]);
|
||||
};
|
||||
|
||||
std::vector<int> seq_offsets_q;
|
||||
@@ -859,10 +879,12 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
ck_tile::HostTensor<InOutDataType> o_host_ref(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
|
||||
|
||||
ck_tile::HostTensor<int8_t> mask_host(
|
||||
save_mask
|
||||
? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_max_seqlen, max_max_seqlen}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
ck_tile::HostTensor<int8_t> mask_host(save_mask
|
||||
? std::array<ck_tile::index_t, 4>{num_batch,
|
||||
num_head,
|
||||
max_max_seqlen_q,
|
||||
max_max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
|
||||
if(!initialize_qkv)
|
||||
{
|
||||
@@ -904,14 +926,14 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
if(!num_targets.empty())
|
||||
num_targets_dev.ToDevice(num_targets.data());
|
||||
|
||||
ck_tile::DeviceMem group_max_seqlens_dev(group_max_seqlens.size() * sizeof(int));
|
||||
ck_tile::DeviceMem group_max_seqlens_q_dev(group_max_seqlens_q.size() * sizeof(int));
|
||||
ck_tile::DeviceMem group_contextual_seqlens_dev(group_contextual_seqlens.size() * sizeof(int));
|
||||
ck_tile::DeviceMem group_window_sizes_dev(group_window_sizes.size() * sizeof(int));
|
||||
ck_tile::DeviceMem group_min_full_attn_seqlens_dev(group_min_full_attn_seqlens.size() *
|
||||
sizeof(int));
|
||||
ck_tile::DeviceMem group_attn_scales_dev(group_attn_scales.size() * sizeof(float));
|
||||
|
||||
group_max_seqlens_dev.ToDevice(group_max_seqlens.data());
|
||||
group_max_seqlens_q_dev.ToDevice(group_max_seqlens_q.data());
|
||||
group_contextual_seqlens_dev.ToDevice(group_contextual_seqlens.data());
|
||||
group_window_sizes_dev.ToDevice(group_window_sizes.data());
|
||||
group_min_full_attn_seqlens_dev.ToDevice(group_min_full_attn_seqlens.data());
|
||||
@@ -921,38 +943,38 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
|
||||
float scale_s = (alpha != 0.f) ? alpha : 1.0f / std::sqrt(hdim_qk);
|
||||
|
||||
params.is_cross_attention = is_cross_attention;
|
||||
params.num_batch = num_batch;
|
||||
params.num_group = num_group;
|
||||
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
|
||||
params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer();
|
||||
params.max_seqlen = max_max_seqlen;
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
params.bias_ptr = nullptr; // bias is not supported at present
|
||||
params.o_ptr = o_dev.GetDeviceBuffer();
|
||||
params.hdim_qk = hdim_qk;
|
||||
params.hdim_v = hdim_v;
|
||||
params.num_head = num_head;
|
||||
params.scale_s = scale_s;
|
||||
params.seq_stride_q = q_host.get_strides()[1];
|
||||
params.seq_stride_k = k_host.get_strides()[1];
|
||||
params.seq_stride_v = v_host.get_strides()[1];
|
||||
params.seq_stride_bias = 0;
|
||||
params.seq_stride_o = o_host_ref.get_strides()[1];
|
||||
params.nhead_stride_q = q_host.get_strides()[2];
|
||||
params.nhead_stride_k = k_host.get_strides()[2];
|
||||
params.nhead_stride_v = v_host.get_strides()[2];
|
||||
params.nhead_stride_bias = 0;
|
||||
params.nhead_stride_o = o_host_ref.get_strides()[2];
|
||||
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
|
||||
params.use_softmax = use_softmax;
|
||||
params.use_causal = use_causal;
|
||||
params.p_drop = 0.0f; // dropout is not supported at present
|
||||
params.philox_seed = 0UL;
|
||||
params.philox_offset = 0UL;
|
||||
params.group_max_seqlen_ptr = group_max_seqlens_dev.GetDeviceBuffer();
|
||||
params.is_cross_attention = is_cross_attention;
|
||||
params.num_batch = num_batch;
|
||||
params.num_group = num_group;
|
||||
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
|
||||
params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer();
|
||||
params.max_seqlen_q = max_max_seqlen_q;
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
params.bias_ptr = nullptr; // bias is not supported at present
|
||||
params.o_ptr = o_dev.GetDeviceBuffer();
|
||||
params.hdim_qk = hdim_qk;
|
||||
params.hdim_v = hdim_v;
|
||||
params.num_head = num_head;
|
||||
params.scale_s = scale_s;
|
||||
params.seq_stride_q = q_host.get_strides()[1];
|
||||
params.seq_stride_k = k_host.get_strides()[1];
|
||||
params.seq_stride_v = v_host.get_strides()[1];
|
||||
params.seq_stride_bias = 0;
|
||||
params.seq_stride_o = o_host_ref.get_strides()[1];
|
||||
params.nhead_stride_q = q_host.get_strides()[2];
|
||||
params.nhead_stride_k = k_host.get_strides()[2];
|
||||
params.nhead_stride_v = v_host.get_strides()[2];
|
||||
params.nhead_stride_bias = 0;
|
||||
params.nhead_stride_o = o_host_ref.get_strides()[2];
|
||||
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
|
||||
params.use_softmax = use_softmax;
|
||||
params.use_causal = use_causal;
|
||||
params.p_drop = 0.0f; // dropout is not supported at present
|
||||
params.philox_seed = 0UL;
|
||||
params.philox_offset = 0UL;
|
||||
params.group_max_seqlen_q_ptr = group_max_seqlens_q_dev.GetDeviceBuffer();
|
||||
params.group_contextual_seqlen_ptr = group_contextual_seqlens_dev.GetDeviceBuffer();
|
||||
params.group_window_size_ptr = group_window_sizes_dev.GetDeviceBuffer();
|
||||
params.group_min_full_attn_seqlen_ptr = group_min_full_attn_seqlens_dev.GetDeviceBuffer();
|
||||
@@ -994,11 +1016,12 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
num_batch,
|
||||
num_batch / num_group,
|
||||
scale_s,
|
||||
max_max_seqlen,
|
||||
max_max_seqlen_q,
|
||||
max_max_seqlen_kv,
|
||||
seq_offsets_q,
|
||||
seq_offsets_kv,
|
||||
num_targets,
|
||||
group_max_seqlens,
|
||||
group_max_seqlens_q,
|
||||
group_contextual_seqlens,
|
||||
group_window_sizes,
|
||||
group_min_full_attn_seqlens,
|
||||
|
||||
@@ -174,7 +174,7 @@ struct HstuAttentionFwdKernel
|
||||
int32_t window_size; // to be set by the per-group window_size
|
||||
int32_t min_full_attn_seqlen; // to be set by the per-group min_full_attn_seqlen
|
||||
|
||||
const int32_t* group_max_seqlen_ptr;
|
||||
const int32_t* group_max_seqlen_q_ptr;
|
||||
const int32_t* group_contextual_seqlen_ptr;
|
||||
const int32_t* group_window_size_ptr;
|
||||
const int32_t* group_min_full_attn_seqlen_ptr;
|
||||
@@ -318,8 +318,7 @@ struct HstuAttentionFwdKernel
|
||||
seq_stride_o,
|
||||
num_head,
|
||||
scale_s,
|
||||
attn_scale ? attn_scale
|
||||
: 1.0f / static_cast<float>(max(seqlen_q, seqlen_kv)), // max_seqlen
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(seqlen_q),
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen}, // args for common karg
|
||||
@@ -351,7 +350,7 @@ struct HstuAttentionFwdKernel
|
||||
void* o_ptr,
|
||||
const void* seq_q_offsets_ptr,
|
||||
const void* seq_kv_offsets_ptr,
|
||||
ck_tile::index_t max_seqlen,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
@@ -397,7 +396,7 @@ struct HstuAttentionFwdKernel
|
||||
-1, // seqlen_kv will be updated by another pointer
|
||||
num_head,
|
||||
scale_s,
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen),
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen_q),
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen}, // args for common karg
|
||||
@@ -429,7 +428,7 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t num_batch_per_group,
|
||||
const void* seq_q_offsets_ptr,
|
||||
const void* seq_kv_offsets_ptr,
|
||||
const void* group_max_seqlen_ptr,
|
||||
const void* group_max_seqlen_q_ptr,
|
||||
const void* group_contextual_seqlen_ptr,
|
||||
const void* group_window_size_ptr,
|
||||
const void* group_min_full_attn_seqlen_ptr,
|
||||
@@ -480,7 +479,7 @@ struct HstuAttentionFwdKernel
|
||||
0, // to be set by the per-group contextual_seqlen
|
||||
0, // to be set by the per-group window_size
|
||||
0, // to be set by the per-group min_full_attn_seqlen
|
||||
reinterpret_cast<const int32_t*>(group_max_seqlen_ptr),
|
||||
reinterpret_cast<const int32_t*>(group_max_seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(group_contextual_seqlen_ptr),
|
||||
reinterpret_cast<const int32_t*>(group_window_size_ptr),
|
||||
reinterpret_cast<const int32_t*>(group_min_full_attn_seqlen_ptr),
|
||||
@@ -654,9 +653,9 @@ struct HstuAttentionFwdKernel
|
||||
index_t i_group =
|
||||
__builtin_amdgcn_readfirstlane(i_batch / kargs.num_batch_per_group);
|
||||
|
||||
float attn_scale = kargs.group_attn_scale_ptr[i_group];
|
||||
index_t max_seqlen = kargs.group_max_seqlen_ptr[i_group];
|
||||
kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen));
|
||||
float attn_scale = kargs.group_attn_scale_ptr[i_group];
|
||||
index_t max_seqlen_q = kargs.group_max_seqlen_q_ptr[i_group];
|
||||
kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen_q));
|
||||
kargs.contextual_seqlen = kargs.group_contextual_seqlen_ptr[i_group];
|
||||
kargs.window_size = kargs.group_window_size_ptr[i_group];
|
||||
kargs.min_full_attn_seqlen = kargs.group_min_full_attn_seqlen_ptr[i_group];
|
||||
|
||||
@@ -173,7 +173,7 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
int32_t window_size; // to be set by the per-group window_size
|
||||
int32_t min_full_attn_seqlen; // to be set by the per-group min_full_attn_seqlen
|
||||
|
||||
const int32_t* group_max_seqlen_ptr;
|
||||
const int32_t* group_max_seqlen_q_ptr;
|
||||
const int32_t* group_contextual_seqlen_ptr;
|
||||
const int32_t* group_window_size_ptr;
|
||||
const int32_t* group_min_full_attn_seqlen_ptr;
|
||||
@@ -313,8 +313,7 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
seq_stride_v,
|
||||
num_head,
|
||||
scale_s,
|
||||
attn_scale ? attn_scale
|
||||
: 1.0f / static_cast<float>(max(seqlen_q, seqlen_kv)), // max_seqlen
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(seqlen_q),
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen}, // args for common karg
|
||||
@@ -347,7 +346,7 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
ck_tile::index_t num_splits, // number of splitted seqlen_kv
|
||||
const void* seq_q_offsets_ptr,
|
||||
const void* seq_kv_offsets_ptr,
|
||||
ck_tile::index_t max_seqlen,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
@@ -390,7 +389,7 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
-1, // seqlen_kv will be updated by another pointer
|
||||
num_head,
|
||||
scale_s,
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen),
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen_q),
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen}, // args for common karg
|
||||
@@ -423,7 +422,7 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
ck_tile::index_t num_batch_per_group,
|
||||
const void* seq_q_offsets_ptr,
|
||||
const void* seq_kv_offsets_ptr,
|
||||
const void* group_max_seqlen_ptr,
|
||||
const void* group_max_seqlen_q_ptr,
|
||||
const void* group_contextual_seqlen_ptr,
|
||||
const void* group_window_size_ptr,
|
||||
const void* group_min_full_attn_seqlen_ptr,
|
||||
@@ -471,7 +470,7 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
0, // to be set by the per-group contextual_seqlen
|
||||
0, // to be set by the per-group window_size
|
||||
0, // to be set by the per-group min_full_attn_seqlen
|
||||
reinterpret_cast<const int32_t*>(group_max_seqlen_ptr),
|
||||
reinterpret_cast<const int32_t*>(group_max_seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(group_contextual_seqlen_ptr),
|
||||
reinterpret_cast<const int32_t*>(group_window_size_ptr),
|
||||
reinterpret_cast<const int32_t*>(group_min_full_attn_seqlen_ptr),
|
||||
@@ -674,9 +673,9 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
index_t i_group =
|
||||
__builtin_amdgcn_readfirstlane(i_batch / kargs.num_batch_per_group);
|
||||
|
||||
float attn_scale = kargs.group_attn_scale_ptr[i_group];
|
||||
index_t max_seqlen = kargs.group_max_seqlen_ptr[i_group];
|
||||
kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen));
|
||||
float attn_scale = kargs.group_attn_scale_ptr[i_group];
|
||||
index_t max_seqlen_q = kargs.group_max_seqlen_q_ptr[i_group];
|
||||
kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen_q));
|
||||
kargs.contextual_seqlen = kargs.group_contextual_seqlen_ptr[i_group];
|
||||
kargs.window_size = kargs.group_window_size_ptr[i_group];
|
||||
kargs.min_full_attn_seqlen = kargs.group_min_full_attn_seqlen_ptr[i_group];
|
||||
|
||||
@@ -133,7 +133,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
|
||||
param.seq_q_offsets_ptr,
|
||||
param.is_cross_attention ? param.seq_kv_offsets_ptr
|
||||
: param.seq_q_offsets_ptr,
|
||||
param.group_max_seqlen_ptr,
|
||||
param.group_max_seqlen_q_ptr,
|
||||
param.group_contextual_seqlen_ptr,
|
||||
param.group_window_size_ptr,
|
||||
param.group_min_full_attn_seqlen_ptr,
|
||||
@@ -159,7 +159,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
|
||||
}();
|
||||
|
||||
dim3 kGridSize =
|
||||
HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen, param.hdim_v);
|
||||
HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
@@ -178,7 +178,7 @@ template <typename InOutDataType,
|
||||
void run_group_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionGroupFwdParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen) == 128)
|
||||
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen_q) == 128)
|
||||
group_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
@@ -197,7 +197,7 @@ void run_group_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionGroupFw
|
||||
|
||||
// ToDo: enable splitkv when kUseSoftmax is true
|
||||
if(!disable_fwd_splitkv && !kUseSoftmax &&
|
||||
shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen))
|
||||
shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen_q))
|
||||
{
|
||||
if constexpr(!kUseSoftmax)
|
||||
{
|
||||
|
||||
@@ -176,11 +176,11 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
static void RunWithFwdSplitKVKernel(HstuAttentionGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
param.num_splits =
|
||||
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen);
|
||||
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen_q);
|
||||
|
||||
// assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head,
|
||||
// num_splits, hdim]
|
||||
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen *
|
||||
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen_q *
|
||||
param.num_head * param.num_splits * param.hdim_v *
|
||||
sizeof(OaccDataType);
|
||||
|
||||
@@ -197,7 +197,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.seq_q_offsets_ptr,
|
||||
param.is_cross_attention ? param.seq_kv_offsets_ptr
|
||||
: param.seq_q_offsets_ptr,
|
||||
param.group_max_seqlen_ptr,
|
||||
param.group_max_seqlen_q_ptr,
|
||||
param.group_contextual_seqlen_ptr,
|
||||
param.group_window_size_ptr,
|
||||
param.group_min_full_attn_seqlen_ptr,
|
||||
@@ -221,7 +221,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
}();
|
||||
|
||||
dim3 kGridSize = HstuKernel::GridSize(
|
||||
param.num_batch, param.num_head, param.max_seqlen, param.hdim_v, param.num_splits);
|
||||
param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v, param.num_splits);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
@@ -245,7 +245,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.hdim_v);
|
||||
}();
|
||||
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen);
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
param.seq_q_offsets_ptr,
|
||||
param.is_cross_attention ? param.seq_kv_offsets_ptr
|
||||
: param.seq_q_offsets_ptr,
|
||||
param.max_seqlen,
|
||||
param.max_seqlen_q,
|
||||
param.hdim_qk,
|
||||
param.hdim_v,
|
||||
param.num_head,
|
||||
@@ -160,7 +160,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch,
|
||||
param.num_head,
|
||||
param.max_seqlen,
|
||||
param.max_seqlen_q,
|
||||
param.hdim_v,
|
||||
has_minfull_attn_seqlen);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
@@ -181,7 +181,7 @@ template <typename InOutDataType,
|
||||
void run_jagged_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGroupFwdParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen) == 128)
|
||||
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen_q) == 128)
|
||||
jagged_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
@@ -200,7 +200,7 @@ void run_jagged_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGrou
|
||||
|
||||
// ToDo: enable splitkv when kUseSoftmax is true
|
||||
if(!disable_fwd_splitkv && !kUseSoftmax &&
|
||||
shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen))
|
||||
shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen_q))
|
||||
{
|
||||
if constexpr(!kUseSoftmax)
|
||||
{
|
||||
|
||||
@@ -175,11 +175,11 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
param.num_splits =
|
||||
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen);
|
||||
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen_q);
|
||||
|
||||
// assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head,
|
||||
// num_splits, hdim]
|
||||
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen *
|
||||
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen_q *
|
||||
param.num_head * param.num_splits * param.hdim_v *
|
||||
sizeof(OaccDataType);
|
||||
|
||||
@@ -195,7 +195,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.seq_q_offsets_ptr,
|
||||
param.is_cross_attention ? param.seq_kv_offsets_ptr
|
||||
: param.seq_q_offsets_ptr,
|
||||
param.max_seqlen,
|
||||
param.max_seqlen_q,
|
||||
param.hdim_qk,
|
||||
param.hdim_v,
|
||||
param.num_head,
|
||||
@@ -221,7 +221,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch,
|
||||
param.num_head,
|
||||
param.max_seqlen,
|
||||
param.max_seqlen_q,
|
||||
param.hdim_v,
|
||||
param.num_splits,
|
||||
has_minfull_attn_seqlen);
|
||||
@@ -248,7 +248,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.hdim_v);
|
||||
}();
|
||||
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen);
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ struct HstuAttentionNoGroupFwdParams
|
||||
ck_tile::index_t seqlen_kv; // batched mode only
|
||||
const void* seq_q_offsets_ptr; // jagged mode only
|
||||
const void* seq_kv_offsets_ptr; // jagged mode only
|
||||
ck_tile::index_t max_seqlen; // jagged mode only
|
||||
ck_tile::index_t max_seqlen_q; // jagged mode only
|
||||
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
@@ -84,7 +84,7 @@ struct HstuAttentionGroupFwdParams
|
||||
ck_tile::index_t num_batch;
|
||||
const void* seq_q_offsets_ptr;
|
||||
const void* seq_kv_offsets_ptr;
|
||||
ck_tile::index_t max_seqlen; // the maximum of all the groups' max_seqlen
|
||||
ck_tile::index_t max_seqlen_q; // the maximum of all the groups' max_seqlen_q
|
||||
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
@@ -122,7 +122,7 @@ struct HstuAttentionGroupFwdParams
|
||||
|
||||
// parameters used by Group HSTU
|
||||
const void* group_attn_scale_ptr;
|
||||
const void* group_max_seqlen_ptr;
|
||||
const void* group_max_seqlen_q_ptr; // use for setting attn_scales
|
||||
const void* group_window_size_ptr;
|
||||
const void* group_contextual_seqlen_ptr;
|
||||
const void* group_min_full_attn_seqlen_ptr;
|
||||
|
||||
@@ -43,7 +43,7 @@ struct reference_no_group_hstu_attention
|
||||
float alpha,
|
||||
float attn_scale,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_kv,
|
||||
int max_seqlen_kv, // only used as last dim of the tensor for saving the mask
|
||||
std::vector<int> seq_q_offsets,
|
||||
std::vector<int> seq_kv_offsets,
|
||||
std::vector<int> num_targets, // define masking length at the end of token
|
||||
@@ -116,9 +116,7 @@ struct reference_no_group_hstu_attention
|
||||
|
||||
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
|
||||
|
||||
float scale_p = attn_scale
|
||||
? attn_scale
|
||||
: 1.0f / static_cast<float>(max(max_seqlen_q, max_seqlen_kv));
|
||||
float scale_p = attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen_q);
|
||||
|
||||
BOOL_SWITCH_2(window_size > 0, kHasLocal, is_cross_attention, kIsCrossAttention, [&] {
|
||||
using HstuMaskType =
|
||||
@@ -335,12 +333,15 @@ struct reference_group_hstu_attention
|
||||
int num_batch,
|
||||
int num_batch_per_group,
|
||||
float alpha,
|
||||
int max_max_seqlen, // the maximum of all groups's max_seqlen
|
||||
int max_max_seqlen_q, // the maximum of all groups's max_seqlen_q, only used as second last
|
||||
// dim of the tensor for saving the mask
|
||||
int max_max_seqlen_kv, // the maximum of all groups's max_seqlen_k, only used as last dim of
|
||||
// the tensor for saving the mask
|
||||
const std::vector<int>& seq_q_offsets,
|
||||
const std::vector<int>& seq_kv_offsets,
|
||||
const std::vector<int>& num_targets, // define masking length at the end of token
|
||||
// sequence to be excluded for attention
|
||||
const std::vector<int>& group_max_seqlens, // max seqlen list by groups
|
||||
const std::vector<int>& num_targets, // define masking length at the end of token
|
||||
// sequence to be excluded for attention
|
||||
const std::vector<int>& group_max_seqlens_q, // max seqlen_q list by groups
|
||||
const std::vector<int>& group_contextual_seqlens, // contextual seqlen list by groups
|
||||
const std::vector<int>& group_window_sizes, // window_size list by groups
|
||||
const std::vector<int>& group_min_full_attn_seqlens, // min_full_attn_seqlen list by groups
|
||||
@@ -374,8 +375,8 @@ struct reference_group_hstu_attention
|
||||
|
||||
if(static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[0]) == num_batch &&
|
||||
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[1]) == num_head &&
|
||||
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_max_seqlen &&
|
||||
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_max_seqlen)
|
||||
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_max_seqlen_q &&
|
||||
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_max_seqlen_kv)
|
||||
save_mask = true;
|
||||
|
||||
// check num_tagets
|
||||
@@ -394,10 +395,10 @@ struct reference_group_hstu_attention
|
||||
|
||||
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
|
||||
|
||||
int max_seqlen = group_max_seqlens[i_group];
|
||||
int max_seqlen_q = group_max_seqlens_q[i_group];
|
||||
float attn_scale = group_attn_scales[i_group];
|
||||
|
||||
float scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen));
|
||||
float scale_p = (attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen_q));
|
||||
|
||||
int contextual_seqlen = group_contextual_seqlens[i_group];
|
||||
int window_size = group_window_sizes[i_group];
|
||||
@@ -468,8 +469,8 @@ struct reference_group_hstu_attention
|
||||
|
||||
if(save_mask)
|
||||
{
|
||||
for(int sq = 0; sq < max_seqlen; sq++)
|
||||
for(int sk = 0; sk < max_seqlen; sk++)
|
||||
for(int sq = 0; sq < max_max_seqlen_q; sq++)
|
||||
for(int sk = 0; sk < max_max_seqlen_kv; sk++)
|
||||
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) = 0;
|
||||
|
||||
for(int sq = 0; sq < seqlen_q; sq++)
|
||||
|
||||
Reference in New Issue
Block a user