mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Add -save_mask option to the example to output int8 mask tensor
This commit is contained in:
@@ -51,6 +51,7 @@
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
|
||||
.insert("perf", "0", "weather measure execution time or not");
|
||||
```
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems)
|
||||
{
|
||||
outFile.write(reinterpret_cast<char*>(data), dataNumItems * sizeof(T));
|
||||
outFile.close();
|
||||
printf("Wrote output to file %s\n", fileName);
|
||||
printf("Write output to file %s\n", fileName);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -84,6 +84,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
|
||||
.insert("perf", "0", "weather measure execution time or not");
|
||||
// clang-format on
|
||||
|
||||
@@ -197,6 +198,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int seed = arg_parser.get_int("seed");
|
||||
bool measure_perf = static_cast<bool>(arg_parser.get_int("perf"));
|
||||
|
||||
bool save_mask = static_cast<bool>(arg_parser.get_int("save_mask"));
|
||||
|
||||
std::string str_of_targets = arg_parser.get_str("targets");
|
||||
std::vector<int> num_targets = get_integers_from_string(str_of_targets);
|
||||
|
||||
@@ -260,7 +263,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else
|
||||
{
|
||||
assert(1 == seq_lengths.size());
|
||||
seqlen = seq_lengths[0];
|
||||
seqlen = seq_lengths[0];
|
||||
max_seqlen = seqlen;
|
||||
|
||||
if(!num_targets.empty())
|
||||
{
|
||||
@@ -310,6 +314,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<InOutDataType> o_host_ref(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, num_head, hdim_v});
|
||||
|
||||
ck_tile::HostTensor<int8_t> mask_host(
|
||||
save_mask ? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_seqlen, max_seqlen}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(v_host);
|
||||
@@ -447,6 +455,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
mask_host,
|
||||
num_batch,
|
||||
1.0f / std::sqrt(params.hdim_qk),
|
||||
is_jagged ? max_seqlen : seqlen,
|
||||
@@ -465,6 +474,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// dumpBufferToFile("output_dev.dat", o_host.data(), o_host.get_element_space_size());
|
||||
// dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size());
|
||||
|
||||
if(save_mask)
|
||||
dumpBufferToFile("ck_hstu_mask.dat", mask_host.data(), mask_host.get_element_space_size());
|
||||
|
||||
auto [rtol, atol] = get_elimit<InOutDataType>();
|
||||
|
||||
res = ck_tile::check_err(
|
||||
|
||||
@@ -40,6 +40,7 @@ struct reference_hstu_attention
|
||||
const HostTensor<InOutDataType>& k_batch_seq_nhead_hdim,
|
||||
const HostTensor<InOutDataType>& v_batch_seq_nhead_hdim,
|
||||
HostTensor<InOutDataType>& o_batch_seq_nhead_hdim,
|
||||
HostTensor<int8_t>& mask_batch_nhead_seq_seq,
|
||||
int num_batch,
|
||||
float alpha,
|
||||
int max_seqlen,
|
||||
@@ -87,6 +88,14 @@ struct reference_hstu_attention
|
||||
assert(hdim_qk == k_batch_seq_nhead_hdim.get_lengths()[3]);
|
||||
assert(hdim_v == o_batch_seq_nhead_hdim.get_lengths()[3]);
|
||||
|
||||
bool save_mask = false;
|
||||
|
||||
if(static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[0]) == num_batch &&
|
||||
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[1]) == num_head &&
|
||||
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_seqlen &&
|
||||
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_seqlen)
|
||||
save_mask = true;
|
||||
|
||||
// check num_tagets
|
||||
assert(num_tagets.empty() || num_targets.size() == num_batch);
|
||||
|
||||
@@ -111,6 +120,15 @@ struct reference_hstu_attention
|
||||
seqlen, contextual_seqlen, num_target);
|
||||
}();
|
||||
|
||||
if(save_mask)
|
||||
{
|
||||
// initialize the mask
|
||||
for(int sq = 0; sq < max_seqlen; sq++)
|
||||
for(int sk = 0; sk < max_seqlen; sk++)
|
||||
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) =
|
||||
static_cast<int8_t>(mask.IsTokenPairInsideMask(sq, sk));
|
||||
}
|
||||
|
||||
// for all rows in the batch
|
||||
for(int sq = 0; sq < seqlen; sq++)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user