mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Add init_qkv and dump_output example parameters for easier debugging
This commit is contained in:
@@ -52,8 +52,10 @@
|
||||
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
|
||||
.insert("perf", "0", "weather measure execution time or not");
|
||||
.insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");
|
||||
```
|
||||
|
||||
|
||||
@@ -48,6 +48,28 @@ void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void readDataToBufferFromFile(T* data, size_t dataNumItems, const std::string& fileName)
|
||||
{
|
||||
std::ifstream infile(fileName, std::ios::binary);
|
||||
if(infile)
|
||||
{
|
||||
try
|
||||
{
|
||||
infile.read(reinterpret_cast<char*>(data), dataNumItems * sizeof(T));
|
||||
infile.close();
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
throw e;
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("could not open the file for reading");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
@@ -86,8 +108,10 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
|
||||
.insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes")
|
||||
.insert("perf", "0", "weather measure execution time or not");
|
||||
.insert("perf", "0", "weather measure execution time or not")
|
||||
.insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");
|
||||
// clang-format on
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
@@ -199,8 +223,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
int seed = arg_parser.get_int("seed");
|
||||
bool measure_perf = static_cast<bool>(arg_parser.get_int("perf"));
|
||||
bool dump_output = static_cast<bool>(arg_parser.get_int("dump_output"));
|
||||
|
||||
bool save_mask = static_cast<bool>(arg_parser.get_int("save_mask"));
|
||||
bool save_mask = static_cast<bool>(arg_parser.get_int("save_mask"));
|
||||
bool initialize_qkv = static_cast<bool>(arg_parser.get_int("init_qkv"));
|
||||
|
||||
std::string str_of_targets = arg_parser.get_str("targets");
|
||||
std::vector<int> num_targets = get_integers_from_string(str_of_targets);
|
||||
@@ -333,9 +359,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
save_mask ? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_seqlen, max_seqlen}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(v_host);
|
||||
if(!initialize_qkv)
|
||||
{
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(v_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
readDataToBufferFromFile(q_host.data(), q_host.get_element_space_size(), "q.dat");
|
||||
readDataToBufferFromFile(k_host.data(), k_host.get_element_space_size(), "k.dat");
|
||||
readDataToBufferFromFile(v_host.data(), v_host.get_element_space_size(), "v.dat");
|
||||
};
|
||||
|
||||
ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes());
|
||||
@@ -486,8 +521,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
o_dev.FromDevice(o_host.data());
|
||||
|
||||
// dumpBufferToFile("output_dev.dat", o_host.data(), o_host.get_element_space_size());
|
||||
// dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size());
|
||||
if(dump_output)
|
||||
{
|
||||
dumpBufferToFile("output_dev.dat", o_host.data(), o_host.get_element_space_size());
|
||||
dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size());
|
||||
}
|
||||
|
||||
if(save_mask)
|
||||
dumpBufferToFile(
|
||||
|
||||
Reference in New Issue
Block a user