Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/composable_kernel into tianxing/unified-attention

This commit is contained in:
Tianxing Wu
2025-11-17 10:06:05 +00:00

View File

@@ -72,6 +72,13 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
return std::make_pair(result, arg_parser);
}
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
struct Problem
{
explicit Problem(const ck_tile::ArgParser& args)
@@ -176,105 +183,105 @@ auto generate_qkv(const Problem& problem,
return std::make_tuple(q, k, v);
}
// namespace host {
// template <typename AccDataType,
// typename PDataType,
// typename QDataType,
// typename KDataType,
// typename VDataType,
// typename ODataType,
// typename QElementOp,
// typename KElementOp,
// typename VElementOp,
// typename SAccElementOp>
// CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
// const ck_tile::HostTensor<KDataType>& k_bshd,
// const ck_tile::HostTensor<VDataType>& v_bshd,
// const mask_info& mask,
// ck_tile::HostTensor<ODataType>& o_bshd,
// const QElementOp& q_element_op = {},
// const KElementOp& k_element_op = {},
// const VElementOp& v_element_op = {},
// const SAccElementOp& s_acc_element_op = {})
// {
// const int batch_size = q_bshd.mDesc.get_lengths()[0];
// const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
// const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
// const int nhead_q = q_bshd.mDesc.get_lengths()[2];
// const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
// const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
// const int hdim_v = v_bshd.mDesc.get_lengths()[3];
namespace host {
template <typename AccDataType,
typename PDataType,
typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename QElementOp,
typename KElementOp,
typename VElementOp,
typename SAccElementOp>
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
const ck_tile::HostTensor<KDataType>& k_bshd,
const ck_tile::HostTensor<VDataType>& v_bshd,
const mask_info& mask,
ck_tile::HostTensor<ODataType>& o_bshd,
const QElementOp& q_element_op = {},
const KElementOp& k_element_op = {},
const VElementOp& v_element_op = {},
const SAccElementOp& s_acc_element_op = {})
{
const int batch_size = q_bshd.mDesc.get_lengths()[0];
const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
const int nhead_q = q_bshd.mDesc.get_lengths()[2];
const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
const int hdim_v = v_bshd.mDesc.get_lengths()[3];
// const int nr = nhead_q / nhead_kv;
const int nr = nhead_q / nhead_kv;
// ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
// ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
// ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
// ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
// ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
// ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
// // do computation for each batch
// for(int b = 0; b < batch_size; ++b)
// {
// // copy per-batch data from input tensors
// // clang-format off
// q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] ,
// idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1],
// idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) =
// v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
// // clang-format on
// ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
// q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
// do computation for each batch
for(int b = 0; b < batch_size; ++b)
{
// copy per-batch data from input tensors
// clang-format off
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] ,
idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1],
idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) =
v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
// clang-format on
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
// if(mask.type == mask_enum::no_mask)
// {
// ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
// }
// else if(mask.type == mask_enum::window_generic)
// {
// ck_tile::reference_batched_masking(
// s_host_ref,
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
// mask.left, mask.right, seqlen_q, seqlen_kv));
// }
// else
// {
// // if left window size is negative, means causal
// // else means generic (for current batch)
// if(mask.left < 0)
// ck_tile::reference_batched_masking(
// s_host_ref,
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
// mask.left,
// mask.right,
// seqlen_q,
// seqlen_kv,
// mask.type == mask_enum::mask_top_left));
// else
// ck_tile::reference_batched_masking(
// s_host_ref,
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
// mask.left,
// mask.right,
// seqlen_q,
// seqlen_kv,
// mask.type == mask_enum::mask_top_left));
// }
if(mask.type == mask_enum::no_mask)
{
ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
}
else if(mask.type == mask_enum::window_generic)
{
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, seqlen_q, seqlen_kv));
}
else
{
// if left window size is negative, means causal
// else means generic (for current batch)
if(mask.left < 0)
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
else
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
}
// ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
// s_host_ref, p_host_ref, ck_tile::identity{});
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
s_host_ref, p_host_ref, ck_tile::identity{});
// ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
// p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
// // copy resulting per-batch data to the output tensor
// o_host_ref.ForEach(
// [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
// }
// }
// } // namespace host
// copy resulting per-batch data to the output tensor
o_host_ref.ForEach(
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
}
}
} // namespace host
template <typename DataType>
bool run_impl(const Problem& problem, const RunConfig& run_config)
@@ -425,111 +432,108 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
return false;
}
// std::size_t flop = [&] {
// if(problem.mask.type == mask_enum::no_mask)
// {
// return 4 * args.num_tokens * problem.nhead_q *
// problem.hdim;
// }
// else
// {
// /// FIXME: Use a more accurate method; for now, were just dividing the flop by 2.
// return 2 * args.num_tokens * problem.nhead_q *
// problem.hdim;
// }
// }();
std::size_t flop = [&] {
if(problem.mask.type == mask_enum::no_mask)
{
return 4 * args.num_tokens * problem.nhead_q * problem.hdim;
}
else
{
/// FIXME: Use a more accurate method; for now, were just dividing the flop by 2.
return 2 * args.num_tokens * problem.nhead_q * problem.hdim;
}
}();
// TODO fix this
// std::size_t flop = 1;
// float tflops = static_cast<float>(flop) / 1.e9 / time;
float tflops = static_cast<float>(flop) / 1.e9 / time;
// std::cout << "[" << problem.data_type << "|";
// std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
// << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim
// << ", scale_s:" << problem.sacle_s << ", mask:" << problem.mask << std::fixed
// << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) <<
// tflops
// << " TFlops" << std::endl;
std::cout << "[" << problem.data_type << "|";
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
<< ", d:" << problem.hdim << ", mask:" << problem.mask << std::fixed << ", "
<< std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops
<< " TFlops" << std::endl;
// if(!run_config.verify)
// {
// return true;
// }
// transpose tensor descriptors from bhsd to bshd if necessary
// if(problem.input_layout != TensorLayout::bshd)
// {
// q = q.transpose({0, 2, 1, 3});
// k = k.transpose({0, 2, 1, 3});
// v = v.transpose({0, 2, 1, 3});
// }
// ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
// if(problem.output_layout != TensorLayout::bshd)
// {
// o_ref = o_ref.transpose({0, 2, 1, 3});
// }
// If variable lengths are provided, compute per-batch references
// variable lengths are provided -> compute per-batch references
// with the effective lengths; else compute a single full reference.
// Variable-length aware verification: zero-fill padded region and only compute valid part.
// o_ref.SetZero();
ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
o_ref.SetZero();
// for(int b = 0; b < problem.batch; ++b)
// {
// const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
// const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
for(int b = 0; b < problem.batch; ++b)
{
const ck_tile::index_t seqlen_q_eff = eff_query_lens[b];
const ck_tile::index_t seqlen_kv_eff = eff_kv_lens[b];
// if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
// continue;
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
continue;
// // Slice current batch from inputs (bshd) and build single-batch tensors
// ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
// ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
// ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
// ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
// Slice current batch from inputs (bshd) and build single-batch tensors
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
// // Copy effective region
// q_b.ForEach([&](auto& self, auto idx) {
// // idx: [0, s, h, d]
// self(idx) = q(b, idx[1], idx[2], idx[3]);
// });
// k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
// v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
// Copy effective region
q_b.ForEach([&](auto& self, auto idx) {
// idx: [0, s, h, d]
self(idx) = q(b, idx[1], idx[2], idx[3]);
});
k_b.ForEach([&](auto& self, auto idx) {
// kv cache is paged
ck_tile::index_t table_col = int(idx[1] / problem.BLOCK_SIZE);
ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col;
ck_tile::index_t block_idx = block_tables_host[block_table_offset];
// // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
// host::fmha_fwd<float, DataType>(q_b,
// k_b,
// v_b,
// problem.mask,
// o_b,
// ck_tile::identity{},
// ck_tile::identity{},
// ck_tile::identity{},
// ck_tile::scales{problem.scale_s});
self(idx) = k(block_idx, idx[1] % problem.BLOCK_SIZE, idx[2], idx[3]);
});
v_b.ForEach([&](auto& self, auto idx) {
ck_tile::index_t table_col = int(idx[1] / problem.BLOCK_SIZE);
ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col;
ck_tile::index_t block_idx = block_tables_host[block_table_offset];
// // Scatter into o_ref's bshd descriptor memory
// for(int s = 0; s < seqlen_q_eff; ++s)
// {
// for(int h = 0; h < problem.nhead_q; ++h)
// {
// for(int d = 0; d < problem.hdim; ++d)
// {
// o_ref(b, s, h, d) = o_b(0, s, h, d);
// }
// }
// }
// }
self(idx) = v(block_idx, idx[1] % problem.BLOCK_SIZE, idx[2], idx[3]);
});
// v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
// ck_tile::HostTensor<DataType> o(problem.get_output_shape());
// o_buf.FromDevice(o.data());
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
host::fmha_fwd<float, DataType>(q_b,
k_b,
v_b,
problem.mask,
o_b,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.scale_s});
// const auto [rtol, atol] = [&] {
// if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
// return std::make_tuple(1e-3, 1e-3);
// else
// return std::make_tuple(1e-2, 1e-2);
// }();
// return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
// Scatter into o_ref's bshd descriptor memory
for(int s = 0; s < seqlen_q_eff; ++s)
{
for(int h = 0; h < problem.nhead_q; ++h)
{
for(int d = 0; d < problem.hdim; ++d)
{
o_ref(b, s, h, d) = o_b(0, s, h, d);
}
}
}
}
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
o_buf.FromDevice(o.data());
const auto [rtol, atol] = [&] {
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
return std::make_tuple(1e-3, 1e-3);
else
return std::make_tuple(1e-2, 1e-2);
}();
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
return true;
}