Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/composable_kernel into tianxing/unified-attention

This commit is contained in:
Tianxing Wu
2025-11-28 12:04:31 +00:00
6 changed files with 241 additions and 193 deletions

View File

@@ -463,17 +463,33 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
}
std::size_t flop = [&] {
if(problem.mask.type == mask_enum::no_mask)
long flop_result = 0;
for(int b = 0; b < problem.batch; ++b)
{
return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
else
{
/// FIXME: Use a more accurate method; for now, were just dividing the flop by 2.
return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
long query_lens = has_varlen_q ? eff_q_vec[b] : problem.seqlen_q;
long kv_lens = has_varlen_k ? eff_kv_vec[b] : problem.seqlen_k;
long valid_out_elements = 0;
if(problem.mask.type == mask_enum::no_mask) {
valid_out_elements = kv_lens * query_lens;
} else {
if(query_lens > kv_lens)
{
valid_out_elements = (kv_lens * kv_lens + kv_lens) / 2;
}
else
{
valid_out_elements =
query_lens * kv_lens - ((query_lens * query_lens - query_lens) / 2);
}
}
// Causal logic for valid output elements
flop_result += 2 * problem.nhead_q * valid_out_elements * (problem.hdim + problem.hdim);
}
return flop_result;
}();
float tflops = static_cast<float>(flop) / 1.e9 / time;

View File

@@ -32,15 +32,13 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
{
ck_tile::ArgParser arg_parser;
arg_parser
.insert("prec", "fp16", "data type. fp16/bf16")
.insert("prec", "bf16", "data type. fp16/bf16")
// .insert("b", "3", "batch size")
.insert("h_k", "8", "num head for k/v. num head for q is 4 times this")
// .insert("h_k",
// "-1",
// "num of head, for k/v, -1 means equal to h\n"
// "if not equal to h, then this is GQA/MQA case")
.insert("h_k", "8", "num head for k/v. num head for q is " + std::to_string(num_queries_per_kv) + " times this")
.insert("s", "3328", "max seqlen_q")
.insert("s_k", "-1", "max seqlen_k, -1 means equal to s")
.insert("nb", "1024", "num_blks")
.insert("b", "3", "batch")
.insert("d", "128", "head dim for q & k")
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
// TODO scale factors
@@ -55,6 +53,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
.insert("operm", "0", "permute output")
.insert("causal", "0", "0: no mask, 1: causal mask")
.insert("v", "1", "0:no verify, 1:verify")
.insert("varlen", "1", "0: fixed length, 1: variable length")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
@@ -63,11 +62,11 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
.insert("repeat", "30", "number of iterations to benchmark the kernel")
// Optional effective seqlen override (exclude PAD) for batch mode
.insert("query_lens",
"1, 5, 129",
"",
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.")
.insert("kv_lens",
"1328, 18, 463",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
@@ -75,12 +74,48 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
return std::make_pair(result, arg_parser);
}
struct FmhaMasks
auto seqlen_preprocess(ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t max_seqlen_kv,
const std::vector<int>& query_lens_input,
const std::vector<int>& kv_lens_input,
bool varlen) -> std::pair<std::vector<int>, std::vector<int>>
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
// If both query_lens and kv_lens are provided, return them directly
if(!query_lens_input.empty() && !kv_lens_input.empty())
{
return std::make_pair(query_lens_input, kv_lens_input);
}
std::vector<int> query_lens;
std::vector<int> kv_lens;
if(!varlen)
{
// Fixed length mode: fill with max seqlen
query_lens.assign(batch, max_seqlen_q);
kv_lens.assign(batch, max_seqlen_kv);
}
else
{
// Variable length mode: generate random lengths up to max
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<int> q_dist(1, max_seqlen_q);
std::uniform_int_distribution<int> kv_dist(1, max_seqlen_kv);
query_lens.resize(batch);
kv_lens.resize(batch);
for(ck_tile::index_t i = 0; i < batch; ++i)
{
query_lens[i] = q_dist(gen);
kv_lens[i] = kv_dist(gen);
}
}
return std::make_pair(query_lens, kv_lens);
}
struct Problem
{
@@ -94,10 +129,31 @@ struct Problem
// TODO: support other GQA/MQA cases than just 4x
nhead_q = nhead_kv * num_queries_per_kv;
ck_tile::index_t max_seqlen_q = args.get_int("s");
ck_tile::index_t max_seqlen_kv = args.get_int("s_k");
if (max_seqlen_kv == -1) {
max_seqlen_kv = max_seqlen_q;
}
hdim = args.get_int("d");
query_lens = args.get_int_vec("query_lens");
kv_lens = args.get_int_vec("kv_lens");
batch = std::max(query_lens.size(), kv_lens.size());
assert(query_lens.size() == kv_lens.size() && "query_lens and kv_lens must have the same length b");
batch = args.get_int("b");
bool varlen = args.get_bool("varlen");
auto [query_lens_, kv_lens_] = seqlen_preprocess(
batch,
max_seqlen_q,
max_seqlen_kv,
query_lens,
kv_lens,
varlen);
query_lens = query_lens_;
kv_lens = kv_lens_;
batch = query_lens.size();
// Calculate scale_s
scale_s = args.get_float("scale_s");
@@ -108,7 +164,7 @@ struct Problem
scale = args.get_float("scale");
scale_k = args.get_float("scale_k");
scale_v = args.get_float("scale_v");
num_tokens = 0;
for(const auto& len : query_lens)
{
num_tokens += len;
@@ -198,7 +254,7 @@ template <typename AccDataType,
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
const ck_tile::HostTensor<KDataType>& k_bshd,
const ck_tile::HostTensor<VDataType>& v_bshd,
const mask_info& mask,
// const mask_info& mask,
ck_tile::HostTensor<ODataType>& o_bshd,
const QElementOp& q_element_op = {},
const KElementOp& k_element_op = {},
@@ -222,61 +278,35 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
// do computation for each batch
for(int b = 0; b < batch_size; ++b)
{
// copy per-batch data from input tensors
// clang-format off
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] ,
idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1],
idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) =
v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] ,
idx[2]); });
k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1],
idx[0] / nr, idx[2]); });
v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) =
v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
// clang-format on
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
if(mask.type == mask_enum::no_mask)
{
ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
}
else if(mask.type == mask_enum::window_generic)
{
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, seqlen_q, seqlen_kv));
}
else
{
// if left window size is negative, means causal
// else means generic (for current batch)
if(mask.left < 0)
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
else
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
}
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<UnifiedAttentionMasks::CausalMask>(
-1,
0,
seqlen_q,
seqlen_kv,
1,
false));
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
s_host_ref, p_host_ref, ck_tile::identity{});
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
// copy resulting per-batch data to the output tensor
o_host_ref.ForEach(
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
@@ -303,6 +333,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
ck_tile::unified_attention_args args{};
args.scale_s = problem.scale_s;
args.data_type = problem.data_type;
args.num_seqs = problem.batch;
args.num_head_q = problem.nhead_q;
@@ -369,8 +400,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
};
calculate_cumulative(eff_query_lens, cu_query_lens);
ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size());
ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size());
ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size() * sizeof(ck_tile::index_t));
seq_lens_buf.ToDevice(eff_kv_lens.data());
query_start_len_buf.ToDevice(cu_query_lens.data());
@@ -428,7 +459,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
if(!result)
{
std::cerr << "faild to run fmha_fwd_v3()" << std::endl;
std::cerr << "faild to run unified_attention()" << std::endl;
return false;
}
@@ -471,7 +502,18 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
std::cout << "[" << problem.data_type << "|";
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
<< ", d:" << problem.hdim << ", mask:" << problem.mask << std::fixed << ", "
<< ", d:" << problem.hdim << ", scale_s:" << problem.scale_s
<< ", query_lens:[";
for (size_t i = 0; i < problem.query_lens.size(); ++i) {
std::cout << problem.query_lens[i];
if (i < problem.query_lens.size() - 1) std::cout << ",";
}
std::cout << "], kv_lens:[";
for (size_t i = 0; i < problem.kv_lens.size(); ++i) {
std::cout << problem.kv_lens[i];
if (i < problem.kv_lens.size() - 1) std::cout << ",";
}
std::cout << "], mask:" << "causal mask" << std::fixed << ", "
<< std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops
<< " TFlops, " << std::setprecision(2)
<< (static_cast<double>(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl;
@@ -500,11 +542,12 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
ck_tile::index_t seq_q_off = cu_query_lens[b];
// Copy effective region
q_b.ForEach([&](auto& self, auto idx) {
// idx: [0, s, h, d]
self(idx) = q(b, idx[1], idx[2], idx[3]);
self(idx) = q(seq_q_off + idx[1], idx[2], idx[3]);
});
k_b.ForEach([&](auto& self, auto idx) {
// kv cache is paged
@@ -527,7 +570,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
host::fmha_fwd<float, DataType>(q_b,
k_b,
v_b,
problem.mask,
// problem.mask,
o_b,
ck_tile::identity{},
ck_tile::identity{},
@@ -541,7 +584,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
{
for(int d = 0; d < problem.hdim; ++d)
{
o_ref(b, s, h, d) = o_b(0, s, h, d);
o_ref(seq_q_off + s, h, d) = o_b(0, s, h, d);
}
}
}
@@ -550,13 +593,62 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
o_buf.FromDevice(o.data());
const auto [rtol, atol] = [&] {
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
return std::make_tuple(1e-3, 1e-3);
else
return std::make_tuple(1e-2, 1e-2);
}();
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
size_t total = static_cast<size_t>(problem.num_tokens) *
static_cast<size_t>(problem.nhead_q) *
static_cast<size_t>(problem.hdim);
size_t nonzero = 0;
for (int tok = 0; tok < problem.num_tokens; ++tok) {
for (int h = 0; h < problem.nhead_q; ++h) {
for (int d = 0; d < problem.hdim; ++d) {
if (static_cast<float>(o(tok, h, d)) != 0.0f) {
nonzero++;
}
}
}
}
float percent = (total > 0)
? (100.0f * static_cast<float>(nonzero) / static_cast<float>(total))
: 0.0f;
std::cout << "\nNon-zero elements in output tensor o: "
<< nonzero << " / " << total
<< " (" << percent << "%)\n";
// std::cout << "\n=== Complete Output Tensor (o) ===\n";
// for (int tok = 0; tok < problem.num_tokens; ++tok) {
// std::cout << "Token " << tok << ":\n";
// for (int h = 0; h < problem.nhead_q; ++h) {
// std::cout << " Head " << h << ": ";
// for (int d = 0; d < problem.hdim; ++d) {
// std::cout << static_cast<float>(o(tok, h, d)) << " ";
// }
// std::cout << "\n";
// }
// }
// std::cout << "\n=== Complete Reference Tensor (o_ref) ===\n";
// for (int tok = 0; tok < problem.num_tokens; ++tok) {
// std::cout << "Token " << tok << ":\n";
// for (int h = 0; h < problem.nhead_q; ++h) {
// std::cout << " Head " << h << ": ";
// for (int d = 0; d < problem.hdim; ++d) {
// std::cout << static_cast<float>(o_ref(tok, h, d)) << " ";
// }
// std::cout << "\n";
// }
// }
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
}
int main(int argc, char* argv[])

View File

@@ -8,6 +8,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/ops/unified_attention.hpp"
namespace ck_tile {
@@ -76,3 +77,10 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
const stream_config& config);
} // namespace ck_tile
struct UnifiedAttentionMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};

View File

@@ -124,7 +124,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
"argument num_queries_per_kv must equal compiled num_queries_per_kv");
assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE &&
"argument BLOCK_SIZE must equal compiled BLOCK_SIZE");
assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv &&
assert(BLOCK_Q == BLOCK_M / args.num_queries_per_kv &&
"BLOCK_Q must equal BLOCK_M / num_queries_per_kv");
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;
auto kargs = Kernel::MakeKargs(args.q_ptr,

View File

@@ -184,7 +184,7 @@ struct UnifiedAttentionKernel
while(left < right)
{
ck_tile::index_t mid = (left + right) / 2;
ck_tile::index_t val = query_start_len_ptr[mid];
ck_tile::index_t val = amd_wave_read_first_lane(query_start_len_ptr[mid]);
ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val;
if(mid_val <= target_idx)
@@ -200,55 +200,15 @@ struct UnifiedAttentionKernel
return left - 1;
}
CK_TILE_DEVICE static constexpr auto RemapTileIndices(const ck_tile::index_t pid,
const Kargs& kargs)
{
using namespace ck_tile;
constexpr index_t NUM_XCDS = 8;
const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q);
// Number of pids per XCD in the new arrangement
const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS;
// When GRID_MN cannot divide NUM_XCDS, some xcds will have
// pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
// We calculate the number of xcds that have pids_per_xcd pids as tall_xcds
index_t tall_xcds = GRID_MN % NUM_XCDS;
tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds;
// Compute current XCD and local pid within the XCD
const index_t xcd = pid % NUM_XCDS;
const index_t local_pid = pid / NUM_XCDS;
// Calculate new pid based on the new grouping
index_t remapped_pid = 0; // Initialize to avoid constexpr error
if(xcd < tall_xcds)
{
remapped_pid = xcd * pids_per_xcd + local_pid;
}
else
{
remapped_pid =
tall_xcds * pids_per_xcd + (xcd - tall_xcds) * (pids_per_xcd - 1) + local_pid;
}
return remapped_pid;
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid,
const Kargs& kargs)
{
using namespace ck_tile;
ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks;
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
// UnifiedAttentionPipeline::kN1);
ck_tile::index_t num_head_kv = kargs.num_head_q / kargs.num_queries_per_kv;
const index_t i_tile_m = pid % total_num_q_blocks; // Query block index
const index_t i_tile_n = pid / total_num_q_blocks; // Head index
return ck_tile::make_tuple(i_tile_m, i_tile_n);
return ck_tile::make_tuple(pid % num_head_kv, pid / num_head_kv);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
@@ -278,8 +238,6 @@ struct UnifiedAttentionKernel
// const index_t num_head_k = num_head_q / num_queries_per_kv;
pid = RemapTileIndices(pid, kargs);
// divide problem
const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs);
@@ -295,19 +253,15 @@ struct UnifiedAttentionKernel
BLOCK_Q,
true); // which batch
const index_t q_block_start_idx =
amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / BLOCK_Q + seq_idx;
const index_t q_block_local_idx =
amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
const index_t cur_batch_in_all_start_index =
amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
const index_t cur_batch_in_all_stop_index =
amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]);
const index_t cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx];
const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1];
const index_t cur_batch_query_len =
cur_batch_in_all_stop_index - cur_batch_in_all_start_index;
amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index);
// TODO check if we get the block size info from pipeline
if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len)
@@ -315,14 +269,14 @@ struct UnifiedAttentionKernel
return;
}
const index_t query_pos = q_block_local_idx * BLOCK_Q;
const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q);
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
const index_t context_len = seq_len - cur_batch_query_len;
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
index_t _max_seq_prefix_len =
(context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) // num_queries_per_kv
+ 1);
amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1)
+ 1));
if(seq_len < _max_seq_prefix_len)
{
@@ -330,7 +284,7 @@ struct UnifiedAttentionKernel
}
const auto max_seq_prefix_len = _max_seq_prefix_len;
const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE);
// TODO sliding window
const index_t num_blocks_start = 0;
@@ -357,7 +311,7 @@ struct UnifiedAttentionKernel
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) + kv_head_offset;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) + o_ptr_offset;
index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q;
index_t query_len_padded = amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q);
// const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0);
// Q/K/V DRAM and DRAM window
@@ -367,20 +321,20 @@ struct UnifiedAttentionKernel
make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE),
make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1),
number<UnifiedAttentionPipeline::kAlignmentQ>{},
number<2>{});
number<1>{});
const auto q_dram_pad =
pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
q_dram_base,
// block sizes
make_tuple(number<BLOCK_Q>{}, number<1>{}, number<HEAD_SIZE_PADDED>{}),
make_tuple(number<BLOCK_Q>{}, 1, HEAD_SIZE_PADDED),
sequence<true, false, kPadHeadDimQ>{}); // pads to (seq_len_padded, num_head_q,
// HEAD_SIZE_PADDED)
const auto q_dram_merged = transform_tensor_view(
q_dram_pad,
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
make_pass_through_transform(number<HEAD_SIZE_PADDED>{})),
make_pass_through_transform(HEAD_SIZE_PADDED)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{},
sequence<1>{})); // flattens the first two dims, head idx is the fastest
@@ -402,26 +356,17 @@ struct UnifiedAttentionKernel
// HEAD dim is skipped as defined in the ptrs
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE),
make_tuple(kargs.stride_k_cache_0, kargs.stride_k_cache_1, kargs.stride_k_cache_3),
make_tuple(kargs.num_blks * BLOCK_SIZE, HEAD_SIZE),
make_tuple(kargs.stride_k_cache_1, kargs.stride_k_cache_3),
number<UnifiedAttentionPipeline::kAlignmentK>{},
number<1>{});
const auto k_dram_pad = pad_tensor_view(k_dram_naive,
// TODO can the BLOCK_SIZE_RAW needs padding?
make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED),
sequence<false, false, kPadHeadDimQ>{});
make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED),
sequence<false, kPadHeadDimQ>{});
const auto k_dram_merged = transform_tensor_view(
k_dram_pad,
make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)),
make_pass_through_transform(HEAD_SIZE_PADDED)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{},
sequence<1>{})); // flattens the first two dims, head idx is the fastest
// changing dim in the merged dim
return k_dram_merged;
return k_dram_pad;
}();
auto k_dram_window = make_tile_window(
@@ -430,25 +375,16 @@ struct UnifiedAttentionKernel
const auto v_dram = [&]() {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE),
make_tuple(kargs.stride_v_cache_0, kargs.stride_v_cache_1, kargs.stride_v_cache_3),
make_tuple(kargs.num_blks * BLOCK_SIZE, HEAD_SIZE),
make_tuple(kargs.stride_v_cache_1, kargs.stride_v_cache_3),
number<UnifiedAttentionPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_pad = pad_tensor_view(v_dram_naive,
make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED),
sequence<false, false, kPadHeadDimQ>{});
make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED),
sequence<false, kPadHeadDimQ>{});
const auto v_dram_merged = transform_tensor_view(
v_dram_pad,
make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)),
make_pass_through_transform(HEAD_SIZE_PADDED)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{},
sequence<1>{})); // flattens the first two dims, head idx is the fastest
// changing dim in the merged dim
return v_dram_merged;
return v_dram_pad;
}();
auto v_dram_window = make_tile_window(
@@ -457,12 +393,13 @@ struct UnifiedAttentionKernel
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
seq_len - cur_batch_query_len, // y (i.e. context)
cur_batch_query_len, // x (i.e. extend)
seq_len, // y_total (x + y)
cur_batch_query_len, // x_total
num_queries_per_kv // the same sequence index is repeated num_queries_per_kv
-1,
0,
cur_batch_query_len, // y_total
seq_len, // x_total
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
// times along x dim of the tile
false
);
else
return FmhaMask{cur_batch_query_len, seq_len};

