Add -save_mask option to the example to output int8 mask tensor

This commit is contained in:
Qianfeng Zhang
2025-05-14 01:54:45 +00:00
parent c3761c3bd6
commit 5b0a2618fd
3 changed files with 33 additions and 2 deletions

View File

@@ -51,6 +51,7 @@
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
.insert("perf", "0", "weather measure execution time or not");
```

View File

@@ -40,7 +40,7 @@ void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems)
{
outFile.write(reinterpret_cast<char*>(data), dataNumItems * sizeof(T));
outFile.close();
printf("Wrote output to file %s\n", fileName);
printf("Write output to file %s\n", fileName);
}
else
{
@@ -84,6 +84,7 @@ auto create_args(int argc, char* argv[])
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
.insert("perf", "0", "weather measure execution time or not");
// clang-format on
@@ -197,6 +198,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
int seed = arg_parser.get_int("seed");
bool measure_perf = static_cast<bool>(arg_parser.get_int("perf"));
bool save_mask = static_cast<bool>(arg_parser.get_int("save_mask"));
std::string str_of_targets = arg_parser.get_str("targets");
std::vector<int> num_targets = get_integers_from_string(str_of_targets);
@@ -260,7 +263,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
else
{
assert(1 == seq_lengths.size());
seqlen = seq_lengths[0];
seqlen = seq_lengths[0];
max_seqlen = seqlen;
if(!num_targets.empty())
{
@@ -310,6 +314,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<InOutDataType> o_host_ref(
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, num_head, hdim_v});
ck_tile::HostTensor<int8_t> mask_host(
save_mask ? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_seqlen, max_seqlen}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(v_host);
@@ -447,6 +455,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
k_host,
v_host,
o_host_ref,
mask_host,
num_batch,
1.0f / std::sqrt(params.hdim_qk),
is_jagged ? max_seqlen : seqlen,
@@ -465,6 +474,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
// dumpBufferToFile("output_dev.dat", o_host.data(), o_host.get_element_space_size());
// dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size());
if(save_mask)
dumpBufferToFile("ck_hstu_mask.dat", mask_host.data(), mask_host.get_element_space_size());
auto [rtol, atol] = get_elimit<InOutDataType>();
res = ck_tile::check_err(

View File

@@ -40,6 +40,7 @@ struct reference_hstu_attention
const HostTensor<InOutDataType>& k_batch_seq_nhead_hdim,
const HostTensor<InOutDataType>& v_batch_seq_nhead_hdim,
HostTensor<InOutDataType>& o_batch_seq_nhead_hdim,
HostTensor<int8_t>& mask_batch_nhead_seq_seq,
int num_batch,
float alpha,
int max_seqlen,
@@ -87,6 +88,14 @@ struct reference_hstu_attention
assert(hdim_qk == k_batch_seq_nhead_hdim.get_lengths()[3]);
assert(hdim_v == o_batch_seq_nhead_hdim.get_lengths()[3]);
bool save_mask = false;
if(static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[0]) == num_batch &&
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[1]) == num_head &&
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_seqlen &&
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_seqlen)
save_mask = true;
// check num_tagets
assert(num_tagets.empty() || num_targets.size() == num_batch);
@@ -111,6 +120,15 @@ struct reference_hstu_attention
seqlen, contextual_seqlen, num_target);
}();
if(save_mask)
{
// initialize the mask
for(int sq = 0; sq < max_seqlen; sq++)
for(int sk = 0; sk < max_seqlen; sk++)
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) =
static_cast<int8_t>(mask.IsTokenPairInsideMask(sq, sk));
}
// for all rows in the batch
for(int sq = 0; sq < seqlen; sq++)
{