mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/composable_kernel into tianxing/unified-attention
This commit is contained in:
@@ -463,17 +463,33 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
}
|
||||
|
||||
std::size_t flop = [&] {
|
||||
if(problem.mask.type == mask_enum::no_mask)
|
||||
long flop_result = 0;
|
||||
|
||||
for(int b = 0; b < problem.batch; ++b)
|
||||
{
|
||||
return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
|
||||
problem.hdim;
|
||||
}
|
||||
else
|
||||
{
|
||||
/// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2.
|
||||
return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
|
||||
problem.hdim;
|
||||
long query_lens = has_varlen_q ? eff_q_vec[b] : problem.seqlen_q;
|
||||
long kv_lens = has_varlen_k ? eff_kv_vec[b] : problem.seqlen_k;
|
||||
long valid_out_elements = 0;
|
||||
|
||||
if(problem.mask.type == mask_enum::no_mask) {
|
||||
valid_out_elements = kv_lens * query_lens;
|
||||
} else {
|
||||
if(query_lens > kv_lens)
|
||||
{
|
||||
valid_out_elements = (kv_lens * kv_lens + kv_lens) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
valid_out_elements =
|
||||
query_lens * kv_lens - ((query_lens * query_lens - query_lens) / 2);
|
||||
}
|
||||
|
||||
}
|
||||
// Causal logic for valid output elements
|
||||
|
||||
flop_result += 2 * problem.nhead_q * valid_out_elements * (problem.hdim + problem.hdim);
|
||||
}
|
||||
return flop_result;
|
||||
}();
|
||||
float tflops = static_cast<float>(flop) / 1.e9 / time;
|
||||
|
||||
|
||||
@@ -32,15 +32,13 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser
|
||||
.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("prec", "bf16", "data type. fp16/bf16")
|
||||
// .insert("b", "3", "batch size")
|
||||
.insert("h_k", "8", "num head for k/v. num head for q is 4 times this")
|
||||
// .insert("h_k",
|
||||
// "-1",
|
||||
// "num of head, for k/v, -1 means equal to h\n"
|
||||
// "if not equal to h, then this is GQA/MQA case")
|
||||
|
||||
.insert("h_k", "8", "num head for k/v. num head for q is " + std::to_string(num_queries_per_kv) + " times this")
|
||||
.insert("s", "3328", "max seqlen_q")
|
||||
.insert("s_k", "-1", "max seqlen_k, -1 means equal to s")
|
||||
.insert("nb", "1024", "num_blks")
|
||||
.insert("b", "3", "batch")
|
||||
.insert("d", "128", "head dim for q & k")
|
||||
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
|
||||
// TODO scale factors
|
||||
@@ -55,6 +53,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
.insert("operm", "0", "permute output")
|
||||
.insert("causal", "0", "0: no mask, 1: causal mask")
|
||||
.insert("v", "1", "0:no verify, 1:verify")
|
||||
.insert("varlen", "1", "0: fixed length, 1: variable length")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
@@ -63,11 +62,11 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
.insert("repeat", "30", "number of iterations to benchmark the kernel")
|
||||
// Optional effective seqlen override (exclude PAD) for batch mode
|
||||
.insert("query_lens",
|
||||
"1, 5, 129",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("kv_lens",
|
||||
"1328, 18, 463",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.");
|
||||
|
||||
@@ -75,12 +74,48 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
return std::make_pair(result, arg_parser);
|
||||
}
|
||||
|
||||
struct FmhaMasks
|
||||
auto seqlen_preprocess(ck_tile::index_t batch,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t max_seqlen_kv,
|
||||
const std::vector<int>& query_lens_input,
|
||||
const std::vector<int>& kv_lens_input,
|
||||
bool varlen) -> std::pair<std::vector<int>, std::vector<int>>
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
// If both query_lens and kv_lens are provided, return them directly
|
||||
if(!query_lens_input.empty() && !kv_lens_input.empty())
|
||||
{
|
||||
return std::make_pair(query_lens_input, kv_lens_input);
|
||||
}
|
||||
|
||||
std::vector<int> query_lens;
|
||||
std::vector<int> kv_lens;
|
||||
|
||||
if(!varlen)
|
||||
{
|
||||
// Fixed length mode: fill with max seqlen
|
||||
query_lens.assign(batch, max_seqlen_q);
|
||||
kv_lens.assign(batch, max_seqlen_kv);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Variable length mode: generate random lengths up to max
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<int> q_dist(1, max_seqlen_q);
|
||||
std::uniform_int_distribution<int> kv_dist(1, max_seqlen_kv);
|
||||
|
||||
query_lens.resize(batch);
|
||||
kv_lens.resize(batch);
|
||||
|
||||
for(ck_tile::index_t i = 0; i < batch; ++i)
|
||||
{
|
||||
query_lens[i] = q_dist(gen);
|
||||
kv_lens[i] = kv_dist(gen);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(query_lens, kv_lens);
|
||||
}
|
||||
|
||||
struct Problem
|
||||
{
|
||||
@@ -94,10 +129,31 @@ struct Problem
|
||||
// TODO: support other GQA/MQA cases than just 4x
|
||||
nhead_q = nhead_kv * num_queries_per_kv;
|
||||
|
||||
ck_tile::index_t max_seqlen_q = args.get_int("s");
|
||||
ck_tile::index_t max_seqlen_kv = args.get_int("s_k");
|
||||
|
||||
if (max_seqlen_kv == -1) {
|
||||
max_seqlen_kv = max_seqlen_q;
|
||||
}
|
||||
|
||||
hdim = args.get_int("d");
|
||||
query_lens = args.get_int_vec("query_lens");
|
||||
kv_lens = args.get_int_vec("kv_lens");
|
||||
batch = std::max(query_lens.size(), kv_lens.size());
|
||||
assert(query_lens.size() == kv_lens.size() && "query_lens and kv_lens must have the same length b");
|
||||
batch = args.get_int("b");
|
||||
|
||||
bool varlen = args.get_bool("varlen");
|
||||
auto [query_lens_, kv_lens_] = seqlen_preprocess(
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
max_seqlen_kv,
|
||||
query_lens,
|
||||
kv_lens,
|
||||
varlen);
|
||||
|
||||
query_lens = query_lens_;
|
||||
kv_lens = kv_lens_;
|
||||
batch = query_lens.size();
|
||||
|
||||
// Calculate scale_s
|
||||
scale_s = args.get_float("scale_s");
|
||||
@@ -108,7 +164,7 @@ struct Problem
|
||||
scale = args.get_float("scale");
|
||||
scale_k = args.get_float("scale_k");
|
||||
scale_v = args.get_float("scale_v");
|
||||
|
||||
num_tokens = 0;
|
||||
for(const auto& len : query_lens)
|
||||
{
|
||||
num_tokens += len;
|
||||
@@ -198,7 +254,7 @@ template <typename AccDataType,
|
||||
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
|
||||
const ck_tile::HostTensor<KDataType>& k_bshd,
|
||||
const ck_tile::HostTensor<VDataType>& v_bshd,
|
||||
const mask_info& mask,
|
||||
// const mask_info& mask,
|
||||
ck_tile::HostTensor<ODataType>& o_bshd,
|
||||
const QElementOp& q_element_op = {},
|
||||
const KElementOp& k_element_op = {},
|
||||
@@ -222,61 +278,35 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
|
||||
|
||||
ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
|
||||
ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
|
||||
|
||||
// do computation for each batch
|
||||
for(int b = 0; b < batch_size; ++b)
|
||||
{
|
||||
// copy per-batch data from input tensors
|
||||
// clang-format off
|
||||
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] ,
|
||||
idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1],
|
||||
idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) =
|
||||
v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
|
||||
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] ,
|
||||
idx[2]); });
|
||||
k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1],
|
||||
idx[0] / nr, idx[2]); });
|
||||
v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) =
|
||||
v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
|
||||
// clang-format on
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
|
||||
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
|
||||
|
||||
if(mask.type == mask_enum::no_mask)
|
||||
{
|
||||
ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
|
||||
}
|
||||
else if(mask.type == mask_enum::window_generic)
|
||||
{
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left, mask.right, seqlen_q, seqlen_kv));
|
||||
}
|
||||
else
|
||||
{
|
||||
// if left window size is negative, means causal
|
||||
// else means generic (for current batch)
|
||||
if(mask.left < 0)
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
else
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<UnifiedAttentionMasks::CausalMask>(
|
||||
-1,
|
||||
0,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
1,
|
||||
false));
|
||||
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
|
||||
s_host_ref, p_host_ref, ck_tile::identity{});
|
||||
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
|
||||
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
|
||||
|
||||
|
||||
// copy resulting per-batch data to the output tensor
|
||||
o_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
|
||||
@@ -303,6 +333,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
|
||||
ck_tile::unified_attention_args args{};
|
||||
|
||||
args.scale_s = problem.scale_s;
|
||||
args.data_type = problem.data_type;
|
||||
args.num_seqs = problem.batch;
|
||||
args.num_head_q = problem.nhead_q;
|
||||
@@ -369,8 +400,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
};
|
||||
calculate_cumulative(eff_query_lens, cu_query_lens);
|
||||
|
||||
ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size());
|
||||
ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size());
|
||||
ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size() * sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size() * sizeof(ck_tile::index_t));
|
||||
|
||||
seq_lens_buf.ToDevice(eff_kv_lens.data());
|
||||
query_start_len_buf.ToDevice(cu_query_lens.data());
|
||||
@@ -428,7 +459,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "faild to run fmha_fwd_v3()" << std::endl;
|
||||
std::cerr << "faild to run unified_attention()" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -471,7 +502,18 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
|
||||
std::cout << "[" << problem.data_type << "|";
|
||||
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
|
||||
<< ", d:" << problem.hdim << ", mask:" << problem.mask << std::fixed << ", "
|
||||
<< ", d:" << problem.hdim << ", scale_s:" << problem.scale_s
|
||||
<< ", query_lens:[";
|
||||
for (size_t i = 0; i < problem.query_lens.size(); ++i) {
|
||||
std::cout << problem.query_lens[i];
|
||||
if (i < problem.query_lens.size() - 1) std::cout << ",";
|
||||
}
|
||||
std::cout << "], kv_lens:[";
|
||||
for (size_t i = 0; i < problem.kv_lens.size(); ++i) {
|
||||
std::cout << problem.kv_lens[i];
|
||||
if (i < problem.kv_lens.size() - 1) std::cout << ",";
|
||||
}
|
||||
std::cout << "], mask:" << "causal mask" << std::fixed << ", "
|
||||
<< std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops
|
||||
<< " TFlops, " << std::setprecision(2)
|
||||
<< (static_cast<double>(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl;
|
||||
@@ -500,11 +542,12 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
ck_tile::index_t seq_q_off = cu_query_lens[b];
|
||||
|
||||
// Copy effective region
|
||||
q_b.ForEach([&](auto& self, auto idx) {
|
||||
// idx: [0, s, h, d]
|
||||
self(idx) = q(b, idx[1], idx[2], idx[3]);
|
||||
self(idx) = q(seq_q_off + idx[1], idx[2], idx[3]);
|
||||
});
|
||||
k_b.ForEach([&](auto& self, auto idx) {
|
||||
// kv cache is paged
|
||||
@@ -527,7 +570,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
host::fmha_fwd<float, DataType>(q_b,
|
||||
k_b,
|
||||
v_b,
|
||||
problem.mask,
|
||||
// problem.mask,
|
||||
o_b,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
@@ -541,7 +584,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
{
|
||||
for(int d = 0; d < problem.hdim; ++d)
|
||||
{
|
||||
o_ref(b, s, h, d) = o_b(0, s, h, d);
|
||||
o_ref(seq_q_off + s, h, d) = o_b(0, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -550,13 +593,62 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
|
||||
o_buf.FromDevice(o.data());
|
||||
|
||||
|
||||
const auto [rtol, atol] = [&] {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
|
||||
return std::make_tuple(1e-3, 1e-3);
|
||||
else
|
||||
return std::make_tuple(1e-2, 1e-2);
|
||||
}();
|
||||
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
|
||||
|
||||
size_t total = static_cast<size_t>(problem.num_tokens) *
|
||||
static_cast<size_t>(problem.nhead_q) *
|
||||
static_cast<size_t>(problem.hdim);
|
||||
|
||||
size_t nonzero = 0;
|
||||
|
||||
for (int tok = 0; tok < problem.num_tokens; ++tok) {
|
||||
for (int h = 0; h < problem.nhead_q; ++h) {
|
||||
for (int d = 0; d < problem.hdim; ++d) {
|
||||
if (static_cast<float>(o(tok, h, d)) != 0.0f) {
|
||||
nonzero++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float percent = (total > 0)
|
||||
? (100.0f * static_cast<float>(nonzero) / static_cast<float>(total))
|
||||
: 0.0f;
|
||||
|
||||
std::cout << "\nNon-zero elements in output tensor o: "
|
||||
<< nonzero << " / " << total
|
||||
<< " (" << percent << "%)\n";
|
||||
|
||||
// std::cout << "\n=== Complete Output Tensor (o) ===\n";
|
||||
// for (int tok = 0; tok < problem.num_tokens; ++tok) {
|
||||
// std::cout << "Token " << tok << ":\n";
|
||||
// for (int h = 0; h < problem.nhead_q; ++h) {
|
||||
// std::cout << " Head " << h << ": ";
|
||||
// for (int d = 0; d < problem.hdim; ++d) {
|
||||
// std::cout << static_cast<float>(o(tok, h, d)) << " ";
|
||||
// }
|
||||
// std::cout << "\n";
|
||||
// }
|
||||
// }
|
||||
|
||||
// std::cout << "\n=== Complete Reference Tensor (o_ref) ===\n";
|
||||
// for (int tok = 0; tok < problem.num_tokens; ++tok) {
|
||||
// std::cout << "Token " << tok << ":\n";
|
||||
// for (int h = 0; h < problem.nhead_q; ++h) {
|
||||
// std::cout << " Head " << h << ": ";
|
||||
// for (int d = 0; d < problem.hdim; ++d) {
|
||||
// std::cout << static_cast<float>(o_ref(tok, h, d)) << " ";
|
||||
// }
|
||||
// std::cout << "\n";
|
||||
// }
|
||||
// }
|
||||
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/ops/unified_attention.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -76,3 +77,10 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
const stream_config& config);
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
struct UnifiedAttentionMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
@@ -124,7 +124,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
"argument num_queries_per_kv must equal compiled num_queries_per_kv");
|
||||
assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE &&
|
||||
"argument BLOCK_SIZE must equal compiled BLOCK_SIZE");
|
||||
assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv &&
|
||||
assert(BLOCK_Q == BLOCK_M / args.num_queries_per_kv &&
|
||||
"BLOCK_Q must equal BLOCK_M / num_queries_per_kv");
|
||||
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
|
||||
@@ -184,7 +184,7 @@ struct UnifiedAttentionKernel
|
||||
while(left < right)
|
||||
{
|
||||
ck_tile::index_t mid = (left + right) / 2;
|
||||
ck_tile::index_t val = query_start_len_ptr[mid];
|
||||
ck_tile::index_t val = amd_wave_read_first_lane(query_start_len_ptr[mid]);
|
||||
ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val;
|
||||
|
||||
if(mid_val <= target_idx)
|
||||
@@ -200,55 +200,15 @@ struct UnifiedAttentionKernel
|
||||
return left - 1;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto RemapTileIndices(const ck_tile::index_t pid,
|
||||
const Kargs& kargs)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t NUM_XCDS = 8;
|
||||
const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q);
|
||||
|
||||
// Number of pids per XCD in the new arrangement
|
||||
const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS;
|
||||
|
||||
// When GRID_MN cannot divide NUM_XCDS, some xcds will have
|
||||
// pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
|
||||
// We calculate the number of xcds that have pids_per_xcd pids as tall_xcds
|
||||
index_t tall_xcds = GRID_MN % NUM_XCDS;
|
||||
tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds;
|
||||
|
||||
// Compute current XCD and local pid within the XCD
|
||||
const index_t xcd = pid % NUM_XCDS;
|
||||
const index_t local_pid = pid / NUM_XCDS;
|
||||
|
||||
// Calculate new pid based on the new grouping
|
||||
index_t remapped_pid = 0; // Initialize to avoid constexpr error
|
||||
if(xcd < tall_xcds)
|
||||
{
|
||||
remapped_pid = xcd * pids_per_xcd + local_pid;
|
||||
}
|
||||
else
|
||||
{
|
||||
remapped_pid =
|
||||
tall_xcds * pids_per_xcd + (xcd - tall_xcds) * (pids_per_xcd - 1) + local_pid;
|
||||
}
|
||||
|
||||
return remapped_pid;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid,
|
||||
const Kargs& kargs)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks;
|
||||
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
|
||||
// UnifiedAttentionPipeline::kN1);
|
||||
ck_tile::index_t num_head_kv = kargs.num_head_q / kargs.num_queries_per_kv;
|
||||
|
||||
const index_t i_tile_m = pid % total_num_q_blocks; // Query block index
|
||||
const index_t i_tile_n = pid / total_num_q_blocks; // Head index
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n);
|
||||
return ck_tile::make_tuple(pid % num_head_kv, pid / num_head_kv);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -278,8 +238,6 @@ struct UnifiedAttentionKernel
|
||||
|
||||
// const index_t num_head_k = num_head_q / num_queries_per_kv;
|
||||
|
||||
pid = RemapTileIndices(pid, kargs);
|
||||
|
||||
// divide problem
|
||||
const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs);
|
||||
|
||||
@@ -295,19 +253,15 @@ struct UnifiedAttentionKernel
|
||||
BLOCK_Q,
|
||||
true); // which batch
|
||||
|
||||
const index_t q_block_start_idx =
|
||||
amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
|
||||
const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / BLOCK_Q + seq_idx;
|
||||
|
||||
const index_t q_block_local_idx =
|
||||
amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
|
||||
const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
|
||||
|
||||
const index_t cur_batch_in_all_start_index =
|
||||
amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
|
||||
const index_t cur_batch_in_all_stop_index =
|
||||
amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]);
|
||||
const index_t cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx];
|
||||
const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1];
|
||||
|
||||
const index_t cur_batch_query_len =
|
||||
cur_batch_in_all_stop_index - cur_batch_in_all_start_index;
|
||||
amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index);
|
||||
|
||||
// TODO check if we get the block size info from pipeline
|
||||
if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len)
|
||||
@@ -315,14 +269,14 @@ struct UnifiedAttentionKernel
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t query_pos = q_block_local_idx * BLOCK_Q;
|
||||
const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q);
|
||||
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
|
||||
|
||||
const index_t context_len = seq_len - cur_batch_query_len;
|
||||
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
|
||||
|
||||
index_t _max_seq_prefix_len =
|
||||
(context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) // num_queries_per_kv
|
||||
+ 1);
|
||||
amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1)
|
||||
+ 1));
|
||||
|
||||
if(seq_len < _max_seq_prefix_len)
|
||||
{
|
||||
@@ -330,7 +284,7 @@ struct UnifiedAttentionKernel
|
||||
}
|
||||
|
||||
const auto max_seq_prefix_len = _max_seq_prefix_len;
|
||||
const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
||||
|
||||
// TODO sliding window
|
||||
const index_t num_blocks_start = 0;
|
||||
@@ -357,7 +311,7 @@ struct UnifiedAttentionKernel
|
||||
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) + kv_head_offset;
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) + o_ptr_offset;
|
||||
|
||||
index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q;
|
||||
index_t query_len_padded = amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q);
|
||||
// const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0);
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
@@ -367,20 +321,20 @@ struct UnifiedAttentionKernel
|
||||
make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE),
|
||||
make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1),
|
||||
number<UnifiedAttentionPipeline::kAlignmentQ>{},
|
||||
number<2>{});
|
||||
number<1>{});
|
||||
|
||||
const auto q_dram_pad =
|
||||
pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
|
||||
q_dram_base,
|
||||
// block sizes
|
||||
make_tuple(number<BLOCK_Q>{}, number<1>{}, number<HEAD_SIZE_PADDED>{}),
|
||||
make_tuple(number<BLOCK_Q>{}, 1, HEAD_SIZE_PADDED),
|
||||
sequence<true, false, kPadHeadDimQ>{}); // pads to (seq_len_padded, num_head_q,
|
||||
// HEAD_SIZE_PADDED)
|
||||
|
||||
const auto q_dram_merged = transform_tensor_view(
|
||||
q_dram_pad,
|
||||
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
|
||||
make_pass_through_transform(number<HEAD_SIZE_PADDED>{})),
|
||||
make_pass_through_transform(HEAD_SIZE_PADDED)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{})); // flattens the first two dims, head idx is the fastest
|
||||
@@ -402,26 +356,17 @@ struct UnifiedAttentionKernel
|
||||
// HEAD dim is skipped as defined in the ptrs
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE),
|
||||
make_tuple(kargs.stride_k_cache_0, kargs.stride_k_cache_1, kargs.stride_k_cache_3),
|
||||
make_tuple(kargs.num_blks * BLOCK_SIZE, HEAD_SIZE),
|
||||
make_tuple(kargs.stride_k_cache_1, kargs.stride_k_cache_3),
|
||||
number<UnifiedAttentionPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
const auto k_dram_pad = pad_tensor_view(k_dram_naive,
|
||||
// TODO can the BLOCK_SIZE_RAW needs padding?
|
||||
make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED),
|
||||
sequence<false, false, kPadHeadDimQ>{});
|
||||
make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
const auto k_dram_merged = transform_tensor_view(
|
||||
k_dram_pad,
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)),
|
||||
make_pass_through_transform(HEAD_SIZE_PADDED)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{})); // flattens the first two dims, head idx is the fastest
|
||||
// changing dim in the merged dim
|
||||
|
||||
return k_dram_merged;
|
||||
return k_dram_pad;
|
||||
}();
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
@@ -430,25 +375,16 @@ struct UnifiedAttentionKernel
|
||||
const auto v_dram = [&]() {
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE),
|
||||
make_tuple(kargs.stride_v_cache_0, kargs.stride_v_cache_1, kargs.stride_v_cache_3),
|
||||
make_tuple(kargs.num_blks * BLOCK_SIZE, HEAD_SIZE),
|
||||
make_tuple(kargs.stride_v_cache_1, kargs.stride_v_cache_3),
|
||||
number<UnifiedAttentionPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_pad = pad_tensor_view(v_dram_naive,
|
||||
make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED),
|
||||
sequence<false, false, kPadHeadDimQ>{});
|
||||
make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
const auto v_dram_merged = transform_tensor_view(
|
||||
v_dram_pad,
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)),
|
||||
make_pass_through_transform(HEAD_SIZE_PADDED)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{})); // flattens the first two dims, head idx is the fastest
|
||||
// changing dim in the merged dim
|
||||
|
||||
return v_dram_merged;
|
||||
return v_dram_pad;
|
||||
}();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
@@ -457,12 +393,13 @@ struct UnifiedAttentionKernel
|
||||
FmhaMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
seq_len - cur_batch_query_len, // y (i.e. context)
|
||||
cur_batch_query_len, // x (i.e. extend)
|
||||
seq_len, // y_total (x + y)
|
||||
cur_batch_query_len, // x_total
|
||||
num_queries_per_kv // the same sequence index is repeated num_queries_per_kv
|
||||
-1,
|
||||
0,
|
||||
cur_batch_query_len, // y_total
|
||||
seq_len, // x_total
|
||||
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
|
||||
// times along x dim of the tile
|
||||
false
|
||||
);
|
||||
else
|
||||
return FmhaMask{cur_batch_query_len, seq_len};
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
#define ENABLE_ASM_MARKER 1
|
||||
#if ENABLE_ASM_MARKER
|
||||
#define ASM_MARKER(marker) \
|
||||
@@ -411,7 +410,7 @@ struct UnifiedAttentionPipeline
|
||||
HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize());
|
||||
static_assert(sizeof(SaccDataType) * BLOCK_SIZE * BLOCK_M <= GetSmemSize());
|
||||
auto s_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
|
||||
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
|
||||
@@ -427,7 +426,7 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
auto o_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr)),
|
||||
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
|
||||
MakeSimpleLdsDesc<BLOCK_M, HEAD_SIZE_PADDED>());
|
||||
[[maybe_unused]] auto o_lds_window = make_tile_window(
|
||||
o_lds, make_tuple(number<BLOCK_M>{}, number<HEAD_SIZE_PADDED>{}), {0, 0});
|
||||
|
||||
@@ -543,16 +542,11 @@ struct UnifiedAttentionPipeline
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
// const auto [seqlen_k_start, seqlen_k_end] =
|
||||
// mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<BLOCK_M>{},
|
||||
// number<BLOCK_SIZE>{});
|
||||
|
||||
// const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start,
|
||||
// BLOCK_SIZE);
|
||||
const auto num_total_loop = num_blocks;
|
||||
// index_t kv_token_start = seqlen_k_start;
|
||||
index_t k_block_table_off = num_blocks_start;
|
||||
index_t v_block_table_off = num_blocks_start;
|
||||
|
||||
// TODO check is paddings kPadSeqLenK
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
@@ -565,23 +559,23 @@ struct UnifiedAttentionPipeline
|
||||
}
|
||||
}
|
||||
|
||||
// TODO check correctness of this
|
||||
index_t i_total_loops = num_blocks_start;
|
||||
const ck_tile::index_t* block_tables_ptr_ =
|
||||
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
|
||||
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
|
||||
index_t kv_blk_idx_prev = 0;
|
||||
index_t kv_blk_idx_intial = block_tables_ptr_[block_table_offset + k_block_table_off];
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0},
|
||||
{kv_blk_idx_intial * BLOCK_SIZE, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
k_dram_window.init_raw();
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, // TODO: hdim split?
|
||||
{kv_blk_idx_intial * BLOCK_SIZE, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
v_dram_window.init_raw();
|
||||
|
||||
@@ -677,6 +671,9 @@ struct UnifiedAttentionPipeline
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
// TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx
|
||||
// as the index
|
||||
|
||||
k_block_table_off++;
|
||||
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + k_block_table_off];
|
||||
/// FIXME: use the future-predicting method to move the window
|
||||
k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
|
||||
};
|
||||
@@ -687,7 +684,9 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
// kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
|
||||
v_block_table_off++;
|
||||
|
||||
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + v_block_table_off];
|
||||
/// FIXME: use the future-predicting method to move the window
|
||||
v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
|
||||
};
|
||||
@@ -900,7 +899,7 @@ struct UnifiedAttentionPipeline
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
i_total_loops * BLOCK_SIZE,
|
||||
number<BLOCK_M>{},
|
||||
number<BLOCK_Q>{},
|
||||
number<BLOCK_SIZE>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
@@ -985,7 +984,6 @@ struct UnifiedAttentionPipeline
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
|
||||
// TODO what is this???
|
||||
Scheduler::schedule(cl_p, number<1>{});
|
||||
fmha_mask(xdl_SP_p01_reg_idx);
|
||||
|
||||
@@ -1014,7 +1012,6 @@ struct UnifiedAttentionPipeline
|
||||
cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx);
|
||||
|
||||
Scheduler::schedule(cl_p, number<3>{});
|
||||
// kv_token_start += BLOCK_SIZE;
|
||||
if(num_total_loop <= ++i_total_loops)
|
||||
{
|
||||
result = false;
|
||||
@@ -1061,7 +1058,6 @@ struct UnifiedAttentionPipeline
|
||||
Scheduler::schedule(cl_p, number<2>{});
|
||||
fmha_mask(xdl_SP_p01_reg_idx);
|
||||
|
||||
// kv_token_start += BLOCK_SIZE;
|
||||
if(num_total_loop <= ++i_total_loops)
|
||||
{
|
||||
result = false;
|
||||
@@ -1139,7 +1135,6 @@ struct UnifiedAttentionPipeline
|
||||
fmha_alu0(number<0>{});
|
||||
fmha_alu_D_upd();
|
||||
|
||||
// kv_token_start += BLOCK_SIZE;
|
||||
++i_total_loops;
|
||||
if(num_total_loop <= i_total_loops)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user