View File

@@ -6,7 +6,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#define ENABLE_ASM_MARKER 1
#if ENABLE_ASM_MARKER
#define ASM_MARKER(marker) \
@@ -411,7 +410,7 @@ struct UnifiedAttentionPipeline
HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize());
static_assert(sizeof(SaccDataType) * BLOCK_SIZE * BLOCK_M <= GetSmemSize());
auto s_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
@@ -427,7 +426,7 @@ struct UnifiedAttentionPipeline
auto o_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr)),
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
MakeSimpleLdsDesc<BLOCK_M, HEAD_SIZE_PADDED>());
[[maybe_unused]] auto o_lds_window = make_tile_window(
o_lds, make_tuple(number<BLOCK_M>{}, number<HEAD_SIZE_PADDED>{}), {0, 0});
@@ -543,16 +542,11 @@ struct UnifiedAttentionPipeline
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
// const auto [seqlen_k_start, seqlen_k_end] =
// mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<BLOCK_M>{},
// number<BLOCK_SIZE>{});
// const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start,
// BLOCK_SIZE);
const auto num_total_loop = num_blocks;
// index_t kv_token_start = seqlen_k_start;
index_t k_block_table_off = num_blocks_start;
index_t v_block_table_off = num_blocks_start;
// TODO check is paddings kPadSeqLenK
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking)
{
@@ -565,23 +559,23 @@ struct UnifiedAttentionPipeline
}
}
// TODO check correctness of this
index_t i_total_loops = num_blocks_start;
const ck_tile::index_t* block_tables_ptr_ =
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
index_t kv_blk_idx_prev = 0;
index_t kv_blk_idx_intial = block_tables_ptr_[block_table_offset + k_block_table_off];
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0},
{kv_blk_idx_intial * BLOCK_SIZE, 0},
Policy::template MakeKDramTileDistribution<Problem>());
k_dram_window.init_raw();
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, // TODO: hdim split?
{kv_blk_idx_intial * BLOCK_SIZE, 0},
Policy::template MakeVDramTileDistribution<Problem>());
v_dram_window.init_raw();
@@ -677,6 +671,9 @@ struct UnifiedAttentionPipeline
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
// TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx
// as the index
k_block_table_off++;
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + k_block_table_off];
/// FIXME: use the future-predicting method to move the window
k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
};
@@ -687,7 +684,9 @@ struct UnifiedAttentionPipeline
auto V_mem_load = [&](auto v_lds_write_idx) {
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
// kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
v_block_table_off++;
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + v_block_table_off];
/// FIXME: use the future-predicting method to move the window
v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
};
@@ -900,7 +899,7 @@ struct UnifiedAttentionPipeline
{
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
i_total_loops * BLOCK_SIZE,
number<BLOCK_M>{},
number<BLOCK_Q>{},
number<BLOCK_SIZE>{});
if(need_perpixel_check)
{
@@ -985,7 +984,6 @@ struct UnifiedAttentionPipeline
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
// TODO what is this???
Scheduler::schedule(cl_p, number<1>{});
fmha_mask(xdl_SP_p01_reg_idx);
@@ -1014,7 +1012,6 @@ struct UnifiedAttentionPipeline
cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx);
Scheduler::schedule(cl_p, number<3>{});
// kv_token_start += BLOCK_SIZE;
if(num_total_loop <= ++i_total_loops)
{
result = false;
@@ -1061,7 +1058,6 @@ struct UnifiedAttentionPipeline
Scheduler::schedule(cl_p, number<2>{});
fmha_mask(xdl_SP_p01_reg_idx);
// kv_token_start += BLOCK_SIZE;
if(num_total_loop <= ++i_total_loops)
{
result = false;
@@ -1139,7 +1135,6 @@ struct UnifiedAttentionPipeline
fmha_alu0(number<0>{});
fmha_alu_D_upd();
// kv_token_start += BLOCK_SIZE;
++i_total_loops;
if(num_total_loop <= i_total_loops)
{