mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/composable_kernel into tianxing/unified-attention
This commit is contained in:
@@ -72,6 +72,13 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
return std::make_pair(result, arg_parser);
|
||||
}
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
struct Problem
|
||||
{
|
||||
explicit Problem(const ck_tile::ArgParser& args)
|
||||
@@ -176,105 +183,105 @@ auto generate_qkv(const Problem& problem,
|
||||
return std::make_tuple(q, k, v);
|
||||
}
|
||||
|
||||
// namespace host {
|
||||
// template <typename AccDataType,
|
||||
// typename PDataType,
|
||||
// typename QDataType,
|
||||
// typename KDataType,
|
||||
// typename VDataType,
|
||||
// typename ODataType,
|
||||
// typename QElementOp,
|
||||
// typename KElementOp,
|
||||
// typename VElementOp,
|
||||
// typename SAccElementOp>
|
||||
// CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
|
||||
// const ck_tile::HostTensor<KDataType>& k_bshd,
|
||||
// const ck_tile::HostTensor<VDataType>& v_bshd,
|
||||
// const mask_info& mask,
|
||||
// ck_tile::HostTensor<ODataType>& o_bshd,
|
||||
// const QElementOp& q_element_op = {},
|
||||
// const KElementOp& k_element_op = {},
|
||||
// const VElementOp& v_element_op = {},
|
||||
// const SAccElementOp& s_acc_element_op = {})
|
||||
// {
|
||||
// const int batch_size = q_bshd.mDesc.get_lengths()[0];
|
||||
// const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
|
||||
// const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
|
||||
// const int nhead_q = q_bshd.mDesc.get_lengths()[2];
|
||||
// const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
|
||||
// const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
|
||||
// const int hdim_v = v_bshd.mDesc.get_lengths()[3];
|
||||
namespace host {
|
||||
template <typename AccDataType,
|
||||
typename PDataType,
|
||||
typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename ODataType,
|
||||
typename QElementOp,
|
||||
typename KElementOp,
|
||||
typename VElementOp,
|
||||
typename SAccElementOp>
|
||||
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
|
||||
const ck_tile::HostTensor<KDataType>& k_bshd,
|
||||
const ck_tile::HostTensor<VDataType>& v_bshd,
|
||||
const mask_info& mask,
|
||||
ck_tile::HostTensor<ODataType>& o_bshd,
|
||||
const QElementOp& q_element_op = {},
|
||||
const KElementOp& k_element_op = {},
|
||||
const VElementOp& v_element_op = {},
|
||||
const SAccElementOp& s_acc_element_op = {})
|
||||
{
|
||||
const int batch_size = q_bshd.mDesc.get_lengths()[0];
|
||||
const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
|
||||
const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
|
||||
const int nhead_q = q_bshd.mDesc.get_lengths()[2];
|
||||
const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
|
||||
const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
|
||||
const int hdim_v = v_bshd.mDesc.get_lengths()[3];
|
||||
|
||||
// const int nr = nhead_q / nhead_kv;
|
||||
const int nr = nhead_q / nhead_kv;
|
||||
|
||||
// ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
|
||||
// ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
|
||||
// ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
|
||||
// ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
|
||||
ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
|
||||
ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
|
||||
ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
|
||||
ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
|
||||
|
||||
// ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
|
||||
// ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
|
||||
ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
|
||||
ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
|
||||
|
||||
// // do computation for each batch
|
||||
// for(int b = 0; b < batch_size; ++b)
|
||||
// {
|
||||
// // copy per-batch data from input tensors
|
||||
// // clang-format off
|
||||
// q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] ,
|
||||
// idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1],
|
||||
// idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) =
|
||||
// v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
|
||||
// // clang-format on
|
||||
// ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
|
||||
// q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
|
||||
// do computation for each batch
|
||||
for(int b = 0; b < batch_size; ++b)
|
||||
{
|
||||
// copy per-batch data from input tensors
|
||||
// clang-format off
|
||||
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] ,
|
||||
idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1],
|
||||
idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) =
|
||||
v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
|
||||
// clang-format on
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
|
||||
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
|
||||
|
||||
// if(mask.type == mask_enum::no_mask)
|
||||
// {
|
||||
// ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
|
||||
// }
|
||||
// else if(mask.type == mask_enum::window_generic)
|
||||
// {
|
||||
// ck_tile::reference_batched_masking(
|
||||
// s_host_ref,
|
||||
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
// mask.left, mask.right, seqlen_q, seqlen_kv));
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// // if left window size is negative, means causal
|
||||
// // else means generic (for current batch)
|
||||
// if(mask.left < 0)
|
||||
// ck_tile::reference_batched_masking(
|
||||
// s_host_ref,
|
||||
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
// mask.left,
|
||||
// mask.right,
|
||||
// seqlen_q,
|
||||
// seqlen_kv,
|
||||
// mask.type == mask_enum::mask_top_left));
|
||||
// else
|
||||
// ck_tile::reference_batched_masking(
|
||||
// s_host_ref,
|
||||
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
// mask.left,
|
||||
// mask.right,
|
||||
// seqlen_q,
|
||||
// seqlen_kv,
|
||||
// mask.type == mask_enum::mask_top_left));
|
||||
// }
|
||||
if(mask.type == mask_enum::no_mask)
|
||||
{
|
||||
ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
|
||||
}
|
||||
else if(mask.type == mask_enum::window_generic)
|
||||
{
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left, mask.right, seqlen_q, seqlen_kv));
|
||||
}
|
||||
else
|
||||
{
|
||||
// if left window size is negative, means causal
|
||||
// else means generic (for current batch)
|
||||
if(mask.left < 0)
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
else
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
|
||||
// ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
|
||||
// s_host_ref, p_host_ref, ck_tile::identity{});
|
||||
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
|
||||
s_host_ref, p_host_ref, ck_tile::identity{});
|
||||
|
||||
// ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
|
||||
// p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
|
||||
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
|
||||
|
||||
// // copy resulting per-batch data to the output tensor
|
||||
// o_host_ref.ForEach(
|
||||
// [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
|
||||
// }
|
||||
// }
|
||||
// } // namespace host
|
||||
// copy resulting per-batch data to the output tensor
|
||||
o_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
|
||||
}
|
||||
}
|
||||
} // namespace host
|
||||
|
||||
template <typename DataType>
|
||||
bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
@@ -425,111 +432,108 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
return false;
|
||||
}
|
||||
|
||||
// std::size_t flop = [&] {
|
||||
// if(problem.mask.type == mask_enum::no_mask)
|
||||
// {
|
||||
// return 4 * args.num_tokens * problem.nhead_q *
|
||||
// problem.hdim;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2.
|
||||
// return 2 * args.num_tokens * problem.nhead_q *
|
||||
// problem.hdim;
|
||||
// }
|
||||
// }();
|
||||
std::size_t flop = [&] {
|
||||
if(problem.mask.type == mask_enum::no_mask)
|
||||
{
|
||||
return 4 * args.num_tokens * problem.nhead_q * problem.hdim;
|
||||
}
|
||||
else
|
||||
{
|
||||
/// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2.
|
||||
return 2 * args.num_tokens * problem.nhead_q * problem.hdim;
|
||||
}
|
||||
}();
|
||||
// TODO fix this
|
||||
// std::size_t flop = 1;
|
||||
// float tflops = static_cast<float>(flop) / 1.e9 / time;
|
||||
float tflops = static_cast<float>(flop) / 1.e9 / time;
|
||||
|
||||
// std::cout << "[" << problem.data_type << "|";
|
||||
// std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
|
||||
// << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim
|
||||
// << ", scale_s:" << problem.sacle_s << ", mask:" << problem.mask << std::fixed
|
||||
// << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) <<
|
||||
// tflops
|
||||
// << " TFlops" << std::endl;
|
||||
std::cout << "[" << problem.data_type << "|";
|
||||
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
|
||||
<< ", d:" << problem.hdim << ", mask:" << problem.mask << std::fixed << ", "
|
||||
<< std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops
|
||||
<< " TFlops" << std::endl;
|
||||
|
||||
// if(!run_config.verify)
|
||||
// {
|
||||
// return true;
|
||||
// }
|
||||
|
||||
// transpose tensor descriptors from bhsd to bshd if necessary
|
||||
// if(problem.input_layout != TensorLayout::bshd)
|
||||
// {
|
||||
// q = q.transpose({0, 2, 1, 3});
|
||||
// k = k.transpose({0, 2, 1, 3});
|
||||
// v = v.transpose({0, 2, 1, 3});
|
||||
// }
|
||||
|
||||
// ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
|
||||
// if(problem.output_layout != TensorLayout::bshd)
|
||||
// {
|
||||
// o_ref = o_ref.transpose({0, 2, 1, 3});
|
||||
// }
|
||||
|
||||
// If variable lengths are provided, compute per-batch references
|
||||
// variable lengths are provided -> compute per-batch references
|
||||
// with the effective lengths; else compute a single full reference.
|
||||
// Variable-length aware verification: zero-fill padded region and only compute valid part.
|
||||
// o_ref.SetZero();
|
||||
ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
|
||||
o_ref.SetZero();
|
||||
|
||||
// for(int b = 0; b < problem.batch; ++b)
|
||||
// {
|
||||
// const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
|
||||
// const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
|
||||
for(int b = 0; b < problem.batch; ++b)
|
||||
{
|
||||
const ck_tile::index_t seqlen_q_eff = eff_query_lens[b];
|
||||
const ck_tile::index_t seqlen_kv_eff = eff_kv_lens[b];
|
||||
|
||||
// if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
|
||||
// continue;
|
||||
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
|
||||
continue;
|
||||
|
||||
// // Slice current batch from inputs (bshd) and build single-batch tensors
|
||||
// ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
// ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
// ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
// ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
// Slice current batch from inputs (bshd) and build single-batch tensors
|
||||
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
|
||||
// // Copy effective region
|
||||
// q_b.ForEach([&](auto& self, auto idx) {
|
||||
// // idx: [0, s, h, d]
|
||||
// self(idx) = q(b, idx[1], idx[2], idx[3]);
|
||||
// });
|
||||
// k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
|
||||
// v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
|
||||
// Copy effective region
|
||||
q_b.ForEach([&](auto& self, auto idx) {
|
||||
// idx: [0, s, h, d]
|
||||
self(idx) = q(b, idx[1], idx[2], idx[3]);
|
||||
});
|
||||
k_b.ForEach([&](auto& self, auto idx) {
|
||||
// kv cache is paged
|
||||
ck_tile::index_t table_col = int(idx[1] / problem.BLOCK_SIZE);
|
||||
ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col;
|
||||
ck_tile::index_t block_idx = block_tables_host[block_table_offset];
|
||||
|
||||
// // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
|
||||
// host::fmha_fwd<float, DataType>(q_b,
|
||||
// k_b,
|
||||
// v_b,
|
||||
// problem.mask,
|
||||
// o_b,
|
||||
// ck_tile::identity{},
|
||||
// ck_tile::identity{},
|
||||
// ck_tile::identity{},
|
||||
// ck_tile::scales{problem.scale_s});
|
||||
self(idx) = k(block_idx, idx[1] % problem.BLOCK_SIZE, idx[2], idx[3]);
|
||||
});
|
||||
v_b.ForEach([&](auto& self, auto idx) {
|
||||
ck_tile::index_t table_col = int(idx[1] / problem.BLOCK_SIZE);
|
||||
ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col;
|
||||
ck_tile::index_t block_idx = block_tables_host[block_table_offset];
|
||||
|
||||
// // Scatter into o_ref's bshd descriptor memory
|
||||
// for(int s = 0; s < seqlen_q_eff; ++s)
|
||||
// {
|
||||
// for(int h = 0; h < problem.nhead_q; ++h)
|
||||
// {
|
||||
// for(int d = 0; d < problem.hdim; ++d)
|
||||
// {
|
||||
// o_ref(b, s, h, d) = o_b(0, s, h, d);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
self(idx) = v(block_idx, idx[1] % problem.BLOCK_SIZE, idx[2], idx[3]);
|
||||
});
|
||||
// v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
|
||||
|
||||
// ck_tile::HostTensor<DataType> o(problem.get_output_shape());
|
||||
// o_buf.FromDevice(o.data());
|
||||
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
|
||||
host::fmha_fwd<float, DataType>(q_b,
|
||||
k_b,
|
||||
v_b,
|
||||
problem.mask,
|
||||
o_b,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales{problem.scale_s});
|
||||
|
||||
// const auto [rtol, atol] = [&] {
|
||||
// if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
|
||||
// return std::make_tuple(1e-3, 1e-3);
|
||||
// else
|
||||
// return std::make_tuple(1e-2, 1e-2);
|
||||
// }();
|
||||
// return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
|
||||
// Scatter into o_ref's bshd descriptor memory
|
||||
for(int s = 0; s < seqlen_q_eff; ++s)
|
||||
{
|
||||
for(int h = 0; h < problem.nhead_q; ++h)
|
||||
{
|
||||
for(int d = 0; d < problem.hdim; ++d)
|
||||
{
|
||||
o_ref(b, s, h, d) = o_b(0, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
|
||||
o_buf.FromDevice(o.data());
|
||||
|
||||
const auto [rtol, atol] = [&] {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
|
||||
return std::make_tuple(1e-3, 1e-3);
|
||||
else
|
||||
return std::make_tuple(1e-2, 1e-2);
|
||||
}();
|
||||
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user