mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Fix and change in example
This commit is contained in:
@@ -62,7 +62,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("hdim_qk", "64", "headdim size of Q/K")
|
||||
.insert("hdim_v", "64", "headdim size of V/O")
|
||||
.insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor")
|
||||
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("causal", "1", "enable causal mask or not")
|
||||
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
@@ -95,13 +95,53 @@ static std::vector<int> get_integers_from_string(std::string lengthsStr)
|
||||
};
|
||||
|
||||
std::string sliceStr = lengthsStr.substr(pos);
|
||||
int len = std::stoi(sliceStr);
|
||||
|
||||
lengths.push_back(len);
|
||||
if(!sliceStr.empty())
|
||||
{
|
||||
int len = std::stoi(sliceStr);
|
||||
|
||||
lengths.push_back(len);
|
||||
};
|
||||
|
||||
return (lengths);
|
||||
};
|
||||
|
||||
static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdParams& param)
|
||||
{
|
||||
if(param.is_jagged)
|
||||
{
|
||||
os << "Jagged inputs used! " << std::endl;
|
||||
os << "use causal: " << param.use_causal << std::endl;
|
||||
os << "Num of batches: " << param.num_batch << std::endl;
|
||||
os << "Num of heads: " << param.num_head << std::endl;
|
||||
os << "QK hdim: " << param.hdim_qk << " V hdim: " << param.hdim_v << std::endl;
|
||||
os << "Q/K/V/O seq stride: " << param.seq_stride_q << " " << param.seq_stride_k << " "
|
||||
<< param.seq_stride_v << " " << param.seq_stride_o << std::endl;
|
||||
os << "Q/K/V/O nhead stride: " << param.nhead_stride_q << " " << param.nhead_stride_k << " "
|
||||
<< param.nhead_stride_v << " " << param.nhead_stride_o << std::endl;
|
||||
os << "contextual_seqlen: " << param.contextual_seqlen << std::endl;
|
||||
os << "window_size: " << param.window_size << std::endl;
|
||||
os << "min_full_attn_seqlen: " << param.min_full_attn_seqlen << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
os << "Batched inputs used! " << std::endl;
|
||||
os << "use causal: " << param.use_causal << std::endl;
|
||||
os << "Num of batches: " << param.num_batch << std::endl;
|
||||
os << "Num of heads: " << param.num_head << std::endl;
|
||||
os << "QK hdim: " << param.hdim_qk << " V hdim: " << param.hdim_v << std::endl;
|
||||
os << "Q/K/V/O seq stride: " << param.seq_stride_q << " " << param.seq_stride_k << " "
|
||||
<< param.seq_stride_v << " " << param.seq_stride_o << std::endl;
|
||||
os << "Q/K/V/O nhead stride: " << param.nhead_stride_q << " " << param.nhead_stride_k << " "
|
||||
<< param.nhead_stride_v << " " << param.nhead_stride_o << std::endl;
|
||||
os << "Q/K/V/O batch stride: " << param.batch_stride_q << " " << param.batch_stride_k << " "
|
||||
<< param.batch_stride_v << " " << param.batch_stride_o << std::endl;
|
||||
os << "contextual_seqlen: " << param.contextual_seqlen << std::endl;
|
||||
os << "window_size: " << param.window_size << std::endl;
|
||||
os << "min_full_attn_seqlen: " << param.min_full_attn_seqlen << std::endl;
|
||||
};
|
||||
};
|
||||
|
||||
// threshold for different dtypes
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
@@ -152,9 +192,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int seqlen = 0; // means total seq lengths for jagged
|
||||
int max_seqlen = 0;
|
||||
|
||||
// supplement the sequence of lengths to have num_batch values
|
||||
if(!num_targets.empty() && static_cast<int>(num_targets.size()) < num_batch)
|
||||
{
|
||||
auto last_val = num_targets.back();
|
||||
|
||||
for(int i = num_targets.size(); i < num_batch; i++)
|
||||
num_targets.push_back(last_val);
|
||||
};
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
assert(num_batch == seq_lengths.size());
|
||||
assert(num_batch >= seq_lengths.size());
|
||||
|
||||
// supplement the sequence of lengths to have num_batch values
|
||||
if(static_cast<int>(seq_lengths.size()) < num_batch)
|
||||
{
|
||||
auto last_len = seq_lengths.back();
|
||||
|
||||
for(int i = seq_lengths.size(); i < num_batch; i++)
|
||||
seq_lengths.push_back(last_len);
|
||||
};
|
||||
|
||||
seq_offsets.push_back(0);
|
||||
for(size_t i = 0; i < seq_lengths.size(); i++)
|
||||
@@ -190,10 +248,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(!num_targets.empty())
|
||||
{
|
||||
assert(1 == num_targets.size());
|
||||
assert(num_batch == num_targets.size());
|
||||
|
||||
assert(seqlen - num_targets[0] >= min_full_attn_seqlen);
|
||||
assert(seqlen - num_targets[0] >= contextual_seqlen);
|
||||
for(size_t i = 0; i < seq_lengths.size(); i++)
|
||||
{
|
||||
assert(seqlen - num_targets[i] >= min_full_attn_seqlen);
|
||||
assert(seqlen - num_targets[i] >= contextual_seqlen);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -245,7 +306,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
params.bias_ptr = nullptr;
|
||||
params.bias_ptr = nullptr; // bias is not supported at present
|
||||
params.o_ptr = o_dev.GetDeviceBuffer();
|
||||
params.hdim_qk = hdim_qk;
|
||||
params.hdim_v = hdim_v;
|
||||
@@ -278,7 +339,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
params.bias_ptr = nullptr;
|
||||
params.bias_ptr = nullptr; // bias is not supported at present
|
||||
params.o_ptr = o_dev.GetDeviceBuffer();
|
||||
params.hdim_qk = hdim_qk;
|
||||
params.hdim_v = hdim_v;
|
||||
@@ -309,6 +370,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
params.philox_offset = 0UL;
|
||||
};
|
||||
|
||||
show_hstu_attention_fwd_param(std::cout, params);
|
||||
|
||||
hipStream_t stream;
|
||||
|
||||
HIP_CHECK_ERROR(hipStreamCreate(&stream));
|
||||
@@ -348,7 +411,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
v_host,
|
||||
o_host_ref,
|
||||
num_batch,
|
||||
1.0f,
|
||||
1.0f / std::sqrt(params.hdim_qk),
|
||||
seq_offsets,
|
||||
num_targets,
|
||||
window_size,
|
||||
@@ -393,7 +456,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
auto ms = timer.duration() / 20.f;
|
||||
|
||||
std::cout << "Average execution time of the gather_attention operator is " << ms
|
||||
std::cout << "Average execution time of the hstu_attention operator is " << ms
|
||||
<< " milli-seconds" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user