mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Various fixes
This commit is contained in:
@@ -343,8 +343,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
};
|
||||
calculate_cumulative(eff_query_lens, cu_query_lens);
|
||||
|
||||
ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size());
|
||||
ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size());
|
||||
ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size() * sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size() * sizeof(ck_tile::index_t));
|
||||
|
||||
seq_lens_buf.ToDevice(eff_kv_lens.data());
|
||||
query_start_len_buf.ToDevice(cu_query_lens.data());
|
||||
@@ -525,31 +525,40 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
|
||||
o_buf.FromDevice(o.data());
|
||||
|
||||
// const auto [rtol, atol] = [&] {
|
||||
// if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
|
||||
// return std::make_tuple(1e-3, 1e-3);
|
||||
// else
|
||||
// return std::make_tuple(1e-2, 1e-2);
|
||||
// }();
|
||||
const auto [rtol, atol] = [&] {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
|
||||
return std::make_tuple(1e-3, 1e-3);
|
||||
else
|
||||
return std::make_tuple(1e-2, 1e-2);
|
||||
}();
|
||||
|
||||
// Print some of the output data for debugging
|
||||
std::cout << "\nFirst few elements of output tensor o:" << std::endl;
|
||||
for(int b = 0; b < std::min(2, static_cast<int>(problem.batch)); ++b) {
|
||||
std::cout << "Batch " << b << ":" << std::endl;
|
||||
for(int s = 0; s < std::min(5, static_cast<int>(eff_query_lens[b])); ++s) {
|
||||
for(int h = 0; h < std::min(2, static_cast<int>(problem.nhead_q)); ++h) {
|
||||
for(int d = 0; d < std::min(4, static_cast<int>(problem.hdim)); ++d) {
|
||||
std::cout << "o[" << b << "][" << s << "][" << h << "][" << d << "] = "
|
||||
<< static_cast<float>(o(b, s, h, d))
|
||||
<< std::endl;
|
||||
size_t total = static_cast<size_t>(problem.num_tokens) *
|
||||
static_cast<size_t>(problem.nhead_q) *
|
||||
static_cast<size_t>(problem.hdim);
|
||||
|
||||
size_t nonzero = 0;
|
||||
|
||||
for (int b = 0; b < problem.batch; ++b) {
|
||||
for (int s = 0; s < eff_query_lens[b]; ++s) {
|
||||
for (int h = 0; h < problem.nhead_q; ++h) {
|
||||
for (int d = 0; d < problem.hdim; ++d) {
|
||||
if (static_cast<float>(o(b, s, h, d)) != 0.0f) {
|
||||
nonzero++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
return 1; // ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
|
||||
|
||||
float percent = (total > 0)
|
||||
? (100.0f * static_cast<float>(nonzero) / static_cast<float>(total))
|
||||
: 0.0f;
|
||||
|
||||
std::cout << "\nNon-zero elements in output tensor o: "
|
||||
<< nonzero << " / " << total
|
||||
<< " (" << percent << "%)\n";
|
||||
|
||||
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -124,7 +124,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
"argument num_queries_per_kv must equal compiled num_queries_per_kv");
|
||||
assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE &&
|
||||
"argument BLOCK_SIZE must equal compiled BLOCK_SIZE");
|
||||
assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv &&
|
||||
assert(BLOCK_Q == BLOCK_M / args.num_queries_per_kv &&
|
||||
"BLOCK_Q must equal BLOCK_M / num_queries_per_kv");
|
||||
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
|
||||
Reference in New Issue
Block a user