Add init_qkv and dump_output example parameters for easier debugging

This commit is contained in:
Qianfeng Zhang
2025-05-28 15:33:54 +00:00
parent 10c35125d2
commit 68a5ab8ff8
2 changed files with 47 additions and 7 deletions

View File

@@ -52,8 +52,10 @@
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
.insert("perf", "0", "weather measure execution time or not");
.insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");
```

View File

@@ -48,6 +48,28 @@ void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems)
}
}
template <typename T>
void readDataToBufferFromFile(T* data, size_t dataNumItems, const std::string& fileName)
{
std::ifstream infile(fileName, std::ios::binary);
if(infile)
{
try
{
infile.read(reinterpret_cast<char*>(data), dataNumItems * sizeof(T));
infile.close();
}
catch(const std::runtime_error& e)
{
throw e;
};
}
else
{
throw std::runtime_error("could not open the file for reading");
}
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
@@ -86,8 +108,10 @@ auto create_args(int argc, char* argv[])
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
.insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes")
.insert("perf", "0", "weather measure execution time or not");
.insert("perf", "0", "weather measure execution time or not")
.insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");
// clang-format on
bool result = arg_parser.parse(argc, argv);
@@ -199,8 +223,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
int seed = arg_parser.get_int("seed");
bool measure_perf = static_cast<bool>(arg_parser.get_int("perf"));
bool dump_output = static_cast<bool>(arg_parser.get_int("dump_output"));
bool save_mask = static_cast<bool>(arg_parser.get_int("save_mask"));
bool save_mask = static_cast<bool>(arg_parser.get_int("save_mask"));
bool initialize_qkv = static_cast<bool>(arg_parser.get_int("init_qkv"));
std::string str_of_targets = arg_parser.get_str("targets");
std::vector<int> num_targets = get_integers_from_string(str_of_targets);
@@ -333,9 +359,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
save_mask ? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_seqlen, max_seqlen}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(v_host);
if(!initialize_qkv)
{
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(v_host);
}
else
{
readDataToBufferFromFile(q_host.data(), q_host.get_element_space_size(), "q.dat");
readDataToBufferFromFile(k_host.data(), k_host.get_element_space_size(), "k.dat");
readDataToBufferFromFile(v_host.data(), v_host.get_element_space_size(), "v.dat");
};
ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes());
@@ -486,8 +521,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_dev.FromDevice(o_host.data());
// dumpBufferToFile("output_dev.dat", o_host.data(), o_host.get_element_space_size());
// dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size());
if(dump_output)
{
dumpBufferToFile("output_dev.dat", o_host.data(), o_host.get_element_space_size());
dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size());
}
if(save_mask)
dumpBufferToFile(