mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Support separate sequence lengths for q and kv
This commit is contained in:
@@ -100,7 +100,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("nhead", "4", "number of heads")
|
||||
.insert("hdim_qk", "64", "headdim size of Q/K")
|
||||
.insert("hdim_v", "64", "headdim size of V/O")
|
||||
.insert("seqlens", "400", "uih seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
|
||||
.insert("seqlens", "400", "uih seqlen of single or all batches for query tensor, actually allocated seqlen will include the target of each batch and context_len")
|
||||
.insert("seqlens_kv", "", "uih seqlen of single or all batches for key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
|
||||
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
|
||||
@@ -238,13 +239,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::string str_of_targets = arg_parser.get_str("targets");
|
||||
std::vector<int> num_targets = get_integers_from_string(str_of_targets);
|
||||
|
||||
std::string str_of_lengths = arg_parser.get_str("seqlens");
|
||||
std::vector<int> seq_lengths = get_integers_from_string(str_of_lengths);
|
||||
std::string str_of_lengths_q = arg_parser.get_str("seqlens");
|
||||
std::vector<int> seq_lengths_q = get_integers_from_string(str_of_lengths_q);
|
||||
|
||||
std::string str_of_lengths_kv = arg_parser.get_str("seqlens_kv");
|
||||
std::vector<int> seq_lengths_kv = get_integers_from_string(str_of_lengths_kv);
|
||||
|
||||
int input_max_uih_seqlen = arg_parser.get_int("max_seqlen");
|
||||
int input_max_target = arg_parser.get_int("max_target");
|
||||
|
||||
int uih_seqlen = 0; // means total seq lengths for jagged
|
||||
int max_uih_seqlen = 0;
|
||||
int max_target = 0;
|
||||
|
||||
@@ -264,31 +267,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
max_target = max(max_target, num_targets[i]);
|
||||
};
|
||||
|
||||
HSTU_CHECK(!seq_lengths.empty(), "sequence lengths shoud be defined!");
|
||||
HSTU_CHECK(!seq_lengths_q.empty(), "sequence lengths of q shoud be defined!");
|
||||
|
||||
// assume seq_lengths_kv is same as seq_lengths_q if not defined
|
||||
if(seq_lengths_kv.empty())
|
||||
seq_lengths_kv = seq_lengths_q;
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
// supplement seq_lengths using the last input value if user-provided lengths not enough
|
||||
if(static_cast<int>(seq_lengths.size()) < num_batch)
|
||||
if(static_cast<int>(seq_lengths_q.size()) < num_batch)
|
||||
{
|
||||
auto last_len = seq_lengths.back();
|
||||
auto last_len = seq_lengths_q.back();
|
||||
|
||||
for(int i = seq_lengths.size(); i < num_batch; i++)
|
||||
seq_lengths.push_back(last_len);
|
||||
for(int i = seq_lengths_q.size(); i < num_batch; i++)
|
||||
seq_lengths_q.push_back(last_len);
|
||||
};
|
||||
|
||||
// supplement seq_lengths_kv using the last input value if user-provided lengths not enough
|
||||
if(static_cast<int>(seq_lengths_kv.size()) < num_batch)
|
||||
{
|
||||
auto last_len = seq_lengths_kv.back();
|
||||
|
||||
for(int i = seq_lengths_kv.size(); i < num_batch; i++)
|
||||
seq_lengths_kv.push_back(last_len);
|
||||
};
|
||||
|
||||
// only consider num_batch values even if more values are provided by the user
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
{
|
||||
max_uih_seqlen = max(max_uih_seqlen, seq_lengths[i]);
|
||||
max_uih_seqlen = max(max_uih_seqlen, seq_lengths_q[i]);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
HSTU_CHECK(1 == seq_lengths.size(),
|
||||
HSTU_CHECK(1 == seq_lengths_q.size() && 1 == seq_lengths_kv.size(),
|
||||
"sequence lengths for batched mode shoud have single element!");
|
||||
uih_seqlen = seq_lengths[0];
|
||||
max_uih_seqlen = uih_seqlen;
|
||||
max_uih_seqlen = max(seq_lengths_q[0], seq_lengths_kv[0]);
|
||||
};
|
||||
|
||||
// the user input of max_uih_seqlen can either be ignored or be bigger than all uih_seqlens
|
||||
@@ -304,28 +319,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
max_uih_seqlen = (input_max_uih_seqlen > 0) ? input_max_uih_seqlen : max_uih_seqlen;
|
||||
max_target = (input_max_target > 0) ? input_max_target : max_target;
|
||||
|
||||
int phy_seqlen = 0;
|
||||
int max_seqlen = max_uih_seqlen + max_target + contextual_seqlen;
|
||||
int phy_seqlen_q = 0;
|
||||
int phy_seqlen_kv = 0;
|
||||
int max_seqlen = max_uih_seqlen + max_target + contextual_seqlen;
|
||||
|
||||
std::vector<int> seq_offsets;
|
||||
std::vector<int> seq_offsets_q;
|
||||
std::vector<int> seq_offsets_kv;
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
seq_offsets.push_back(0);
|
||||
seq_offsets_q.push_back(0);
|
||||
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
{
|
||||
int batch_seqlen = num_targets.empty()
|
||||
? seq_lengths[i] + contextual_seqlen
|
||||
: seq_lengths[i] + num_targets[i] + contextual_seqlen;
|
||||
? seq_lengths_q[i] + contextual_seqlen
|
||||
: seq_lengths_q[i] + num_targets[i] + contextual_seqlen;
|
||||
|
||||
phy_seqlen += batch_seqlen;
|
||||
seq_offsets.push_back(phy_seqlen);
|
||||
phy_seqlen_q += batch_seqlen;
|
||||
seq_offsets_q.push_back(phy_seqlen_q);
|
||||
};
|
||||
|
||||
seq_offsets_kv.push_back(0);
|
||||
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
{
|
||||
int batch_seqlen = num_targets.empty()
|
||||
? seq_lengths_kv[i] + contextual_seqlen
|
||||
: seq_lengths_kv[i] + num_targets[i] + contextual_seqlen;
|
||||
|
||||
phy_seqlen_kv += batch_seqlen;
|
||||
seq_offsets_kv.push_back(phy_seqlen_kv);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
phy_seqlen = max_seqlen;
|
||||
phy_seqlen_q = max_seqlen;
|
||||
phy_seqlen_kv = max_seqlen;
|
||||
};
|
||||
|
||||
long total_flops = 0;
|
||||
@@ -335,10 +365,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
{
|
||||
int len = seq_offsets[i + 1] - seq_offsets[i];
|
||||
total_flops +=
|
||||
(static_cast<long>(len) * len * hdim_qk + static_cast<long>(len) * hdim_v * len) *
|
||||
2;
|
||||
int len_q = seq_offsets_q[i + 1] - seq_offsets_q[i];
|
||||
int len_kv = seq_offsets_kv[i + 1] - seq_offsets_kv[i];
|
||||
total_flops += (static_cast<long>(len_q) * len_kv * hdim_qk +
|
||||
static_cast<long>(len_q) * hdim_v * len_kv) *
|
||||
2;
|
||||
};
|
||||
|
||||
total_flops *= num_head;
|
||||
@@ -346,21 +377,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else
|
||||
{
|
||||
total_flops = static_cast<long>(num_batch) * num_head *
|
||||
(static_cast<long>(phy_seqlen) * phy_seqlen * hdim_qk +
|
||||
static_cast<long>(phy_seqlen) * hdim_v * phy_seqlen) *
|
||||
(static_cast<long>(phy_seqlen_q) * phy_seqlen_kv * hdim_qk +
|
||||
static_cast<long>(phy_seqlen_q) * hdim_v * phy_seqlen_kv) *
|
||||
2;
|
||||
};
|
||||
|
||||
int batches_for_alloc = is_jagged ? 1 : num_batch;
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> q_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_qk});
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_qk});
|
||||
ck_tile::HostTensor<InOutDataType> k_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_qk});
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_kv, num_head, hdim_qk});
|
||||
ck_tile::HostTensor<InOutDataType> v_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_v});
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_kv, num_head, hdim_v});
|
||||
ck_tile::HostTensor<InOutDataType> o_host_ref(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_v});
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
|
||||
|
||||
ck_tile::HostTensor<int8_t> mask_host(
|
||||
save_mask ? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_seqlen, max_seqlen}
|
||||
@@ -393,7 +424,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_dev(o_host_ref.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem seq_offsets_dev(seq_offsets.size() * sizeof(int));
|
||||
ck_tile::DeviceMem seq_offsets_q_dev(seq_offsets_q.size() * sizeof(int));
|
||||
ck_tile::DeviceMem seq_offsets_kv_dev(seq_offsets_kv.size() * sizeof(int));
|
||||
ck_tile::DeviceMem num_targets_dev(num_targets.size() * sizeof(int));
|
||||
|
||||
q_dev.ToDevice(q_host.data());
|
||||
@@ -401,7 +433,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
v_dev.ToDevice(v_host.data());
|
||||
|
||||
if(is_jagged)
|
||||
seq_offsets_dev.ToDevice(seq_offsets.data());
|
||||
{
|
||||
seq_offsets_q_dev.ToDevice(seq_offsets_q.data());
|
||||
seq_offsets_kv_dev.ToDevice(seq_offsets_kv.data());
|
||||
};
|
||||
if(!num_targets.empty())
|
||||
num_targets_dev.ToDevice(num_targets.data());
|
||||
|
||||
@@ -411,30 +446,31 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
params.is_jagged = true;
|
||||
params.num_batch = num_batch;
|
||||
params.seq_offsets_ptr = seq_offsets_dev.GetDeviceBuffer();
|
||||
params.max_seqlen = max_seqlen;
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
params.bias_ptr = nullptr; // bias is not supported at present
|
||||
params.o_ptr = o_dev.GetDeviceBuffer();
|
||||
params.hdim_qk = hdim_qk;
|
||||
params.hdim_v = hdim_v;
|
||||
params.num_head = num_head;
|
||||
params.scale_s = scale_s;
|
||||
params.attn_scale = attn_scale;
|
||||
params.seq_stride_q = q_host.get_strides()[1];
|
||||
params.seq_stride_k = k_host.get_strides()[1];
|
||||
params.seq_stride_v = v_host.get_strides()[1];
|
||||
params.seq_stride_bias = 0;
|
||||
params.seq_stride_o = o_host_ref.get_strides()[1];
|
||||
params.nhead_stride_q = q_host.get_strides()[2];
|
||||
params.nhead_stride_k = k_host.get_strides()[2];
|
||||
params.nhead_stride_v = v_host.get_strides()[2];
|
||||
params.nhead_stride_bias = 0;
|
||||
params.nhead_stride_o = o_host_ref.get_strides()[2];
|
||||
params.is_jagged = true;
|
||||
params.num_batch = num_batch;
|
||||
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
|
||||
params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer();
|
||||
params.max_seqlen = max_seqlen;
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
params.bias_ptr = nullptr; // bias is not supported at present
|
||||
params.o_ptr = o_dev.GetDeviceBuffer();
|
||||
params.hdim_qk = hdim_qk;
|
||||
params.hdim_v = hdim_v;
|
||||
params.num_head = num_head;
|
||||
params.scale_s = scale_s;
|
||||
params.attn_scale = attn_scale;
|
||||
params.seq_stride_q = q_host.get_strides()[1];
|
||||
params.seq_stride_k = k_host.get_strides()[1];
|
||||
params.seq_stride_v = v_host.get_strides()[1];
|
||||
params.seq_stride_bias = 0;
|
||||
params.seq_stride_o = o_host_ref.get_strides()[1];
|
||||
params.nhead_stride_q = q_host.get_strides()[2];
|
||||
params.nhead_stride_k = k_host.get_strides()[2];
|
||||
params.nhead_stride_v = v_host.get_strides()[2];
|
||||
params.nhead_stride_bias = 0;
|
||||
params.nhead_stride_o = o_host_ref.get_strides()[2];
|
||||
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
|
||||
params.use_softmax = use_softmax;
|
||||
params.use_causal = use_causal;
|
||||
@@ -449,7 +485,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
params.is_jagged = false;
|
||||
params.num_batch = num_batch;
|
||||
params.seqlen = max_seqlen;
|
||||
params.seqlen_q = phy_seqlen_q;
|
||||
params.seqlen_kv = phy_seqlen_kv;
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
@@ -532,7 +569,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
scale_s,
|
||||
attn_scale,
|
||||
max_seqlen,
|
||||
seq_offsets,
|
||||
seq_offsets_q,
|
||||
seq_offsets_kv,
|
||||
num_targets,
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
@@ -540,7 +578,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
});
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> o_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_v});
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
|
||||
|
||||
o_dev.FromDevice(o_host.data());
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
const bool pad_seqlen_k = !(param.seqlen % HstuAttentionTileSetting::kN0 == 0);
|
||||
const bool pad_seqlen_k = !(param.seqlen_kv % HstuAttentionTileSetting::kN0 == 0);
|
||||
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0);
|
||||
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0);
|
||||
|
||||
@@ -125,7 +125,8 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_ptr,
|
||||
param.seqlen,
|
||||
param.seqlen_q,
|
||||
param.seqlen_kv,
|
||||
param.hdim_qk,
|
||||
param.hdim_v,
|
||||
param.num_head,
|
||||
@@ -157,7 +158,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
|
||||
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
|
||||
dim3 kGridSize = HstuKernel::GridSize(
|
||||
param.num_batch, param.num_head, param.seqlen, param.hdim_v, has_minfull_attn_seqlen);
|
||||
param.num_batch, param.num_head, param.seqlen_q, param.hdim_v, has_minfull_attn_seqlen);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -78,7 +78,8 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
ck_tile::index_t seqlen;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_kv;
|
||||
ck_tile::index_t hdim_qk;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
@@ -98,7 +99,8 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
struct HstuAttentionFwdJaggModeBaseKargs
|
||||
{
|
||||
const int32_t* seq_offsets_ptr;
|
||||
const int32_t* seq_q_offsets_ptr;
|
||||
const int32_t* seq_kv_offsets_ptr;
|
||||
|
||||
ck_tile::index_t seq_stride_q;
|
||||
ck_tile::index_t seq_stride_k;
|
||||
@@ -120,7 +122,8 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t hdim_qk;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
ck_tile::index_t seqlen;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_kv;
|
||||
|
||||
ck_tile::index_t num_head;
|
||||
float scale_s; // scaling value exerted on the immediate Q@K result
|
||||
@@ -196,7 +199,8 @@ struct HstuAttentionFwdKernel
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_kv,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
@@ -239,7 +243,8 @@ struct HstuAttentionFwdKernel
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o,
|
||||
seqlen,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
seq_stride_q,
|
||||
@@ -248,7 +253,8 @@ struct HstuAttentionFwdKernel
|
||||
seq_stride_o,
|
||||
num_head,
|
||||
scale_s,
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(seqlen), // max_seqlen
|
||||
attn_scale ? attn_scale
|
||||
: 1.0f / static_cast<float>(max(seqlen_q, seqlen_kv)), // max_seqlen
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen}, // args for common karg
|
||||
@@ -278,7 +284,8 @@ struct HstuAttentionFwdKernel
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
const void* seq_offsets_ptr,
|
||||
const void* seq_q_offsets_ptr,
|
||||
const void* seq_kv_offsets_ptr,
|
||||
ck_tile::index_t max_seqlen,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
@@ -304,7 +311,8 @@ struct HstuAttentionFwdKernel
|
||||
uint64_t philox_offset)
|
||||
{
|
||||
Kargs kargs{
|
||||
{reinterpret_cast<const int32_t*>(seq_offsets_ptr),
|
||||
{reinterpret_cast<const int32_t*>(seq_q_offsets_ptr),
|
||||
reinterpret_cast<const int32_t*>(seq_kv_offsets_ptr),
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
@@ -320,7 +328,8 @@ struct HstuAttentionFwdKernel
|
||||
nhead_stride_o,
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
-1, // seqlen_q will be updated by another pointer
|
||||
-1, // seqlen_kv will be updated by another pointer
|
||||
num_head,
|
||||
scale_s,
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen),
|
||||
@@ -465,8 +474,8 @@ struct HstuAttentionFwdKernel
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seq_offsets_ptr[i_batch];
|
||||
const long_index_t key_start = query_start;
|
||||
const long_index_t query_start = kargs.seq_q_offsets_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seq_kv_offsets_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.seq_stride_q;
|
||||
batch_offset_k = key_start * kargs.seq_stride_k;
|
||||
@@ -478,7 +487,10 @@ struct HstuAttentionFwdKernel
|
||||
}
|
||||
batch_offset_o = query_start * kargs.seq_stride_o;
|
||||
|
||||
kargs.seqlen = kargs.seq_offsets_ptr[i_batch + 1] - kargs.seq_offsets_ptr[i_batch];
|
||||
kargs.seqlen_q =
|
||||
kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch];
|
||||
kargs.seqlen_kv =
|
||||
kargs.seq_kv_offsets_ptr[i_batch + 1] - kargs.seq_kv_offsets_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -494,16 +506,16 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch];
|
||||
|
||||
index_t seqlen_in_first_split = kargs.seqlen;
|
||||
index_t seqlen_in_first_split = kargs.seqlen_q;
|
||||
bool is_tile_in_first_split = true;
|
||||
index_t i_m0;
|
||||
|
||||
if(kargs.min_full_attn_seqlen > 0)
|
||||
{
|
||||
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
|
||||
if(kargs.seqlen - num_target > kargs.min_full_attn_seqlen)
|
||||
if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen)
|
||||
{
|
||||
seqlen_in_first_split = kargs.seqlen - num_target - kargs.min_full_attn_seqlen;
|
||||
seqlen_in_first_split = kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen;
|
||||
|
||||
index_t num_tile_in_first_split =
|
||||
ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0);
|
||||
@@ -522,7 +534,7 @@ struct HstuAttentionFwdKernel
|
||||
is_tile_in_first_split = false;
|
||||
|
||||
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
|
||||
kargs.min_full_attn_seqlen = kargs.seqlen - num_target;
|
||||
kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target;
|
||||
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
};
|
||||
@@ -532,7 +544,7 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * HstuAttentionPipeline::kN1);
|
||||
|
||||
index_t seqlen_q_in_ctrl = is_tile_in_first_split ? seqlen_in_first_split : kargs.seqlen;
|
||||
index_t seqlen_q_in_ctrl = is_tile_in_first_split ? seqlen_in_first_split : kargs.seqlen_q;
|
||||
|
||||
if(seqlen_q_in_ctrl <= i_m0)
|
||||
return;
|
||||
@@ -567,7 +579,7 @@ struct HstuAttentionFwdKernel
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_qk),
|
||||
make_tuple(kargs.seqlen_kv, kargs.hdim_qk),
|
||||
make_tuple(kargs.seq_stride_k, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
@@ -580,7 +592,7 @@ struct HstuAttentionFwdKernel
|
||||
const auto v_dram = [&]() {
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_v),
|
||||
make_tuple(kargs.seqlen_kv, kargs.hdim_v),
|
||||
make_tuple(kargs.seq_stride_v, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
@@ -590,7 +602,7 @@ struct HstuAttentionFwdKernel
|
||||
const auto v_dram_transposed =
|
||||
transform_tensor_view(v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen)),
|
||||
make_pass_through_transform(kargs.seqlen_kv)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
@@ -641,7 +653,7 @@ struct HstuAttentionFwdKernel
|
||||
const auto bias_dram = [&]() {
|
||||
const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
bias_ptr,
|
||||
make_tuple(seqlen_q_in_ctrl, kargs.seqlen),
|
||||
make_tuple(seqlen_q_in_ctrl, kargs.seqlen_kv),
|
||||
make_tuple(kargs.seq_stride_bias, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentBias>{},
|
||||
number<1>{});
|
||||
@@ -682,7 +694,8 @@ struct HstuAttentionFwdKernel
|
||||
using HstuMaskType = typename ck_tile::HstuBlockMasking<kHasCausalMask, true>::Type;
|
||||
const auto mask =
|
||||
make_hstu_block_mask_with_local<HstuMaskType>(is_tile_in_first_split,
|
||||
kargs.seqlen,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_kv,
|
||||
kargs.contextual_seqlen,
|
||||
num_target,
|
||||
kargs.window_size,
|
||||
@@ -703,7 +716,7 @@ struct HstuAttentionFwdKernel
|
||||
using HstuMaskType =
|
||||
typename ck_tile::HstuBlockMasking<kHasCausalMask, false>::Type;
|
||||
const auto mask = make_hstu_block_mask_without_local<HstuMaskType>(
|
||||
kargs.seqlen, kargs.contextual_seqlen, num_target);
|
||||
kargs.seqlen_q, kargs.seqlen_kv, kargs.contextual_seqlen, num_target);
|
||||
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
|
||||
@@ -118,7 +118,8 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_ptr,
|
||||
param.seq_offsets_ptr,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.seq_kv_offsets_ptr,
|
||||
param.max_seqlen,
|
||||
param.hdim_qk,
|
||||
param.hdim_v,
|
||||
|
||||
@@ -10,9 +10,11 @@ struct HstuAttentionFwdParams
|
||||
bool is_jagged;
|
||||
|
||||
ck_tile::index_t num_batch;
|
||||
ck_tile::index_t seqlen; // batched mode only
|
||||
const void* seq_offsets_ptr; // jagged mode only
|
||||
ck_tile::index_t max_seqlen; // jagged mode only
|
||||
ck_tile::index_t seqlen_q; // batched mode only
|
||||
ck_tile::index_t seqlen_kv; // batched mode only
|
||||
const void* seq_q_offsets_ptr; // jagged mode only
|
||||
const void* seq_kv_offsets_ptr; // jagged mode only
|
||||
ck_tile::index_t max_seqlen; // jagged mode only
|
||||
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
|
||||
@@ -15,42 +15,54 @@ struct HstuBlockMaskWithLocal
|
||||
static constexpr bool IsMasking = true;
|
||||
|
||||
// is_tile_in_first_split is false only when min_full_attn_seqlen > 0 and the current
|
||||
// tile is inside scope [max_uih_len - min_full_attn_seqlen, seqlen); for other cases
|
||||
// tile is inside scope [max_uih_len - min_full_attn_seqlen, seqlen_q); for other cases
|
||||
// and tiles, is_tile_in_first_split is true
|
||||
bool is_tile_in_first_split;
|
||||
int seqlen;
|
||||
int seqlen_q;
|
||||
int seqlen_k;
|
||||
int contextual_seqlen;
|
||||
|
||||
int min_full_attn_seqlen;
|
||||
int max_attn_len;
|
||||
|
||||
int max_uih_len;
|
||||
int max_id;
|
||||
int max_q_uih_len;
|
||||
int max_k_uih_len;
|
||||
int max_row_id;
|
||||
int max_col_id;
|
||||
|
||||
CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(bool is_tile_in_first_split_,
|
||||
int seqlen_,
|
||||
int seqlen_q_,
|
||||
int seqlen_k_,
|
||||
int contextual_seqlen_,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_,
|
||||
int num_target_)
|
||||
: is_tile_in_first_split(is_tile_in_first_split_),
|
||||
seqlen(seqlen_),
|
||||
seqlen_q(seqlen_q_),
|
||||
seqlen_k(seqlen_k_),
|
||||
contextual_seqlen(contextual_seqlen_),
|
||||
min_full_attn_seqlen(min_full_attn_seqlen_)
|
||||
{
|
||||
max_uih_len = seqlen - num_target_;
|
||||
max_q_uih_len = seqlen_q - num_target_;
|
||||
max_k_uih_len = seqlen_k - num_target_;
|
||||
|
||||
// in case user provided max_attn_len_ could be bigger than max_uih_len
|
||||
max_attn_len = min(max_uih_len, max_attn_len_);
|
||||
max_attn_len = min(max_k_uih_len, min(max_q_uih_len, max_attn_len_));
|
||||
|
||||
// assuming min_full_attn_seqlen has higher priority, ensure contextual scope not collide
|
||||
// with min_full_attn_seqlen scope
|
||||
contextual_seqlen = min(contextual_seqlen, max_uih_len - min_full_attn_seqlen);
|
||||
contextual_seqlen = min(contextual_seqlen, max_q_uih_len - min_full_attn_seqlen);
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
max_id = max_uih_len - (contextual_seqlen - 1);
|
||||
{
|
||||
max_row_id = max_q_uih_len - (contextual_seqlen - 1);
|
||||
max_col_id = max_k_uih_len - (contextual_seqlen - 1);
|
||||
}
|
||||
else
|
||||
max_id = max_uih_len;
|
||||
{
|
||||
max_row_id = max_q_uih_len;
|
||||
max_col_id = max_k_uih_len;
|
||||
}
|
||||
};
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
@@ -65,20 +77,20 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
index_t x_end = min(i_y + YTile, seqlen_k);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else
|
||||
{
|
||||
// tile is partitially or completely in [max_uih_len-min_full_attn_seqlen,
|
||||
// max_uih_len)
|
||||
if(i_y < max_uih_len)
|
||||
// max_q_uih_len)
|
||||
if(i_y < max_q_uih_len)
|
||||
{
|
||||
return ck_tile::make_tuple(0, seqlen);
|
||||
return ck_tile::make_tuple(0, seqlen_k);
|
||||
}
|
||||
else // tile is completely inside [max_uih_len, seqlen)
|
||||
else // tile is completely inside [max_q_uih_len, seqlen_q)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
index_t x_end = min(i_y + YTile, seqlen_k);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
};
|
||||
@@ -90,18 +102,18 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
if(i_y >= min(contextual_seqlen, 1) + max_attn_len)
|
||||
{
|
||||
// some row of the tile in [contextual_seqlen+max_attn_len, max_uih_len)
|
||||
if(i_y < max_uih_len)
|
||||
// some row of the tile in [contextual_seqlen+max_attn_len, max_q_uih_len)
|
||||
if(i_y < max_q_uih_len)
|
||||
{
|
||||
index_t x_start = i_y - max_attn_len;
|
||||
index_t x_start_aligned = x_start - x_start % XTile;
|
||||
|
||||
// some rows of the tile in [max_uih_len -max_attn_len, max_uih_len)
|
||||
if(i_y + YTile > max_uih_len - max_attn_len)
|
||||
// some rows of the tile in [max_q_uih_len - max_attn_len, max_q_uih_len)
|
||||
if(i_y + YTile > max_q_uih_len - max_attn_len)
|
||||
{
|
||||
return ck_tile::make_tuple(x_start_aligned, seqlen);
|
||||
return ck_tile::make_tuple(x_start_aligned, seqlen_k);
|
||||
}
|
||||
else // whole tile in [contextual_seqlen+max_attn_len, max_uih_len
|
||||
else // whole tile in [contextual_seqlen+max_attn_len, max_q_uih_len
|
||||
// -max_attn_len)
|
||||
{
|
||||
index_t x_end = i_y + YTile + max_attn_len;
|
||||
@@ -110,8 +122,8 @@ struct HstuBlockMaskWithLocal
|
||||
}
|
||||
else // whole tile in [max_uih_len, seqlen)
|
||||
{
|
||||
index_t x_start = max_uih_len - max_attn_len;
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
index_t x_start = max_k_uih_len - max_attn_len;
|
||||
index_t x_end = min(i_y + YTile, seqlen_k);
|
||||
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
}
|
||||
@@ -120,12 +132,12 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen)
|
||||
{
|
||||
index_t x_end = min(max(i_y + YTile + max_attn_len, max_uih_len), seqlen);
|
||||
index_t x_end = min(max(i_y + YTile + max_attn_len, max_k_uih_len), seqlen_k);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else // whole tile in [contextual_seqlen, seqlen)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile + max_attn_len, seqlen);
|
||||
index_t x_end = min(i_y + YTile + max_attn_len, seqlen_k);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
}
|
||||
@@ -134,17 +146,17 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
if(i_y >= min(contextual_seqlen, 1) + max_attn_len)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
index_t x_end = min(i_y + YTile, seqlen_k);
|
||||
|
||||
// some row of the tile in [contextual_seqlen+max_attn_len, max_uih_len)
|
||||
if(i_y < max_uih_len)
|
||||
// some row of the tile in [contextual_seqlen+max_attn_len, max_q_uih_len)
|
||||
if(i_y < max_q_uih_len)
|
||||
{
|
||||
index_t x_start = i_y - max_attn_len;
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
}
|
||||
else // whole tile in [max_uih_len, seqlen)
|
||||
{
|
||||
index_t x_start = max_uih_len - max_attn_len;
|
||||
index_t x_start = max_k_uih_len - max_attn_len;
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
}
|
||||
}
|
||||
@@ -152,12 +164,12 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen)
|
||||
{
|
||||
index_t x_end = min(max(i_y + YTile, max_uih_len), seqlen);
|
||||
index_t x_end = min(max(i_y + YTile, max_k_uih_len), seqlen_k);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else // whole tile in [contextual_seqlen, seqlen)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
index_t x_end = min(i_y + YTile, seqlen_k);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
}
|
||||
@@ -176,18 +188,18 @@ struct HstuBlockMaskWithLocal
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
row_id = min(row_id, max_row_id);
|
||||
col_id = min(col_id, max_col_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
if(row_id == 0 && col_id < max_col_id)
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = min(row, max_id);
|
||||
col_id = min(col, max_id);
|
||||
row_id = min(row, max_row_id);
|
||||
col_id = min(col, max_col_id);
|
||||
};
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
@@ -195,7 +207,7 @@ struct HstuBlockMaskWithLocal
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
bool in_min_full_scope =
|
||||
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
|
||||
(min_full_attn_seqlen > 0) ? (row_id >= max_row_id - min_full_attn_seqlen) : false;
|
||||
|
||||
return (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) || in_min_full_scope));
|
||||
@@ -203,7 +215,7 @@ struct HstuBlockMaskWithLocal
|
||||
else
|
||||
{
|
||||
bool in_min_full_scope =
|
||||
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
|
||||
(min_full_attn_seqlen > 0) ? (row_id >= max_row_id - min_full_attn_seqlen) : false;
|
||||
|
||||
return (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
|
||||
@@ -222,18 +234,18 @@ struct HstuBlockMaskWithLocal
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
row_id = min(row_id, max_row_id);
|
||||
col_id = min(col_id, max_col_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
if(row_id == 0 && col_id < max_col_id)
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = min(row, max_id);
|
||||
col_id = min(col, max_id);
|
||||
row_id = min(row, max_row_id);
|
||||
col_id = min(col, max_col_id);
|
||||
};
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
@@ -269,7 +281,7 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
|
||||
if(!is_tile_in_first_split && i_tile_right <= min(i_tile_top + 1, max_uih_len))
|
||||
if(!is_tile_in_first_split && i_tile_right <= min(i_tile_top + 1, max_k_uih_len))
|
||||
return true;
|
||||
}
|
||||
else
|
||||
@@ -277,11 +289,11 @@ struct HstuBlockMaskWithLocal
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_bottom = i_tile_top + TileHeight;
|
||||
|
||||
// 1) tile is completely in [max_uih_len-min_full_attn_seqlen, max_uih_len]
|
||||
// 2) some row of tile is in [max_uih_len, seqlen], requires i_tile_right <= max_uih_len
|
||||
// to return true
|
||||
// 1) tile is completely in [max_q_uih_len-min_full_attn_seqlen, max_q_uih_len]
|
||||
// 2) some row of tile is in [max_q_uih_len, seqlen_q], requires i_tile_right <=
|
||||
// max_k_uih_len to return true
|
||||
if(!is_tile_in_first_split &&
|
||||
(i_tile_bottom <= max_uih_len || i_tile_right <= max_uih_len))
|
||||
(i_tile_bottom <= max_q_uih_len || i_tile_right <= max_k_uih_len))
|
||||
return true;
|
||||
};
|
||||
|
||||
@@ -295,21 +307,32 @@ struct HstuBlockMaskNoLocal
|
||||
static constexpr bool kUseLocal = false;
|
||||
static constexpr bool IsMasking = kUseCausal;
|
||||
|
||||
int seqlen;
|
||||
int seqlen_q;
|
||||
int seqlen_k;
|
||||
int contextual_seqlen;
|
||||
|
||||
int max_uih_len;
|
||||
int max_id;
|
||||
int max_q_uih_len;
|
||||
int max_k_uih_len;
|
||||
int max_row_id;
|
||||
int max_col_id;
|
||||
|
||||
CK_TILE_HOST_DEVICE HstuBlockMaskNoLocal(int seqlen_, int contextual_seqlen_, int num_target_)
|
||||
: seqlen(seqlen_), contextual_seqlen(contextual_seqlen_)
|
||||
CK_TILE_HOST_DEVICE
|
||||
HstuBlockMaskNoLocal(int seqlen_q_, int seqlen_k_, int contextual_seqlen_, int num_target_)
|
||||
: seqlen_q(seqlen_q_), seqlen_k(seqlen_k_), contextual_seqlen(contextual_seqlen_)
|
||||
{
|
||||
max_uih_len = seqlen - num_target_;
|
||||
max_q_uih_len = seqlen_q - num_target_;
|
||||
max_k_uih_len = seqlen_k - num_target_;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
max_id = max_uih_len - (contextual_seqlen - 1);
|
||||
{
|
||||
max_row_id = max_q_uih_len - (contextual_seqlen - 1);
|
||||
max_col_id = max_k_uih_len - (contextual_seqlen - 1);
|
||||
}
|
||||
else
|
||||
max_id = max_uih_len;
|
||||
{
|
||||
max_row_id = max_q_uih_len;
|
||||
max_col_id = max_k_uih_len;
|
||||
}
|
||||
};
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
@@ -321,21 +344,21 @@ struct HstuBlockMaskNoLocal
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, seqlen);
|
||||
return ck_tile::make_tuple(0, seqlen_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
index_t x_end = min(i_y + YTile, seqlen_k);
|
||||
|
||||
if(i_y < contextual_seqlen)
|
||||
{
|
||||
if(i_y + YTile > max_uih_len)
|
||||
if(i_y + YTile > max_k_uih_len)
|
||||
{
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
return ck_tile::make_tuple(0, max_k_uih_len);
|
||||
};
|
||||
}
|
||||
else
|
||||
@@ -357,18 +380,18 @@ struct HstuBlockMaskNoLocal
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
row_id = min(row_id, max_row_id);
|
||||
col_id = min(col_id, max_col_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
if(row_id == 0 && col_id < max_col_id)
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = min(row, max_id);
|
||||
col_id = min(col, max_id);
|
||||
row_id = min(row, max_row_id);
|
||||
col_id = min(col, max_col_id);
|
||||
};
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
@@ -395,18 +418,18 @@ struct HstuBlockMaskNoLocal
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
row_id = min(row_id, max_row_id);
|
||||
col_id = min(col_id, max_col_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
if(row_id == 0 && col_id < max_col_id)
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = min(row, max_id);
|
||||
col_id = min(col, max_id);
|
||||
row_id = min(row, max_row_id);
|
||||
col_id = min(col, max_col_id);
|
||||
};
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
@@ -439,7 +462,7 @@ struct HstuBlockMaskNoLocal
|
||||
|
||||
// assume num_target > 0 with high probability, don't check whether num_target is 0;
|
||||
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
|
||||
if(i_tile_bottom >= max_uih_len || i_tile_right > i_tile_top)
|
||||
if(i_tile_bottom >= max_q_uih_len || i_tile_right > i_tile_top)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
@@ -451,7 +474,7 @@ struct HstuBlockMaskNoLocal
|
||||
|
||||
// assume num_target > 0 with high probability, don't check whether num_target is 0;
|
||||
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
|
||||
if(i_tile_bottom >= max_uih_len || i_tile_right >= max_uih_len)
|
||||
if(i_tile_bottom >= max_q_uih_len || i_tile_right >= max_k_uih_len)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
@@ -469,14 +492,16 @@ struct HstuBlockMasking
|
||||
|
||||
template <typename HstuBlockMaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(bool is_tile_in_first_split_,
|
||||
int seqlen_,
|
||||
int seqlen_q_,
|
||||
int seqlen_k_,
|
||||
int contextual_seqlen_,
|
||||
int num_target,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_)
|
||||
{
|
||||
return HstuBlockMaskType{is_tile_in_first_split_,
|
||||
seqlen_,
|
||||
seqlen_q_,
|
||||
seqlen_k_,
|
||||
contextual_seqlen_,
|
||||
max_attn_len_,
|
||||
min_full_attn_seqlen_,
|
||||
@@ -484,10 +509,12 @@ CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(bool is_tile_
|
||||
};
|
||||
|
||||
template <typename HstuBlockMaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_hstu_block_mask_without_local(int seqlen_, int contextual_seqlen_, int num_target)
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_without_local(int seqlen_q_,
|
||||
int seqlen_k_,
|
||||
int contextual_seqlen_,
|
||||
int num_target)
|
||||
{
|
||||
return HstuBlockMaskType{seqlen_, contextual_seqlen_, num_target};
|
||||
return HstuBlockMaskType{seqlen_q_, seqlen_k_, contextual_seqlen_, num_target};
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -42,7 +42,8 @@ struct reference_hstu_attention
|
||||
float alpha,
|
||||
float attn_scale,
|
||||
int max_seqlen,
|
||||
std::vector<int> seq_offsets,
|
||||
std::vector<int> seq_q_offsets,
|
||||
std::vector<int> seq_kv_offsets,
|
||||
std::vector<int> num_targets, // define masking length at the end of token
|
||||
// sequence to be excluded for attention
|
||||
int contextual_seqlen, // define masking length at the begin of query token
|
||||
@@ -54,7 +55,8 @@ struct reference_hstu_attention
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
// check the number of batches
|
||||
assert(!seq_offsets.empty() && seq_offsets.size() == num_batch + 1);
|
||||
assert(!seq_q_offsets.empty() && seq_q_offsets.size() == num_batch + 1);
|
||||
assert(!seq_kv_offsets.empty() && seq_kv_offsets.size() == num_batch + 1);
|
||||
assert(q_batch_seq_nhead_hdim.get_lengths()[0] == 1);
|
||||
assert(k_batch_seq_nhead_hdim.get_lengths()[0] == 1);
|
||||
assert(v_batch_seq_nhead_hdim.get_lengths()[0] == 1);
|
||||
@@ -62,7 +64,8 @@ struct reference_hstu_attention
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(seq_offsets.empty());
|
||||
assert(seq_q_offsets.empty());
|
||||
assert(seq_kv_offsets.empty());
|
||||
assert(q_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
|
||||
assert(k_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
|
||||
assert(v_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
|
||||
@@ -104,8 +107,10 @@ struct reference_hstu_attention
|
||||
};
|
||||
|
||||
auto f = [&](auto i_batch, auto i_head) {
|
||||
int seqlen = kIsJagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch])
|
||||
: q_batch_seq_nhead_hdim.get_lengths()[1];
|
||||
int seqlen_q = kIsJagged ? (seq_q_offsets[i_batch + 1] - seq_q_offsets[i_batch])
|
||||
: q_batch_seq_nhead_hdim.get_lengths()[1];
|
||||
int seqlen_kv = kIsJagged ? (seq_kv_offsets[i_batch + 1] - seq_kv_offsets[i_batch])
|
||||
: k_batch_seq_nhead_hdim.get_lengths()[1];
|
||||
|
||||
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
|
||||
|
||||
@@ -118,10 +123,11 @@ struct reference_hstu_attention
|
||||
if constexpr(kHasLocal)
|
||||
// need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if the
|
||||
// user passed min_full_attn_seqlen is bigger than max_uih_len
|
||||
if(seqlen - num_target > min_full_attn_seqlen)
|
||||
if(seqlen_q - num_target > min_full_attn_seqlen)
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMaskType>(
|
||||
true,
|
||||
seqlen,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
window_size,
|
||||
@@ -129,14 +135,15 @@ struct reference_hstu_attention
|
||||
else
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMaskType>(
|
||||
true,
|
||||
seqlen,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
contextual_seqlen,
|
||||
num_target,
|
||||
window_size,
|
||||
seqlen - num_target);
|
||||
seqlen_q - num_target);
|
||||
else
|
||||
return ck_tile::make_hstu_block_mask_without_local<HstuMaskType>(
|
||||
seqlen, contextual_seqlen, num_target);
|
||||
seqlen_q, seqlen_kv, contextual_seqlen, num_target);
|
||||
}();
|
||||
|
||||
if(save_mask)
|
||||
@@ -149,7 +156,7 @@ struct reference_hstu_attention
|
||||
}
|
||||
|
||||
// for all rows in the batch
|
||||
for(int sq = 0; sq < seqlen; sq++)
|
||||
for(int sq = 0; sq < seqlen_q; sq++)
|
||||
{
|
||||
CompDataType m =
|
||||
-ck_tile::numeric<CompDataType>::infinity(); // max value of the row
|
||||
@@ -159,7 +166,7 @@ struct reference_hstu_attention
|
||||
std::vector<CompDataType> locals;
|
||||
|
||||
// for all cols in the batch
|
||||
for(int sk = 0; sk < seqlen; sk++)
|
||||
for(int sk = 0; sk < seqlen_kv; sk++)
|
||||
{
|
||||
if(mask.IsTokenPairInsideMask(sq, sk))
|
||||
{
|
||||
@@ -169,9 +176,9 @@ struct reference_hstu_attention
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
InOutDataType qreg = q_batch_seq_nhead_hdim(
|
||||
0, seq_offsets[i_batch] + sq, i_head, k);
|
||||
0, seq_q_offsets[i_batch] + sq, i_head, k);
|
||||
InOutDataType kreg = k_batch_seq_nhead_hdim(
|
||||
0, seq_offsets[i_batch] + sk, i_head, k);
|
||||
0, seq_kv_offsets[i_batch] + sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(kreg);
|
||||
@@ -233,14 +240,14 @@ struct reference_hstu_attention
|
||||
{
|
||||
GemmAccDataType dot_prod = 0.f;
|
||||
|
||||
for(int sk = 0; sk < seqlen; sk++)
|
||||
for(int sk = 0; sk < seqlen_kv; sk++)
|
||||
{
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
InOutDataType preg =
|
||||
ck_tile::type_convert<InOutDataType>(locals[sk]);
|
||||
InOutDataType vreg =
|
||||
v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
|
||||
InOutDataType vreg = v_batch_seq_nhead_hdim(
|
||||
0, seq_kv_offsets[i_batch] + sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(vreg);
|
||||
@@ -257,7 +264,7 @@ struct reference_hstu_attention
|
||||
};
|
||||
|
||||
if constexpr(kIsJagged)
|
||||
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =
|
||||
o_batch_seq_nhead_hdim(0, seq_q_offsets[i_batch] + sq, i_head, k) =
|
||||
ck_tile::type_convert<InOutDataType>(dot_prod);
|
||||
else
|
||||
o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) =
|
||||
|
||||
Reference in New Issue
Block a user