Fix and change in example

This commit is contained in:
Qianfeng Zhang
2025-04-03 14:44:36 +00:00
parent 121a950df5
commit 733734553b

View File

@@ -62,7 +62,7 @@ auto create_args(int argc, char* argv[])
.insert("hdim_qk", "64", "headdim size of Q/K")
.insert("hdim_v", "64", "headdim size of V/O")
.insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor")
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
@@ -95,13 +95,53 @@ static std::vector<int> get_integers_from_string(std::string lengthsStr)
};
std::string sliceStr = lengthsStr.substr(pos);
int len = std::stoi(sliceStr);
lengths.push_back(len);
if(!sliceStr.empty())
{
int len = std::stoi(sliceStr);
lengths.push_back(len);
};
return (lengths);
};
static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdParams& param)
{
if(param.is_jagged)
{
os << "Jagged inputs used! " << std::endl;
os << "use causal: " << param.use_causal << std::endl;
os << "Num of batches: " << param.num_batch << std::endl;
os << "Num of heads: " << param.num_head << std::endl;
os << "QK hdim: " << param.hdim_qk << " V hdim: " << param.hdim_v << std::endl;
os << "Q/K/V/O seq stride: " << param.seq_stride_q << " " << param.seq_stride_k << " "
<< param.seq_stride_v << " " << param.seq_stride_o << std::endl;
os << "Q/K/V/O nhead stride: " << param.nhead_stride_q << " " << param.nhead_stride_k << " "
<< param.nhead_stride_v << " " << param.nhead_stride_o << std::endl;
os << "contextual_seqlen: " << param.contextual_seqlen << std::endl;
os << "window_size: " << param.window_size << std::endl;
os << "min_full_attn_seqlen: " << param.min_full_attn_seqlen << std::endl;
}
else
{
os << "Batched inputs used! " << std::endl;
os << "use causal: " << param.use_causal << std::endl;
os << "Num of batches: " << param.num_batch << std::endl;
os << "Num of heads: " << param.num_head << std::endl;
os << "QK hdim: " << param.hdim_qk << " V hdim: " << param.hdim_v << std::endl;
os << "Q/K/V/O seq stride: " << param.seq_stride_q << " " << param.seq_stride_k << " "
<< param.seq_stride_v << " " << param.seq_stride_o << std::endl;
os << "Q/K/V/O nhead stride: " << param.nhead_stride_q << " " << param.nhead_stride_k << " "
<< param.nhead_stride_v << " " << param.nhead_stride_o << std::endl;
os << "Q/K/V/O batch stride: " << param.batch_stride_q << " " << param.batch_stride_k << " "
<< param.batch_stride_v << " " << param.batch_stride_o << std::endl;
os << "contextual_seqlen: " << param.contextual_seqlen << std::endl;
os << "window_size: " << param.window_size << std::endl;
os << "min_full_attn_seqlen: " << param.min_full_attn_seqlen << std::endl;
};
};
// threshold for different dtypes
template <typename DataType>
auto get_elimit()
@@ -152,9 +192,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
int seqlen = 0; // means total seq lengths for jagged
int max_seqlen = 0;
// supplement the sequence of lengths to have num_batch values
if(!num_targets.empty() && static_cast<int>(num_targets.size()) < num_batch)
{
auto last_val = num_targets.back();
for(int i = num_targets.size(); i < num_batch; i++)
num_targets.push_back(last_val);
};
if(is_jagged)
{
assert(num_batch == seq_lengths.size());
assert(num_batch >= seq_lengths.size());
// supplement the sequence of lengths to have num_batch values
if(static_cast<int>(seq_lengths.size()) < num_batch)
{
auto last_len = seq_lengths.back();
for(int i = seq_lengths.size(); i < num_batch; i++)
seq_lengths.push_back(last_len);
};
seq_offsets.push_back(0);
for(size_t i = 0; i < seq_lengths.size(); i++)
@@ -190,10 +248,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(!num_targets.empty())
{
assert(1 == num_targets.size());
assert(num_batch == num_targets.size());
assert(seqlen - num_targets[0] >= min_full_attn_seqlen);
assert(seqlen - num_targets[0] >= contextual_seqlen);
for(size_t i = 0; i < seq_lengths.size(); i++)
{
assert(seqlen - num_targets[i] >= min_full_attn_seqlen);
assert(seqlen - num_targets[i] >= contextual_seqlen);
};
}
else
{
@@ -245,7 +306,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
params.q_ptr = q_dev.GetDeviceBuffer();
params.k_ptr = k_dev.GetDeviceBuffer();
params.v_ptr = v_dev.GetDeviceBuffer();
params.bias_ptr = nullptr;
params.bias_ptr = nullptr; // bias is not supported at present
params.o_ptr = o_dev.GetDeviceBuffer();
params.hdim_qk = hdim_qk;
params.hdim_v = hdim_v;
@@ -278,7 +339,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
params.q_ptr = q_dev.GetDeviceBuffer();
params.k_ptr = k_dev.GetDeviceBuffer();
params.v_ptr = v_dev.GetDeviceBuffer();
params.bias_ptr = nullptr;
params.bias_ptr = nullptr; // bias is not supported at present
params.o_ptr = o_dev.GetDeviceBuffer();
params.hdim_qk = hdim_qk;
params.hdim_v = hdim_v;
@@ -309,6 +370,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
params.philox_offset = 0UL;
};
show_hstu_attention_fwd_param(std::cout, params);
hipStream_t stream;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
@@ -348,7 +411,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
v_host,
o_host_ref,
num_batch,
1.0f,
1.0f / std::sqrt(params.hdim_qk),
seq_offsets,
num_targets,
window_size,
@@ -393,7 +456,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto ms = timer.duration() / 20.f;
std::cout << "Average execution time of the gather_attention operator is " << ms
std::cout << "Average execution time of the hstu_attention operator is " << ms
<< " milli-seconds" << std::endl;
}