Support separate sequence lengths for q and kv

This commit is contained in:
Qianfeng Zhang
2025-10-31 14:04:32 +00:00
parent eaf9650fed
commit 17e404be3b
7 changed files with 277 additions and 188 deletions

View File

@@ -100,7 +100,8 @@ auto create_args(int argc, char* argv[])
.insert("nhead", "4", "number of heads")
.insert("hdim_qk", "64", "headdim size of Q/K")
.insert("hdim_v", "64", "headdim size of V/O")
.insert("seqlens", "400", "uih seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("seqlens", "400", "uih seqlen of single or all batches for query tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("seqlens_kv", "", "uih seqlen of single or all batches for key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
@@ -238,13 +239,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::string str_of_targets = arg_parser.get_str("targets");
std::vector<int> num_targets = get_integers_from_string(str_of_targets);
std::string str_of_lengths = arg_parser.get_str("seqlens");
std::vector<int> seq_lengths = get_integers_from_string(str_of_lengths);
std::string str_of_lengths_q = arg_parser.get_str("seqlens");
std::vector<int> seq_lengths_q = get_integers_from_string(str_of_lengths_q);
std::string str_of_lengths_kv = arg_parser.get_str("seqlens_kv");
std::vector<int> seq_lengths_kv = get_integers_from_string(str_of_lengths_kv);
int input_max_uih_seqlen = arg_parser.get_int("max_seqlen");
int input_max_target = arg_parser.get_int("max_target");
int uih_seqlen = 0; // means total seq lengths for jagged
int max_uih_seqlen = 0;
int max_target = 0;
@@ -264,31 +267,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
max_target = max(max_target, num_targets[i]);
};
HSTU_CHECK(!seq_lengths.empty(), "sequence lengths shoud be defined!");
HSTU_CHECK(!seq_lengths_q.empty(), "sequence lengths of q shoud be defined!");
// assume seq_lengths_kv is same as seq_lengths_q if not defined
if(seq_lengths_kv.empty())
seq_lengths_kv = seq_lengths_q;
if(is_jagged)
{
// supplement seq_lengths using the last input value if user-provided lengths not enough
if(static_cast<int>(seq_lengths.size()) < num_batch)
if(static_cast<int>(seq_lengths_q.size()) < num_batch)
{
auto last_len = seq_lengths.back();
auto last_len = seq_lengths_q.back();
for(int i = seq_lengths.size(); i < num_batch; i++)
seq_lengths.push_back(last_len);
for(int i = seq_lengths_q.size(); i < num_batch; i++)
seq_lengths_q.push_back(last_len);
};
// supplement seq_lengths_kv using the last input value if user-provided lengths not enough
if(static_cast<int>(seq_lengths_kv.size()) < num_batch)
{
auto last_len = seq_lengths_kv.back();
for(int i = seq_lengths_kv.size(); i < num_batch; i++)
seq_lengths_kv.push_back(last_len);
};
// only consider num_batch values even if more values are provided by the user
for(int i = 0; i < num_batch; i++)
{
max_uih_seqlen = max(max_uih_seqlen, seq_lengths[i]);
max_uih_seqlen = max(max_uih_seqlen, seq_lengths_q[i]);
};
}
else
{
HSTU_CHECK(1 == seq_lengths.size(),
HSTU_CHECK(1 == seq_lengths_q.size() && 1 == seq_lengths_kv.size(),
"sequence lengths for batched mode shoud have single element!");
uih_seqlen = seq_lengths[0];
max_uih_seqlen = uih_seqlen;
max_uih_seqlen = max(seq_lengths_q[0], seq_lengths_kv[0]);
};
// the user input of max_uih_seqlen can either be ignored or be bigger than all uih_seqlens
@@ -304,28 +319,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
max_uih_seqlen = (input_max_uih_seqlen > 0) ? input_max_uih_seqlen : max_uih_seqlen;
max_target = (input_max_target > 0) ? input_max_target : max_target;
int phy_seqlen = 0;
int max_seqlen = max_uih_seqlen + max_target + contextual_seqlen;
int phy_seqlen_q = 0;
int phy_seqlen_kv = 0;
int max_seqlen = max_uih_seqlen + max_target + contextual_seqlen;
std::vector<int> seq_offsets;
std::vector<int> seq_offsets_q;
std::vector<int> seq_offsets_kv;
if(is_jagged)
{
seq_offsets.push_back(0);
seq_offsets_q.push_back(0);
for(int i = 0; i < num_batch; i++)
{
int batch_seqlen = num_targets.empty()
? seq_lengths[i] + contextual_seqlen
: seq_lengths[i] + num_targets[i] + contextual_seqlen;
? seq_lengths_q[i] + contextual_seqlen
: seq_lengths_q[i] + num_targets[i] + contextual_seqlen;
phy_seqlen += batch_seqlen;
seq_offsets.push_back(phy_seqlen);
phy_seqlen_q += batch_seqlen;
seq_offsets_q.push_back(phy_seqlen_q);
};
seq_offsets_kv.push_back(0);
for(int i = 0; i < num_batch; i++)
{
int batch_seqlen = num_targets.empty()
? seq_lengths_kv[i] + contextual_seqlen
: seq_lengths_kv[i] + num_targets[i] + contextual_seqlen;
phy_seqlen_kv += batch_seqlen;
seq_offsets_kv.push_back(phy_seqlen_kv);
};
}
else
{
phy_seqlen = max_seqlen;
phy_seqlen_q = max_seqlen;
phy_seqlen_kv = max_seqlen;
};
long total_flops = 0;
@@ -335,10 +365,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
for(int i = 0; i < num_batch; i++)
{
int len = seq_offsets[i + 1] - seq_offsets[i];
total_flops +=
(static_cast<long>(len) * len * hdim_qk + static_cast<long>(len) * hdim_v * len) *
2;
int len_q = seq_offsets_q[i + 1] - seq_offsets_q[i];
int len_kv = seq_offsets_kv[i + 1] - seq_offsets_kv[i];
total_flops += (static_cast<long>(len_q) * len_kv * hdim_qk +
static_cast<long>(len_q) * hdim_v * len_kv) *
2;
};
total_flops *= num_head;
@@ -346,21 +377,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
else
{
total_flops = static_cast<long>(num_batch) * num_head *
(static_cast<long>(phy_seqlen) * phy_seqlen * hdim_qk +
static_cast<long>(phy_seqlen) * hdim_v * phy_seqlen) *
(static_cast<long>(phy_seqlen_q) * phy_seqlen_kv * hdim_qk +
static_cast<long>(phy_seqlen_q) * hdim_v * phy_seqlen_kv) *
2;
};
int batches_for_alloc = is_jagged ? 1 : num_batch;
ck_tile::HostTensor<InOutDataType> q_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_qk});
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_qk});
ck_tile::HostTensor<InOutDataType> k_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_qk});
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_kv, num_head, hdim_qk});
ck_tile::HostTensor<InOutDataType> v_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_v});
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_kv, num_head, hdim_v});
ck_tile::HostTensor<InOutDataType> o_host_ref(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_v});
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
ck_tile::HostTensor<int8_t> mask_host(
save_mask ? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_seqlen, max_seqlen}
@@ -393,7 +424,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_dev(o_host_ref.get_element_space_size_in_bytes());
ck_tile::DeviceMem seq_offsets_dev(seq_offsets.size() * sizeof(int));
ck_tile::DeviceMem seq_offsets_q_dev(seq_offsets_q.size() * sizeof(int));
ck_tile::DeviceMem seq_offsets_kv_dev(seq_offsets_kv.size() * sizeof(int));
ck_tile::DeviceMem num_targets_dev(num_targets.size() * sizeof(int));
q_dev.ToDevice(q_host.data());
@@ -401,7 +433,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
v_dev.ToDevice(v_host.data());
if(is_jagged)
seq_offsets_dev.ToDevice(seq_offsets.data());
{
seq_offsets_q_dev.ToDevice(seq_offsets_q.data());
seq_offsets_kv_dev.ToDevice(seq_offsets_kv.data());
};
if(!num_targets.empty())
num_targets_dev.ToDevice(num_targets.data());
@@ -411,30 +446,31 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(is_jagged)
{
params.is_jagged = true;
params.num_batch = num_batch;
params.seq_offsets_ptr = seq_offsets_dev.GetDeviceBuffer();
params.max_seqlen = max_seqlen;
params.q_ptr = q_dev.GetDeviceBuffer();
params.k_ptr = k_dev.GetDeviceBuffer();
params.v_ptr = v_dev.GetDeviceBuffer();
params.bias_ptr = nullptr; // bias is not supported at present
params.o_ptr = o_dev.GetDeviceBuffer();
params.hdim_qk = hdim_qk;
params.hdim_v = hdim_v;
params.num_head = num_head;
params.scale_s = scale_s;
params.attn_scale = attn_scale;
params.seq_stride_q = q_host.get_strides()[1];
params.seq_stride_k = k_host.get_strides()[1];
params.seq_stride_v = v_host.get_strides()[1];
params.seq_stride_bias = 0;
params.seq_stride_o = o_host_ref.get_strides()[1];
params.nhead_stride_q = q_host.get_strides()[2];
params.nhead_stride_k = k_host.get_strides()[2];
params.nhead_stride_v = v_host.get_strides()[2];
params.nhead_stride_bias = 0;
params.nhead_stride_o = o_host_ref.get_strides()[2];
params.is_jagged = true;
params.num_batch = num_batch;
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer();
params.max_seqlen = max_seqlen;
params.q_ptr = q_dev.GetDeviceBuffer();
params.k_ptr = k_dev.GetDeviceBuffer();
params.v_ptr = v_dev.GetDeviceBuffer();
params.bias_ptr = nullptr; // bias is not supported at present
params.o_ptr = o_dev.GetDeviceBuffer();
params.hdim_qk = hdim_qk;
params.hdim_v = hdim_v;
params.num_head = num_head;
params.scale_s = scale_s;
params.attn_scale = attn_scale;
params.seq_stride_q = q_host.get_strides()[1];
params.seq_stride_k = k_host.get_strides()[1];
params.seq_stride_v = v_host.get_strides()[1];
params.seq_stride_bias = 0;
params.seq_stride_o = o_host_ref.get_strides()[1];
params.nhead_stride_q = q_host.get_strides()[2];
params.nhead_stride_k = k_host.get_strides()[2];
params.nhead_stride_v = v_host.get_strides()[2];
params.nhead_stride_bias = 0;
params.nhead_stride_o = o_host_ref.get_strides()[2];
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
params.use_softmax = use_softmax;
params.use_causal = use_causal;
@@ -449,7 +485,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
params.is_jagged = false;
params.num_batch = num_batch;
params.seqlen = max_seqlen;
params.seqlen_q = phy_seqlen_q;
params.seqlen_kv = phy_seqlen_kv;
params.q_ptr = q_dev.GetDeviceBuffer();
params.k_ptr = k_dev.GetDeviceBuffer();
params.v_ptr = v_dev.GetDeviceBuffer();
@@ -532,7 +569,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
scale_s,
attn_scale,
max_seqlen,
seq_offsets,
seq_offsets_q,
seq_offsets_kv,
num_targets,
contextual_seqlen,
window_size,
@@ -540,7 +578,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
});
ck_tile::HostTensor<InOutDataType> o_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_v});
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
o_dev.FromDevice(o_host.data());

View File

@@ -60,7 +60,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
{
constexpr ck_tile::index_t occupancy = -1;
const bool pad_seqlen_k = !(param.seqlen % HstuAttentionTileSetting::kN0 == 0);
const bool pad_seqlen_k = !(param.seqlen_kv % HstuAttentionTileSetting::kN0 == 0);
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0);
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0);
@@ -125,7 +125,8 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
param.v_ptr,
param.bias_ptr,
param.o_ptr,
param.seqlen,
param.seqlen_q,
param.seqlen_kv,
param.hdim_qk,
param.hdim_v,
param.num_head,
@@ -157,7 +158,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
dim3 kGridSize = HstuKernel::GridSize(
param.num_batch, param.num_head, param.seqlen, param.hdim_v, has_minfull_attn_seqlen);
param.num_batch, param.num_head, param.seqlen_q, param.hdim_v, has_minfull_attn_seqlen);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;

View File

@@ -78,7 +78,8 @@ struct HstuAttentionFwdKernel
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t seqlen;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_kv;
ck_tile::index_t hdim_qk;
ck_tile::index_t hdim_v;
@@ -98,7 +99,8 @@ struct HstuAttentionFwdKernel
struct HstuAttentionFwdJaggModeBaseKargs
{
const int32_t* seq_offsets_ptr;
const int32_t* seq_q_offsets_ptr;
const int32_t* seq_kv_offsets_ptr;
ck_tile::index_t seq_stride_q;
ck_tile::index_t seq_stride_k;
@@ -120,7 +122,8 @@ struct HstuAttentionFwdKernel
ck_tile::index_t hdim_qk;
ck_tile::index_t hdim_v;
ck_tile::index_t seqlen;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_kv;
ck_tile::index_t num_head;
float scale_s; // scaling value exerted on the immediate Q@K result
@@ -196,7 +199,8 @@ struct HstuAttentionFwdKernel
const void* v_ptr,
const void* bias_ptr,
void* o_ptr,
ck_tile::index_t seqlen,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_kv,
ck_tile::index_t hdim_qk,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
@@ -239,7 +243,8 @@ struct HstuAttentionFwdKernel
nhead_stride_k,
nhead_stride_v,
nhead_stride_o,
seqlen,
seqlen_q,
seqlen_kv,
hdim_qk,
hdim_v,
seq_stride_q,
@@ -248,7 +253,8 @@ struct HstuAttentionFwdKernel
seq_stride_o,
num_head,
scale_s,
attn_scale ? attn_scale : 1.0f / static_cast<float>(seqlen), // max_seqlen
attn_scale ? attn_scale
: 1.0f / static_cast<float>(max(seqlen_q, seqlen_kv)), // max_seqlen
contextual_seqlen,
window_size,
min_full_attn_seqlen}, // args for common karg
@@ -278,7 +284,8 @@ struct HstuAttentionFwdKernel
const void* v_ptr,
const void* bias_ptr,
void* o_ptr,
const void* seq_offsets_ptr,
const void* seq_q_offsets_ptr,
const void* seq_kv_offsets_ptr,
ck_tile::index_t max_seqlen,
ck_tile::index_t hdim_qk,
ck_tile::index_t hdim_v,
@@ -304,7 +311,8 @@ struct HstuAttentionFwdKernel
uint64_t philox_offset)
{
Kargs kargs{
{reinterpret_cast<const int32_t*>(seq_offsets_ptr),
{reinterpret_cast<const int32_t*>(seq_q_offsets_ptr),
reinterpret_cast<const int32_t*>(seq_kv_offsets_ptr),
seq_stride_q,
seq_stride_k,
seq_stride_v,
@@ -320,7 +328,8 @@ struct HstuAttentionFwdKernel
nhead_stride_o,
hdim_qk,
hdim_v,
-1, // seqlen will be updated by another pointer
-1, // seqlen_q will be updated by another pointer
-1, // seqlen_kv will be updated by another pointer
num_head,
scale_s,
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen),
@@ -465,8 +474,8 @@ struct HstuAttentionFwdKernel
if constexpr(kIsJagged)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seq_offsets_ptr[i_batch];
const long_index_t key_start = query_start;
const long_index_t query_start = kargs.seq_q_offsets_ptr[i_batch];
const long_index_t key_start = kargs.seq_kv_offsets_ptr[i_batch];
batch_offset_q = query_start * kargs.seq_stride_q;
batch_offset_k = key_start * kargs.seq_stride_k;
@@ -478,7 +487,10 @@ struct HstuAttentionFwdKernel
}
batch_offset_o = query_start * kargs.seq_stride_o;
kargs.seqlen = kargs.seq_offsets_ptr[i_batch + 1] - kargs.seq_offsets_ptr[i_batch];
kargs.seqlen_q =
kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch];
kargs.seqlen_kv =
kargs.seq_kv_offsets_ptr[i_batch + 1] - kargs.seq_kv_offsets_ptr[i_batch];
}
else
{
@@ -494,16 +506,16 @@ struct HstuAttentionFwdKernel
int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch];
index_t seqlen_in_first_split = kargs.seqlen;
index_t seqlen_in_first_split = kargs.seqlen_q;
bool is_tile_in_first_split = true;
index_t i_m0;
if(kargs.min_full_attn_seqlen > 0)
{
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
if(kargs.seqlen - num_target > kargs.min_full_attn_seqlen)
if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen)
{
seqlen_in_first_split = kargs.seqlen - num_target - kargs.min_full_attn_seqlen;
seqlen_in_first_split = kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen;
index_t num_tile_in_first_split =
ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0);
@@ -522,7 +534,7 @@ struct HstuAttentionFwdKernel
is_tile_in_first_split = false;
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
kargs.min_full_attn_seqlen = kargs.seqlen - num_target;
kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target;
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
};
@@ -532,7 +544,7 @@ struct HstuAttentionFwdKernel
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * HstuAttentionPipeline::kN1);
index_t seqlen_q_in_ctrl = is_tile_in_first_split ? seqlen_in_first_split : kargs.seqlen;
index_t seqlen_q_in_ctrl = is_tile_in_first_split ? seqlen_in_first_split : kargs.seqlen_q;
if(seqlen_q_in_ctrl <= i_m0)
return;
@@ -567,7 +579,7 @@ struct HstuAttentionFwdKernel
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.seqlen, kargs.hdim_qk),
make_tuple(kargs.seqlen_kv, kargs.hdim_qk),
make_tuple(kargs.seq_stride_k, 1),
number<HstuAttentionPipeline::kAlignmentK>{},
number<1>{});
@@ -580,7 +592,7 @@ struct HstuAttentionFwdKernel
const auto v_dram = [&]() {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.seqlen, kargs.hdim_v),
make_tuple(kargs.seqlen_kv, kargs.hdim_v),
make_tuple(kargs.seq_stride_v, 1),
number<HstuAttentionPipeline::kAlignmentV>{},
number<1>{});
@@ -590,7 +602,7 @@ struct HstuAttentionFwdKernel
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen)),
make_pass_through_transform(kargs.seqlen_kv)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
@@ -641,7 +653,7 @@ struct HstuAttentionFwdKernel
const auto bias_dram = [&]() {
const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
bias_ptr,
make_tuple(seqlen_q_in_ctrl, kargs.seqlen),
make_tuple(seqlen_q_in_ctrl, kargs.seqlen_kv),
make_tuple(kargs.seq_stride_bias, 1),
number<HstuAttentionPipeline::kAlignmentBias>{},
number<1>{});
@@ -682,7 +694,8 @@ struct HstuAttentionFwdKernel
using HstuMaskType = typename ck_tile::HstuBlockMasking<kHasCausalMask, true>::Type;
const auto mask =
make_hstu_block_mask_with_local<HstuMaskType>(is_tile_in_first_split,
kargs.seqlen,
kargs.seqlen_q,
kargs.seqlen_kv,
kargs.contextual_seqlen,
num_target,
kargs.window_size,
@@ -703,7 +716,7 @@ struct HstuAttentionFwdKernel
using HstuMaskType =
typename ck_tile::HstuBlockMasking<kHasCausalMask, false>::Type;
const auto mask = make_hstu_block_mask_without_local<HstuMaskType>(
kargs.seqlen, kargs.contextual_seqlen, num_target);
kargs.seqlen_q, kargs.seqlen_kv, kargs.contextual_seqlen, num_target);
return HstuAttentionPipeline{}(q_dram_window,
k_dram_window,

View File

@@ -118,7 +118,8 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
param.v_ptr,
param.bias_ptr,
param.o_ptr,
param.seq_offsets_ptr,
param.seq_q_offsets_ptr,
param.seq_kv_offsets_ptr,
param.max_seqlen,
param.hdim_qk,
param.hdim_v,

View File

@@ -10,9 +10,11 @@ struct HstuAttentionFwdParams
bool is_jagged;
ck_tile::index_t num_batch;
ck_tile::index_t seqlen; // batched mode only
const void* seq_offsets_ptr; // jagged mode only
ck_tile::index_t max_seqlen; // jagged mode only
ck_tile::index_t seqlen_q; // batched mode only
ck_tile::index_t seqlen_kv; // batched mode only
const void* seq_q_offsets_ptr; // jagged mode only
const void* seq_kv_offsets_ptr; // jagged mode only
ck_tile::index_t max_seqlen; // jagged mode only
const void* q_ptr;
const void* k_ptr;

View File

@@ -15,42 +15,54 @@ struct HstuBlockMaskWithLocal
static constexpr bool IsMasking = true;
// is_tile_in_first_split is false only when min_full_attn_seqlen > 0 and the current
// tile is inside scope [max_uih_len - min_full_attn_seqlen, seqlen); for other cases
// tile is inside scope [max_uih_len - min_full_attn_seqlen, seqlen_q); for other cases
// and tiles, is_tile_in_first_split is true
bool is_tile_in_first_split;
int seqlen;
int seqlen_q;
int seqlen_k;
int contextual_seqlen;
int min_full_attn_seqlen;
int max_attn_len;
int max_uih_len;
int max_id;
int max_q_uih_len;
int max_k_uih_len;
int max_row_id;
int max_col_id;
CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(bool is_tile_in_first_split_,
int seqlen_,
int seqlen_q_,
int seqlen_k_,
int contextual_seqlen_,
int max_attn_len_,
int min_full_attn_seqlen_,
int num_target_)
: is_tile_in_first_split(is_tile_in_first_split_),
seqlen(seqlen_),
seqlen_q(seqlen_q_),
seqlen_k(seqlen_k_),
contextual_seqlen(contextual_seqlen_),
min_full_attn_seqlen(min_full_attn_seqlen_)
{
max_uih_len = seqlen - num_target_;
max_q_uih_len = seqlen_q - num_target_;
max_k_uih_len = seqlen_k - num_target_;
// in case user provided max_attn_len_ could be bigger than max_uih_len
max_attn_len = min(max_uih_len, max_attn_len_);
max_attn_len = min(max_k_uih_len, min(max_q_uih_len, max_attn_len_));
// assuming min_full_attn_seqlen has higher priority, ensure contextual scope not collide
// with min_full_attn_seqlen scope
contextual_seqlen = min(contextual_seqlen, max_uih_len - min_full_attn_seqlen);
contextual_seqlen = min(contextual_seqlen, max_q_uih_len - min_full_attn_seqlen);
if(contextual_seqlen > 0)
max_id = max_uih_len - (contextual_seqlen - 1);
{
max_row_id = max_q_uih_len - (contextual_seqlen - 1);
max_col_id = max_k_uih_len - (contextual_seqlen - 1);
}
else
max_id = max_uih_len;
{
max_row_id = max_q_uih_len;
max_col_id = max_k_uih_len;
}
};
// to get the loop length along X axis, return index:[start, end), end-start=length
@@ -65,20 +77,20 @@ struct HstuBlockMaskWithLocal
{
if constexpr(kUseCausal)
{
index_t x_end = min(i_y + YTile, seqlen);
index_t x_end = min(i_y + YTile, seqlen_k);
return ck_tile::make_tuple(0, x_end);
}
else
{
// tile is partitially or completely in [max_uih_len-min_full_attn_seqlen,
// max_uih_len)
if(i_y < max_uih_len)
// max_q_uih_len)
if(i_y < max_q_uih_len)
{
return ck_tile::make_tuple(0, seqlen);
return ck_tile::make_tuple(0, seqlen_k);
}
else // tile is completely inside [max_uih_len, seqlen)
else // tile is completely inside [max_q_uih_len, seqlen_q)
{
index_t x_end = min(i_y + YTile, seqlen);
index_t x_end = min(i_y + YTile, seqlen_k);
return ck_tile::make_tuple(0, x_end);
};
};
@@ -90,18 +102,18 @@ struct HstuBlockMaskWithLocal
{
if(i_y >= min(contextual_seqlen, 1) + max_attn_len)
{
// some row of the tile in [contextual_seqlen+max_attn_len, max_uih_len)
if(i_y < max_uih_len)
// some row of the tile in [contextual_seqlen+max_attn_len, max_q_uih_len)
if(i_y < max_q_uih_len)
{
index_t x_start = i_y - max_attn_len;
index_t x_start_aligned = x_start - x_start % XTile;
// some rows of the tile in [max_uih_len -max_attn_len, max_uih_len)
if(i_y + YTile > max_uih_len - max_attn_len)
// some rows of the tile in [max_q_uih_len - max_attn_len, max_q_uih_len)
if(i_y + YTile > max_q_uih_len - max_attn_len)
{
return ck_tile::make_tuple(x_start_aligned, seqlen);
return ck_tile::make_tuple(x_start_aligned, seqlen_k);
}
else // whole tile in [contextual_seqlen+max_attn_len, max_uih_len
else // whole tile in [contextual_seqlen+max_attn_len, max_q_uih_len
// -max_attn_len)
{
index_t x_end = i_y + YTile + max_attn_len;
@@ -110,8 +122,8 @@ struct HstuBlockMaskWithLocal
}
else // whole tile in [max_uih_len, seqlen)
{
index_t x_start = max_uih_len - max_attn_len;
index_t x_end = min(i_y + YTile, seqlen);
index_t x_start = max_k_uih_len - max_attn_len;
index_t x_end = min(i_y + YTile, seqlen_k);
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
}
@@ -120,12 +132,12 @@ struct HstuBlockMaskWithLocal
{
if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen)
{
index_t x_end = min(max(i_y + YTile + max_attn_len, max_uih_len), seqlen);
index_t x_end = min(max(i_y + YTile + max_attn_len, max_k_uih_len), seqlen_k);
return ck_tile::make_tuple(0, x_end);
}
else // whole tile in [contextual_seqlen, seqlen)
{
index_t x_end = min(i_y + YTile + max_attn_len, seqlen);
index_t x_end = min(i_y + YTile + max_attn_len, seqlen_k);
return ck_tile::make_tuple(0, x_end);
}
}
@@ -134,17 +146,17 @@ struct HstuBlockMaskWithLocal
{
if(i_y >= min(contextual_seqlen, 1) + max_attn_len)
{
index_t x_end = min(i_y + YTile, seqlen);
index_t x_end = min(i_y + YTile, seqlen_k);
// some row of the tile in [contextual_seqlen+max_attn_len, max_uih_len)
if(i_y < max_uih_len)
// some row of the tile in [contextual_seqlen+max_attn_len, max_q_uih_len)
if(i_y < max_q_uih_len)
{
index_t x_start = i_y - max_attn_len;
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
}
else // whole tile in [max_uih_len, seqlen)
{
index_t x_start = max_uih_len - max_attn_len;
index_t x_start = max_k_uih_len - max_attn_len;
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
}
}
@@ -152,12 +164,12 @@ struct HstuBlockMaskWithLocal
{
if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen)
{
index_t x_end = min(max(i_y + YTile, max_uih_len), seqlen);
index_t x_end = min(max(i_y + YTile, max_k_uih_len), seqlen_k);
return ck_tile::make_tuple(0, x_end);
}
else // whole tile in [contextual_seqlen, seqlen)
{
index_t x_end = min(i_y + YTile, seqlen);
index_t x_end = min(i_y + YTile, seqlen_k);
return ck_tile::make_tuple(0, x_end);
}
}
@@ -176,18 +188,18 @@ struct HstuBlockMaskWithLocal
row_id = max(row - contextual_seqlen + 1, 0);
col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_id);
col_id = min(col_id, max_id);
row_id = min(row_id, max_row_id);
col_id = min(col_id, max_col_id);
if(row_id == 0 && col_id < max_id)
if(row_id == 0 && col_id < max_col_id)
return true;
}
else
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
row_id = min(row, max_id);
col_id = min(col, max_id);
row_id = min(row, max_row_id);
col_id = min(col, max_col_id);
};
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
@@ -195,7 +207,7 @@ struct HstuBlockMaskWithLocal
if constexpr(kUseCausal)
{
bool in_min_full_scope =
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
(min_full_attn_seqlen > 0) ? (row_id >= max_row_id - min_full_attn_seqlen) : false;
return (((row_id > col_id) || (row == col)) &&
((row_id - col_id <= max_attn_len) || in_min_full_scope));
@@ -203,7 +215,7 @@ struct HstuBlockMaskWithLocal
else
{
bool in_min_full_scope =
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
(min_full_attn_seqlen > 0) ? (row_id >= max_row_id - min_full_attn_seqlen) : false;
return (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
@@ -222,18 +234,18 @@ struct HstuBlockMaskWithLocal
row_id = max(row - contextual_seqlen + 1, 0);
col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_id);
col_id = min(col_id, max_id);
row_id = min(row_id, max_row_id);
col_id = min(col_id, max_col_id);
if(row_id == 0 && col_id < max_id)
if(row_id == 0 && col_id < max_col_id)
return true;
}
else
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
row_id = min(row, max_id);
col_id = min(col, max_id);
row_id = min(row, max_row_id);
col_id = min(col, max_col_id);
};
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
@@ -269,7 +281,7 @@ struct HstuBlockMaskWithLocal
{
index_t i_tile_right = i_tile_left + TileWidth;
if(!is_tile_in_first_split && i_tile_right <= min(i_tile_top + 1, max_uih_len))
if(!is_tile_in_first_split && i_tile_right <= min(i_tile_top + 1, max_k_uih_len))
return true;
}
else
@@ -277,11 +289,11 @@ struct HstuBlockMaskWithLocal
index_t i_tile_right = i_tile_left + TileWidth;
index_t i_tile_bottom = i_tile_top + TileHeight;
// 1) tile is completely in [max_uih_len-min_full_attn_seqlen, max_uih_len]
// 2) some row of tile is in [max_uih_len, seqlen], requires i_tile_right <= max_uih_len
// to return true
// 1) tile is completely in [max_q_uih_len-min_full_attn_seqlen, max_q_uih_len]
// 2) some row of tile is in [max_q_uih_len, seqlen_q], requires i_tile_right <=
// max_k_uih_len to return true
if(!is_tile_in_first_split &&
(i_tile_bottom <= max_uih_len || i_tile_right <= max_uih_len))
(i_tile_bottom <= max_q_uih_len || i_tile_right <= max_k_uih_len))
return true;
};
@@ -295,21 +307,32 @@ struct HstuBlockMaskNoLocal
static constexpr bool kUseLocal = false;
static constexpr bool IsMasking = kUseCausal;
int seqlen;
int seqlen_q;
int seqlen_k;
int contextual_seqlen;
int max_uih_len;
int max_id;
int max_q_uih_len;
int max_k_uih_len;
int max_row_id;
int max_col_id;
CK_TILE_HOST_DEVICE HstuBlockMaskNoLocal(int seqlen_, int contextual_seqlen_, int num_target_)
: seqlen(seqlen_), contextual_seqlen(contextual_seqlen_)
CK_TILE_HOST_DEVICE
HstuBlockMaskNoLocal(int seqlen_q_, int seqlen_k_, int contextual_seqlen_, int num_target_)
: seqlen_q(seqlen_q_), seqlen_k(seqlen_k_), contextual_seqlen(contextual_seqlen_)
{
max_uih_len = seqlen - num_target_;
max_q_uih_len = seqlen_q - num_target_;
max_k_uih_len = seqlen_k - num_target_;
if(contextual_seqlen > 0)
max_id = max_uih_len - (contextual_seqlen - 1);
{
max_row_id = max_q_uih_len - (contextual_seqlen - 1);
max_col_id = max_k_uih_len - (contextual_seqlen - 1);
}
else
max_id = max_uih_len;
{
max_row_id = max_q_uih_len;
max_col_id = max_k_uih_len;
}
};
// to get the loop length along X axis, return index:[start, end), end-start=length
@@ -321,21 +344,21 @@ struct HstuBlockMaskNoLocal
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, seqlen);
return ck_tile::make_tuple(0, seqlen_k);
}
else
{
index_t x_end = min(i_y + YTile, seqlen);
index_t x_end = min(i_y + YTile, seqlen_k);
if(i_y < contextual_seqlen)
{
if(i_y + YTile > max_uih_len)
if(i_y + YTile > max_k_uih_len)
{
return ck_tile::make_tuple(0, x_end);
}
else
{
return ck_tile::make_tuple(0, max_uih_len);
return ck_tile::make_tuple(0, max_k_uih_len);
};
}
else
@@ -357,18 +380,18 @@ struct HstuBlockMaskNoLocal
row_id = max(row - contextual_seqlen + 1, 0);
col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_id);
col_id = min(col_id, max_id);
row_id = min(row_id, max_row_id);
col_id = min(col_id, max_col_id);
if(row_id == 0 && col_id < max_id)
if(row_id == 0 && col_id < max_col_id)
return true;
}
else
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
row_id = min(row, max_id);
col_id = min(col, max_id);
row_id = min(row, max_row_id);
col_id = min(col, max_col_id);
};
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
@@ -395,18 +418,18 @@ struct HstuBlockMaskNoLocal
row_id = max(row - contextual_seqlen + 1, 0);
col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_id);
col_id = min(col_id, max_id);
row_id = min(row_id, max_row_id);
col_id = min(col_id, max_col_id);
if(row_id == 0 && col_id < max_id)
if(row_id == 0 && col_id < max_col_id)
return true;
}
else
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
row_id = min(row, max_id);
col_id = min(col, max_id);
row_id = min(row, max_row_id);
col_id = min(col, max_col_id);
};
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
@@ -439,7 +462,7 @@ struct HstuBlockMaskNoLocal
// assume num_target > 0 with high probability, don't check whether num_target is 0;
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
if(i_tile_bottom >= max_uih_len || i_tile_right > i_tile_top)
if(i_tile_bottom >= max_q_uih_len || i_tile_right > i_tile_top)
return false;
return true;
@@ -451,7 +474,7 @@ struct HstuBlockMaskNoLocal
// assume num_target > 0 with high probability, don't check whether num_target is 0;
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
if(i_tile_bottom >= max_uih_len || i_tile_right >= max_uih_len)
if(i_tile_bottom >= max_q_uih_len || i_tile_right >= max_k_uih_len)
return false;
return true;
@@ -469,14 +492,16 @@ struct HstuBlockMasking
template <typename HstuBlockMaskType>
CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(bool is_tile_in_first_split_,
int seqlen_,
int seqlen_q_,
int seqlen_k_,
int contextual_seqlen_,
int num_target,
int max_attn_len_,
int min_full_attn_seqlen_)
{
return HstuBlockMaskType{is_tile_in_first_split_,
seqlen_,
seqlen_q_,
seqlen_k_,
contextual_seqlen_,
max_attn_len_,
min_full_attn_seqlen_,
@@ -484,10 +509,12 @@ CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(bool is_tile_
};
template <typename HstuBlockMaskType>
CK_TILE_HOST_DEVICE constexpr auto
make_hstu_block_mask_without_local(int seqlen_, int contextual_seqlen_, int num_target)
CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_without_local(int seqlen_q_,
int seqlen_k_,
int contextual_seqlen_,
int num_target)
{
return HstuBlockMaskType{seqlen_, contextual_seqlen_, num_target};
return HstuBlockMaskType{seqlen_q_, seqlen_k_, contextual_seqlen_, num_target};
};
} // namespace ck_tile

View File

@@ -42,7 +42,8 @@ struct reference_hstu_attention
float alpha,
float attn_scale,
int max_seqlen,
std::vector<int> seq_offsets,
std::vector<int> seq_q_offsets,
std::vector<int> seq_kv_offsets,
std::vector<int> num_targets, // define masking length at the end of token
// sequence to be excluded for attention
int contextual_seqlen, // define masking length at the begin of query token
@@ -54,7 +55,8 @@ struct reference_hstu_attention
if constexpr(kIsJagged)
{
// check the number of batches
assert(!seq_offsets.empty() && seq_offsets.size() == num_batch + 1);
assert(!seq_q_offsets.empty() && seq_q_offsets.size() == num_batch + 1);
assert(!seq_kv_offsets.empty() && seq_kv_offsets.size() == num_batch + 1);
assert(q_batch_seq_nhead_hdim.get_lengths()[0] == 1);
assert(k_batch_seq_nhead_hdim.get_lengths()[0] == 1);
assert(v_batch_seq_nhead_hdim.get_lengths()[0] == 1);
@@ -62,7 +64,8 @@ struct reference_hstu_attention
}
else
{
assert(seq_offsets.empty());
assert(seq_q_offsets.empty());
assert(seq_kv_offsets.empty());
assert(q_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
assert(k_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
assert(v_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
@@ -104,8 +107,10 @@ struct reference_hstu_attention
};
auto f = [&](auto i_batch, auto i_head) {
int seqlen = kIsJagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch])
: q_batch_seq_nhead_hdim.get_lengths()[1];
int seqlen_q = kIsJagged ? (seq_q_offsets[i_batch + 1] - seq_q_offsets[i_batch])
: q_batch_seq_nhead_hdim.get_lengths()[1];
int seqlen_kv = kIsJagged ? (seq_kv_offsets[i_batch + 1] - seq_kv_offsets[i_batch])
: k_batch_seq_nhead_hdim.get_lengths()[1];
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
@@ -118,10 +123,11 @@ struct reference_hstu_attention
if constexpr(kHasLocal)
// need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if the
// user passed min_full_attn_seqlen is bigger than max_uih_len
if(seqlen - num_target > min_full_attn_seqlen)
if(seqlen_q - num_target > min_full_attn_seqlen)
return ck_tile::make_hstu_block_mask_with_local<HstuMaskType>(
true,
seqlen,
seqlen_q,
seqlen_kv,
contextual_seqlen,
num_target,
window_size,
@@ -129,14 +135,15 @@ struct reference_hstu_attention
else
return ck_tile::make_hstu_block_mask_with_local<HstuMaskType>(
true,
seqlen,
seqlen_q,
seqlen_kv,
contextual_seqlen,
num_target,
window_size,
seqlen - num_target);
seqlen_q - num_target);
else
return ck_tile::make_hstu_block_mask_without_local<HstuMaskType>(
seqlen, contextual_seqlen, num_target);
seqlen_q, seqlen_kv, contextual_seqlen, num_target);
}();
if(save_mask)
@@ -149,7 +156,7 @@ struct reference_hstu_attention
}
// for all rows in the batch
for(int sq = 0; sq < seqlen; sq++)
for(int sq = 0; sq < seqlen_q; sq++)
{
CompDataType m =
-ck_tile::numeric<CompDataType>::infinity(); // max value of the row
@@ -159,7 +166,7 @@ struct reference_hstu_attention
std::vector<CompDataType> locals;
// for all cols in the batch
for(int sk = 0; sk < seqlen; sk++)
for(int sk = 0; sk < seqlen_kv; sk++)
{
if(mask.IsTokenPairInsideMask(sq, sk))
{
@@ -169,9 +176,9 @@ struct reference_hstu_attention
if constexpr(kIsJagged)
{
InOutDataType qreg = q_batch_seq_nhead_hdim(
0, seq_offsets[i_batch] + sq, i_head, k);
0, seq_q_offsets[i_batch] + sq, i_head, k);
InOutDataType kreg = k_batch_seq_nhead_hdim(
0, seq_offsets[i_batch] + sk, i_head, k);
0, seq_kv_offsets[i_batch] + sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
ck_tile::type_convert<GemmAccDataType>(kreg);
@@ -233,14 +240,14 @@ struct reference_hstu_attention
{
GemmAccDataType dot_prod = 0.f;
for(int sk = 0; sk < seqlen; sk++)
for(int sk = 0; sk < seqlen_kv; sk++)
{
if constexpr(kIsJagged)
{
InOutDataType preg =
ck_tile::type_convert<InOutDataType>(locals[sk]);
InOutDataType vreg =
v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
InOutDataType vreg = v_batch_seq_nhead_hdim(
0, seq_kv_offsets[i_batch] + sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
ck_tile::type_convert<GemmAccDataType>(vreg);
@@ -257,7 +264,7 @@ struct reference_hstu_attention
};
if constexpr(kIsJagged)
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =
o_batch_seq_nhead_hdim(0, seq_q_offsets[i_batch] + sq, i_head, k) =
ck_tile::type_convert<InOutDataType>(dot_prod);
else
o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) =