cmake list update

This commit is contained in:
Tianxing Wu
2025-11-11 14:35:26 +00:00
parent 47c9d0a131
commit 618ed6defb
13 changed files with 495 additions and 487 deletions

View File

@@ -187,36 +187,42 @@ if(NOT INST_TARGETS)
return()
endif()
set(EXAMPLE_FMHA_FWD_V3 "tile_example_unified_attention")
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")
set(EXAMPLE_UNIFIED_ATTENTION "tile_example_unified_attention")
message(DEBUG "adding example ${EXAMPLE_UNIFIED_ATTENTION}")
add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_unified_attention.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS
add_executable(${EXAMPLE_UNIFIED_ATTENTION} EXCLUDE_FROM_ALL example_unified_attention.cpp)
target_include_directories(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
file(GLOB UNIFIED_ATTENTION_INSTANCES CONFIGURE_DEPENDS
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
)
target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE
target_sources(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE
unified_attention.cpp
${FMHA_FWD_V3_INSTANCES}
${UNIFIED_ATTENTION_INSTANCES}
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
set(EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS)
list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS
-fgpu-flush-denormals-to-zero
-Wno-undefined-func-template
--save-temps
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS)
set(EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS)
check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32)
if(HAS_DISABLE_PACKED_FP32)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS
-mllvm --amdgpu-disable-packed-fp32=1
)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS
list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS
-DCK_TILE_DISABLE_PACKED_FP32=1
)
endif()
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS})
target_compile_options(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS})
target_compile_definitions(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -93,20 +93,20 @@ struct Problem
{
explicit Problem(const ck_tile::ArgParser& args)
{
data_type = args.get_str("prec") == "fp16"
? ck_tile::unified_attention_args::data_type_enum::fp16
: ck_tile::unified_attention_args::data_type_enum::bf16;
batch = args.get_int("b");
max_seqlen_q = args.get_int("s");
data_type = args.get_str("prec") == "fp16"
? ck_tile::unified_attention_args::data_type_enum::fp16
: ck_tile::unified_attention_args::data_type_enum::bf16;
batch = args.get_int("b");
max_seqlen_q = args.get_int("s");
max_context_len = args.get_int("s_k");
num_blks = args.get_int("nb");
BLOCK_SIZE = args.get_int("bs");
nhead_q = args.get_int("h");
nhead_kv = args.get_int("h_k");
num_blks = args.get_int("nb");
BLOCK_SIZE = args.get_int("bs");
nhead_q = args.get_int("h");
nhead_kv = args.get_int("h_k");
hdim = args.get_int("d");
hdim = args.get_int("d");
query_lens = args.get_int_vec("query_lens");
kv_lens = args.get_int_vec("kv_lens");
kv_lens = args.get_int_vec("kv_lens");
// Calculate scale_s
scale_s = args.get_float("scale_s");
@@ -114,14 +114,15 @@ struct Problem
scale_s = 1.0f / ck_tile::sqrt(static_cast<float>(hdim));
// Initialize other scales
scale = args.get_float("scale");
scale = args.get_float("scale");
scale_k = args.get_float("scale_k");
scale_v = args.get_float("scale_v");
// Calculate sums of query_lens and kv_lens if provided
// int64_t kv_lens_sum = 0;
for (const auto& len : query_lens) {
for(const auto& len : query_lens)
{
num_tokens += len;
}
@@ -130,10 +131,7 @@ struct Problem
// }
}
std::vector<ck_tile::index_t> get_query_shape() const
{
return {num_tokens, nhead_q, hdim};
}
std::vector<ck_tile::index_t> get_query_shape() const { return {num_tokens, nhead_q, hdim}; }
std::vector<ck_tile::index_t> get_key_shape() const
{
@@ -145,11 +143,7 @@ struct Problem
return {num_blks, BLOCK_SIZE, nhead_kv, hdim};
}
std::vector<ck_tile::index_t> get_output_shape() const
{
return {num_tokens, nhead_q, hdim};
}
std::vector<ck_tile::index_t> get_output_shape() const { return {num_tokens, nhead_q, hdim}; }
ck_tile::unified_attention_args::data_type_enum data_type;
ck_tile::index_t batch;
@@ -209,7 +203,6 @@ auto generate_qkv(const Problem& problem,
return std::make_tuple(q, k, v);
}
// namespace host {
// template <typename AccDataType,
// typename PDataType,
@@ -231,81 +224,82 @@ auto generate_qkv(const Problem& problem,
// const VElementOp& v_element_op = {},
// const SAccElementOp& s_acc_element_op = {})
// {
// const int batch_size = q_bshd.mDesc.get_lengths()[0];
// const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
// const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
// const int nhead_q = q_bshd.mDesc.get_lengths()[2];
// const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
// const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
// const int hdim_v = v_bshd.mDesc.get_lengths()[3];
// const int batch_size = q_bshd.mDesc.get_lengths()[0];
// const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
// const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
// const int nhead_q = q_bshd.mDesc.get_lengths()[2];
// const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
// const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
// const int hdim_v = v_bshd.mDesc.get_lengths()[3];
// const int nr = nhead_q / nhead_kv;
// const int nr = nhead_q / nhead_kv;
// ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
// ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
// ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
// ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
// ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
// ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
// ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
// ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
// ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
// ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
// ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
// ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
// // do computation for each batch
// for(int b = 0; b < batch_size; ++b)
// {
// // copy per-batch data from input tensors
// // clang-format off
// q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); });
// k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); });
// v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
// // clang-format on
// ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
// q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
// // do computation for each batch
// for(int b = 0; b < batch_size; ++b)
// {
// // copy per-batch data from input tensors
// // clang-format off
// q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] ,
// idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1],
// idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) =
// v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
// // clang-format on
// ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
// q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
// if(mask.type == mask_enum::no_mask)
// {
// ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
// }
// else if(mask.type == mask_enum::window_generic)
// {
// ck_tile::reference_batched_masking(
// s_host_ref,
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
// mask.left, mask.right, seqlen_q, seqlen_kv));
// }
// else
// {
// // if left window size is negative, means causal
// // else means generic (for current batch)
// if(mask.left < 0)
// ck_tile::reference_batched_masking(
// s_host_ref,
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
// mask.left,
// mask.right,
// seqlen_q,
// seqlen_kv,
// mask.type == mask_enum::mask_top_left));
// else
// ck_tile::reference_batched_masking(
// s_host_ref,
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
// mask.left,
// mask.right,
// seqlen_q,
// seqlen_kv,
// mask.type == mask_enum::mask_top_left));
// }
// if(mask.type == mask_enum::no_mask)
// {
// ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
// }
// else if(mask.type == mask_enum::window_generic)
// {
// ck_tile::reference_batched_masking(
// s_host_ref,
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
// mask.left, mask.right, seqlen_q, seqlen_kv));
// }
// else
// {
// // if left window size is negative, means causal
// // else means generic (for current batch)
// if(mask.left < 0)
// ck_tile::reference_batched_masking(
// s_host_ref,
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
// mask.left,
// mask.right,
// seqlen_q,
// seqlen_kv,
// mask.type == mask_enum::mask_top_left));
// else
// ck_tile::reference_batched_masking(
// s_host_ref,
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
// mask.left,
// mask.right,
// seqlen_q,
// seqlen_kv,
// mask.type == mask_enum::mask_top_left));
// }
// ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
// s_host_ref, p_host_ref, ck_tile::identity{});
// ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
// s_host_ref, p_host_ref, ck_tile::identity{});
// ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
// p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
// ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
// p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
// // copy resulting per-batch data to the output tensor
// o_host_ref.ForEach(
// [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
// }
// // copy resulting per-batch data to the output tensor
// o_host_ref.ForEach(
// [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
// }
// }
// } // namespace host
@@ -328,20 +322,20 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
ck_tile::unified_attention_args args{};
args.data_type = problem.data_type;
args.num_seqs = problem.batch;
args.data_type = problem.data_type;
args.num_seqs = problem.batch;
// args.seqlen_q = problem.seqlen_q;
// args.seqlen_k = problem.seqlen_k;
args.num_head_q = problem.nhead_q;
args.num_queries_per_kv = problem.nhead_q / problem.nhead_kv;
args.mask_type = 2;
args.hdim = problem.hdim;
args.num_head_q = problem.nhead_q;
args.num_queries_per_kv = problem.nhead_q / problem.nhead_kv;
args.mask_type = 2;
args.hdim = problem.hdim;
args.num_blks = problem.num_blks;
// args.query_lens = problem.query_lens
// args.kv_lens = problem.kv_lens
args.q_ptr = q_buf.GetDeviceBuffer();
args.q_ptr = q_buf.GetDeviceBuffer();
args.query_stride_0 = problem.hdim * problem.nhead_q;
args.query_stride_0 = problem.hdim;
@@ -352,13 +346,13 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
args.stride_k_cache_2 = problem.hdim;
args.stride_k_cache_3 = 1;
args.v_ptr = v_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();
args.stride_v_cache_0 = args.stride_k_cache_0;
args.stride_v_cache_1 = args.stride_k_cache_1;
args.stride_v_cache_2 = args.stride_k_cache_2;
args.stride_v_cache_3 = args.stride_k_cache_3;
args.o_ptr = o_buf.GetDeviceBuffer();
args.o_ptr = o_buf.GetDeviceBuffer();
args.output_stride_0 = args.query_stride_0;
args.output_stride_1 = args.query_stride_1;
@@ -380,13 +374,13 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
return eff;
};
const auto eff_query_lens = make_effective_vec(problem.query_lens, 1024);
const auto eff_kv_lens = make_effective_vec(problem.kv_lens, 1024);
const auto eff_query_lens = make_effective_vec(problem.query_lens, 1024);
const auto eff_kv_lens = make_effective_vec(problem.kv_lens, 1024);
args.num_tokens = std::accumulate(eff_query_lens.begin(), eff_query_lens.end(), 0);
// Calculate cumulative sums for kernel arguments if varlen is used
std::vector<ck_tile::index_t> cu_query_lens ;
std::vector<ck_tile::index_t> cu_query_lens;
auto calculate_cumulative = [&](const std::vector<ck_tile::index_t>& per_batch_vec,
std::vector<ck_tile::index_t>& cum_vec) {
@@ -403,14 +397,16 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
seq_lens_buf.ToDevice(eff_kv_lens.data());
query_start_len_buf.ToDevice(cu_query_lens.data());
args.seq_lens_ptr =reinterpret_cast<const ck_tile::index_t*>(seq_lens_buf.GetDeviceBuffer());
args.query_start_len_ptr =reinterpret_cast<const ck_tile::index_t*>(query_start_len_buf.GetDeviceBuffer());
args.seq_lens_ptr = reinterpret_cast<const ck_tile::index_t*>(seq_lens_buf.GetDeviceBuffer());
args.query_start_len_ptr =
reinterpret_cast<const ck_tile::index_t*>(query_start_len_buf.GetDeviceBuffer());
auto max_element = [&](const std::vector<ck_tile::index_t>& opt_vec) {
ck_tile::index_t max = opt_vec[0];
for (ck_tile::index_t i: opt_vec) {
if (i > max){
for(ck_tile::index_t i : opt_vec)
{
if(i > max)
{
max = i;
}
}
@@ -419,10 +415,12 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
ck_tile::index_t max_kv_len = max_element(eff_kv_lens);
ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE;
ck_tile::index_t max_num_blocks_per_seq =
(max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE;
// Create block_tables
ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * sizeof(ck_tile::index_t));
ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq *
sizeof(ck_tile::index_t));
// Allocate host memory for block_tables
std::vector<ck_tile::index_t> block_tables_host(problem.batch * max_num_blocks_per_seq);
@@ -430,7 +428,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
// Fill block_tables with random integers between 0 and num_blocks-1
std::mt19937 rng(run_config.seed ? *run_config.seed : std::random_device{}());
std::uniform_int_distribution<ck_tile::index_t> dist(0, problem.num_blks - 1);
for (size_t i = 0; i < block_tables_host.size(); ++i) {
for(size_t i = 0; i < block_tables_host.size(); ++i)
{
block_tables_host[i] = dist(rng);
}
@@ -438,10 +437,10 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
block_tables_buf.ToDevice(block_tables_host.data());
// Set pointer in args
args.block_tables_ptr = reinterpret_cast<const ck_tile::index_t*>(block_tables_buf.GetDeviceBuffer());
args.block_tables_ptr =
reinterpret_cast<const ck_tile::index_t*>(block_tables_buf.GetDeviceBuffer());
args.block_table_stride = max_num_blocks_per_seq;
ck_tile::stream_config stream_config{nullptr,
true,
/*log_level=*/0,
@@ -476,7 +475,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
// std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
// << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim
// << ", scale_s:" << problem.sacle_s << ", mask:" << problem.mask << std::fixed
// << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops
// << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) <<
// tflops
// << " TFlops" << std::endl;
// if(!run_config.verify)
@@ -548,7 +548,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
// }
// }
// }
// ck_tile::HostTensor<DataType> o(problem.get_output_shape());
// o_buf.FromDevice(o.data());

View File

@@ -7,7 +7,8 @@
namespace ck_tile {
std::ostream& operator<<(std::ostream& stream, const unified_attention_args::data_type_enum& data_type)
std::ostream& operator<<(std::ostream& stream,
const unified_attention_args::data_type_enum& data_type)
{
switch(data_type)
{
@@ -17,14 +18,16 @@ std::ostream& operator<<(std::ostream& stream, const unified_attention_args::dat
}
}
std::pair<bool, float> unified_attention(const unified_attention_args& args, const stream_config& config)
std::pair<bool, float> unified_attention(const unified_attention_args& args,
const stream_config& config)
{
if(args.data_type == unified_attention_args::data_type_enum::fp16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false>;
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16,
false>;
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
}
@@ -41,7 +44,8 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args, con
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false>;
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16,
false>;
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
}

View File

@@ -52,24 +52,26 @@ struct unified_attention_args
index_t stride_v_cache_1;
index_t stride_v_cache_2;
index_t stride_v_cache_3;
void* o_ptr;
index_t output_stride_0;
index_t output_stride_1;
const int32_t* block_tables_ptr;
index_t block_table_stride;
const int32_t* seq_lens_ptr; // seq len in each batch
const int32_t* seq_lens_ptr; // seq len in each batch
const int32_t* query_start_len_ptr; // [num_seqs+1]
index_t num_seqs; // number of batches for q
};
std::ostream& operator<<(std::ostream& stream, const unified_attention_args::data_type_enum& data_type);
std::ostream& operator<<(std::ostream& stream,
const unified_attention_args::data_type_enum& data_type);
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
std::pair<bool, float> unified_attention(const unified_attention_args& args, const stream_config& config);
std::pair<bool, float> unified_attention(const unified_attention_args& args,
const stream_config& config);
} // namespace ck_tile

View File

@@ -20,13 +20,13 @@
#include "unified_attention.hpp"
#include "mask.hpp"
#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \
template <> \
std::pair<bool, float> unified_attention_kernel_dispatch<kernel_traits>( \
const unified_attention_args& args, const stream_config& config) \
{ \
return std::make_pair(true, \
unified_attention_kernel_launch<kernel_traits::kernel>(args, config)); \
#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \
template <> \
std::pair<bool, float> unified_attention_kernel_dispatch<kernel_traits>( \
const unified_attention_args& args, const stream_config& config) \
{ \
return std::make_pair( \
true, unified_attention_kernel_launch<kernel_traits::kernel>(args, config)); \
}
namespace ck_tile {
@@ -55,8 +55,8 @@ struct unified_attention_problem_traits<unified_attention_args::data_type_enum::
template <unified_attention_args::data_type_enum DataType, bool IsMasking>
struct unified_attention_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
// BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE
using unified_attention_block_tile = sequence<256, 64, 128, 128>;
@@ -64,34 +64,34 @@ struct unified_attention_kernel_traits
using unified_attention_block_warps = sequence<8, 1, 1>;
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
true // IsVLayoutRowMajor
>;
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
true // IsVLayoutRowMajor
>;
using unified_attention_traits = TileUnifiedAttentionTraits<true, // kPadSeqLenQ_
false, // kPadHeadDimQ
-1 // kBlockPerCu
>;
false, // kPadHeadDimQ
-1 // kBlockPerCu
>;
using unified_attention_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
using unified_attention_pipeline_problem =
UnifiedAttentionPipelineProblem<typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::lse_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
unified_attention_shape,
unified_attention_mask,
unified_attention_traits>;
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::lse_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
unified_attention_shape,
unified_attention_mask,
unified_attention_traits>;
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
@@ -107,11 +107,12 @@ struct unified_attention_kernel_traits
};
template <typename Kernel>
float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config)
float unified_attention_kernel_launch(const unified_attention_args& args,
const stream_config& config)
{
index_t BLOCK_Q = Kernel::BLOCK_M / args.num_queries_per_kv;
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;
auto kargs = Kernel::MakeKargs(args.q_ptr,
@@ -128,26 +129,25 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const
args.scale_out,
total_num_q_blocks,
args.query_stride_0,
args.query_stride_1,
args.stride_k_cache_0,
args.stride_k_cache_1,
args.stride_k_cache_2,
args.stride_k_cache_3,
args.stride_v_cache_0,
args.stride_v_cache_1,
args.stride_v_cache_2,
args.stride_v_cache_3,
args.output_stride_0,
args.output_stride_1,
args.block_tables_ptr,
args.block_table_stride,
args.seq_lens_ptr,
args.query_start_len_ptr,
args.num_seqs
);
args.query_stride_1,
args.stride_k_cache_0,
args.stride_k_cache_1,
args.stride_k_cache_2,
args.stride_k_cache_3,
args.stride_v_cache_0,
args.stride_v_cache_1,
args.stride_v_cache_2,
args.stride_v_cache_3,
args.output_stride_0,
args.output_stride_1,
args.block_tables_ptr,
args.block_table_stride,
args.seq_lens_ptr,
args.query_start_len_ptr,
args.num_seqs);
dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);
constexpr dim3 blocks = Kernel::BlockSize();
dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
@@ -158,6 +158,6 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const
// second = elapsed time (ms) of the kernel launch, valid only if first == true
template <typename KernelTraits>
std::pair<bool, float> unified_attention_kernel_dispatch(const unified_attention_args& args,
const stream_config& config);
const stream_config& config);
} // namespace ck_tile

View File

@@ -3,12 +3,6 @@
#pragma once
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
// Block-level components
#include "ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/unified_attention/block/block_dropout.hpp"
#include "ck_tile/ops/unified_attention/block/block_masking.hpp"
@@ -16,15 +10,15 @@
#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/unified_attention/block/page_block_navigator.hpp"
#include "ck_tile/ops/unified_attention/block/variants.hpp"
// Kernel-level components
#include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp"
// Pipeline-level components
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp"
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -86,19 +86,22 @@ struct GenericAttentionMask
static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
// New constructor accepting repeat_idx with default value 1
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1)
CK_TILE_HOST_DEVICE
GenericAttentionMask(index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1)
: GenericAttentionMask(0, 0, y_total_, x_total_, repeat_idx_)
{
}
CK_TILE_HOST_DEVICE
GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1)
GenericAttentionMask(
index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_), repeat_idx(repeat_idx_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord, index_t repeat_idx_ = 1)
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord,
index_t repeat_idx_ = 1)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
y_total(mask_coord.at(number<2>{})),
@@ -248,13 +251,12 @@ struct GenericAttentionMask
}
}
private:
private:
index_t y, x;
index_t y_total, x_total;
index_t repeat_idx;
};
// TODO: prefer use this function in host code
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
@@ -289,7 +291,7 @@ make_generic_attention_mask_from_lr_window(index_t left_size,
index_t y_total,
index_t x_total,
index_t repeat_idx = 1,
bool is_top_left = true)
bool is_top_left = true)
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, is_top_left);

View File

@@ -18,8 +18,8 @@ namespace ck_tile {
template <typename UnifiedAttentionPipeline_, typename EpiloguePipeline_>
struct UnifiedAttentionKernel
{
using UnifiedAttentionPipeline = ck_tile::remove_cvref_t<UnifiedAttentionPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
using UnifiedAttentionPipeline = ck_tile::remove_cvref_t<UnifiedAttentionPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = UnifiedAttentionPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = UnifiedAttentionPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
@@ -29,18 +29,18 @@ struct UnifiedAttentionKernel
using VDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::VDataType>;
using ODataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::SaccDataType>;
using FmhaMask = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::FmhaMask>;
using FmhaMask = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kPadSeqLenK = UnifiedAttentionPipeline::kPadSeqLenK;
static constexpr bool kPadSeqLenQ = UnifiedAttentionPipeline::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = UnifiedAttentionPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = UnifiedAttentionPipeline::kPadHeadDimV;
// TODO add yjese
static constexpr index_t HEAD_SIZE = UnifiedAttentionPipeline::HEAD_SIZE;
// TODO add yjese
static constexpr index_t HEAD_SIZE = UnifiedAttentionPipeline::HEAD_SIZE;
static constexpr index_t HEAD_SIZE_PADDED = UnifiedAttentionPipeline::HEAD_SIZE_PADDED;
// BLOCK_Q = BLOCK_M // num_queries_per_kv
// BLOCK_Q is the block size for q seqlen
/// static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q;
@@ -49,7 +49,6 @@ struct UnifiedAttentionKernel
// BLOCK size for K seqlen
static constexpr index_t BLOCK_SIZE = UnifiedAttentionPipeline::BLOCK_SIZE;
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
@@ -88,12 +87,11 @@ struct UnifiedAttentionKernel
ck_tile::index_t output_stride_1;
};
struct UnifiedAttentionVarlenKargs: UnifiedAttentionCommonKargs
struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs
{
const int32_t* block_tables_ptr;
ck_tile::index_t block_table_stride;
const int32_t* seq_lens_ptr; // seq len in each batch
const int32_t* seq_lens_ptr; // seq len in each batch
const int32_t* query_start_len_ptr; // [num_seqs+1]
ck_tile::index_t num_seqs; // number of batches for q
@@ -101,38 +99,36 @@ struct UnifiedAttentionKernel
using Kargs = UnifiedAttentionVarlenKargs;
CK_TILE_HOST static constexpr Kargs MakeKargs(
const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
void* o_ptr,
ck_tile::index_t num_blks,
ck_tile::index_t num_head_q,
const ck_tile::index_t num_queries_per_kv,
float scale_s,
float scale,
float scale_k,
float scale_v,
float scale_out,
ck_tile::index_t total_num_q_blocks,
ck_tile::index_t query_stride_0,
ck_tile::index_t query_stride_1,
ck_tile::index_t stride_k_cache_0,
ck_tile::index_t stride_k_cache_1,
ck_tile::index_t stride_k_cache_2,
ck_tile::index_t stride_k_cache_3,
ck_tile::index_t stride_v_cache_0,
ck_tile::index_t stride_v_cache_1,
ck_tile::index_t stride_v_cache_2,
ck_tile::index_t stride_v_cache_3,
ck_tile::index_t output_stride_0,
ck_tile::index_t output_stride_1,
const int32_t* block_tables_ptr,
ck_tile::index_t block_table_stride,
const int32_t* seq_lens_ptr,
const int32_t* query_start_len_ptr,
ck_tile::index_t num_seqs
)
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
void* o_ptr,
ck_tile::index_t num_blks,
ck_tile::index_t num_head_q,
const ck_tile::index_t num_queries_per_kv,
float scale_s,
float scale,
float scale_k,
float scale_v,
float scale_out,
ck_tile::index_t total_num_q_blocks,
ck_tile::index_t query_stride_0,
ck_tile::index_t query_stride_1,
ck_tile::index_t stride_k_cache_0,
ck_tile::index_t stride_k_cache_1,
ck_tile::index_t stride_k_cache_2,
ck_tile::index_t stride_k_cache_3,
ck_tile::index_t stride_v_cache_0,
ck_tile::index_t stride_v_cache_1,
ck_tile::index_t stride_v_cache_2,
ck_tile::index_t stride_v_cache_3,
ck_tile::index_t output_stride_0,
ck_tile::index_t output_stride_1,
const int32_t* block_tables_ptr,
ck_tile::index_t block_table_stride,
const int32_t* seq_lens_ptr,
const int32_t* query_start_len_ptr,
ck_tile::index_t num_seqs)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -146,31 +142,30 @@ struct UnifiedAttentionKernel
scale_k,
scale_v,
scale_out,
total_num_q_blocks,
query_stride_0,
query_stride_1,
stride_k_cache_0,
stride_k_cache_1,
stride_k_cache_2,
stride_k_cache_3,
stride_v_cache_0,
stride_v_cache_1,
stride_v_cache_2,
stride_v_cache_3,
output_stride_0,
output_stride_1},
block_tables_ptr,
block_table_stride,
seq_lens_ptr,
query_start_len_ptr,
num_seqs
};
total_num_q_blocks,
query_stride_0,
query_stride_1,
stride_k_cache_0,
stride_k_cache_1,
stride_k_cache_2,
stride_k_cache_3,
stride_v_cache_0,
stride_v_cache_1,
stride_v_cache_2,
stride_v_cache_3,
output_stride_0,
output_stride_1},
block_tables_ptr,
block_table_stride,
seq_lens_ptr,
query_start_len_ptr,
num_seqs};
return kargs;
}
CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads,
ck_tile::index_t total_num_q_blocks)
ck_tile::index_t total_num_q_blocks)
{
return dim3(num_kv_heads * total_num_q_blocks, 0, 0);
}
@@ -190,16 +185,16 @@ struct UnifiedAttentionKernel
ck_tile::index_t block_q,
bool use_q_block_mode)
{
ck_tile::index_t left = 0;
ck_tile::index_t left = 0;
ck_tile::index_t right = num_seqs;
while (left < right)
while(left < right)
{
ck_tile::index_t mid = (left + right) / 2;
ck_tile::index_t val = query_start_len_ptr[mid];
ck_tile::index_t mid = (left + right) / 2;
ck_tile::index_t val = query_start_len_ptr[mid];
ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val;
if (mid_val <= target_idx)
if(mid_val <= target_idx)
{
left = mid + 1;
}
@@ -208,32 +203,31 @@ struct UnifiedAttentionKernel
right = mid;
}
}
return left - 1;
}
CK_TILE_DEVICE static constexpr auto
RemapTileIndices(const ck_tile::index_t pid, const Kargs& kargs)
CK_TILE_DEVICE static constexpr auto RemapTileIndices(const ck_tile::index_t pid,
const Kargs& kargs)
{
using namespace ck_tile;
constexpr index_t NUM_XCDS = 8;
const index_t GRID_MN = kargs.total_num_q_blocks *
(kargs.num_head_q);
const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q);
// Number of pids per XCD in the new arrangement
const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS;
// When GRID_MN cannot divide NUM_XCDS, some xcds will have
// pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
// We calculate the number of xcds that have pids_per_xcd pids as tall_xcds
index_t tall_xcds = GRID_MN % NUM_XCDS;
tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds;
tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds;
// Compute current XCD and local pid within the XCD
const index_t xcd = pid % NUM_XCDS;
const index_t xcd = pid % NUM_XCDS;
const index_t local_pid = pid / NUM_XCDS;
// Calculate new pid based on the new grouping
index_t remapped_pid = 0; // Initialize to avoid constexpr error
if(xcd < tall_xcds)
@@ -242,15 +236,15 @@ struct UnifiedAttentionKernel
}
else
{
remapped_pid = tall_xcds * pids_per_xcd +
(xcd - tall_xcds) * (pids_per_xcd - 1) +
local_pid;
remapped_pid =
tall_xcds * pids_per_xcd + (xcd - tall_xcds) * (pids_per_xcd - 1) + local_pid;
}
return remapped_pid;
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, const Kargs& kargs)
CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid,
const Kargs& kargs)
{
using namespace ck_tile;
@@ -258,8 +252,8 @@ struct UnifiedAttentionKernel
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
// UnifiedAttentionPipeline::kN1);
const index_t i_tile_m = pid % total_num_q_blocks; // Query block index
const index_t i_tile_n = pid / total_num_q_blocks; // Head index
const index_t i_tile_m = pid % total_num_q_blocks; // Query block index
const index_t i_tile_n = pid / total_num_q_blocks; // Head index
return ck_tile::make_tuple(i_tile_m, i_tile_n);
}
@@ -268,7 +262,8 @@ struct UnifiedAttentionKernel
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(UnifiedAttentionPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
return ck_tile::max(UnifiedAttentionPipeline::GetSmemSize(),
EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -285,7 +280,7 @@ struct UnifiedAttentionKernel
// const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv;
// for simplicity, batch stride we just modify the pointer
// const index_t num_head_q = kargs.num_head_q;
// const index_t num_head_k = num_head_q / num_queries_per_kv;
pid = RemapTileIndices(pid, kargs);
@@ -296,65 +291,76 @@ struct UnifiedAttentionKernel
// grid size is (num_kv_heads, total_num_q_blocks)
// total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
// q.shape[0] is total number of query tokens across all batches
// one q_block spans BLOCK_Q = BLOCK_M // num_queries_per_kv number of query token groups. One query token group shares one kv token
// one q_block spans BLOCK_Q = BLOCK_M // num_queries_per_kv number of query token groups.
// One query token group shares one kv token
const index_t seq_idx = find_seq_idx(
kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, BLOCK_Q, true
); // which batch
const index_t seq_idx = find_seq_idx(kargs.query_start_len_ptr,
q_block_global_idx,
kargs.num_seqs,
BLOCK_Q,
true); // which batch
const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
const index_t q_block_start_idx =
amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
const index_t q_block_local_idx =
amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]);
const index_t cur_batch_in_all_start_index =
amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
const index_t cur_batch_in_all_stop_index =
amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]);
const index_t cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index;
const index_t cur_batch_query_len =
cur_batch_in_all_stop_index - cur_batch_in_all_start_index;
// TODO check if we get the block size info from pipeline
if (q_block_local_idx * BLOCK_Q >= cur_batch_query_len) {
if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len)
{
return;
}
const index_t query_pos = q_block_local_idx * BLOCK_Q;
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
const index_t context_len = seq_len - cur_batch_query_len;
index_t _max_seq_prefix_len = (
context_len
+ q_block_local_idx * BLOCK_Q
+ (BLOCK_M - 1) // num_queries_per_kv
+ 1
);
index_t _max_seq_prefix_len =
(context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) // num_queries_per_kv
+ 1);
if (seq_len < _max_seq_prefix_len) {
if(seq_len < _max_seq_prefix_len)
{
_max_seq_prefix_len = seq_len;
}
const auto max_seq_prefix_len = _max_seq_prefix_len;
const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
// TODO sliding window
const index_t num_blocks_start = 0;
index_t kv_head_offset = kv_head_idx * kargs.stride_k_cache_2;
index_t kv_head_offset = kv_head_idx * kargs.stride_k_cache_2;
// Q/K/V DRAM and DRAM window
index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.query_stride_0; // move the pointer to the batch start
index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.query_stride_1; // move the pointer to the correct head group start
index_t q_ptr_offset_0 = cur_batch_in_all_start_index *
kargs.query_stride_0; // move the pointer to the batch start
index_t q_ptr_offset_1 =
kv_head_idx * num_queries_per_kv *
kargs.query_stride_1; // move the pointer to the correct head group start
index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1;
index_t o_ptr_offset_0 = cur_batch_in_all_start_index * kargs.output_stride_0; // move the pointer to the batch start
index_t o_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.output_stride_1; // move the pointer to the correct head group start
index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1;
index_t o_ptr_offset_0 = cur_batch_in_all_start_index *
kargs.output_stride_0; // move the pointer to the batch start
index_t o_ptr_offset_1 =
kv_head_idx * num_queries_per_kv *
kargs.output_stride_1; // move the pointer to the correct head group start
index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1;
index_t block_table_offset = seq_idx * kargs.block_table_stride;
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) + q_ptr_offset;
const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) + kv_head_offset;
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) + kv_head_offset;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) + o_ptr_offset;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) + o_ptr_offset;
index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q;
// const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0);
@@ -368,37 +374,35 @@ struct UnifiedAttentionKernel
number<UnifiedAttentionPipeline::kAlignmentQ>{},
number<2>{});
const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
q_dram_base,
// block sizes
make_tuple(number<BLOCK_Q>{}, number<1>{}, number<HEAD_SIZE_PADDED>{}),
sequence<true, false, kPadHeadDimQ>{}
); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED)
const auto q_dram_pad =
pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
q_dram_base,
// block sizes
make_tuple(number<BLOCK_Q>{}, number<1>{}, number<HEAD_SIZE_PADDED>{}),
sequence<true, false, kPadHeadDimQ>{}); // pads to (seq_len_padded, num_head_q,
// HEAD_SIZE_PADDED)
const auto q_dram_merged = transform_tensor_view(
q_dram_pad,
make_tuple(
make_merge_transform(
make_tuple(query_len_padded, num_queries_per_kv)
),
make_pass_through_transform(number<HEAD_SIZE_PADDED>{})
),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{})
); // flattens the first two dims, head idx is the fastest changing dim in the merged dim
q_dram_pad,
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
make_pass_through_transform(number<HEAD_SIZE_PADDED>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{},
sequence<1>{})); // flattens the first two dims, head idx is the fastest
// changing dim in the merged dim
return q_dram_merged;
}();
// static_assert(q_dram.desc_[number<0>{}] == 0, "q_dram.get_bottom_tensor_view()[number<0>{}] == 0");
// static_assert(q_dram.desc_[number<0>{}] == 0,
// "q_dram.get_bottom_tensor_view()[number<0>{}] == 0");
// Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim)
// stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1)
auto q_dram_window = make_tile_window(
q_dram,
make_tuple(number<BLOCK_M>{}, number<HEAD_SIZE_PADDED>{}),
{query_pos * num_queries_per_kv, 0}
);
auto q_dram_window =
make_tile_window(q_dram,
make_tuple(number<BLOCK_M>{}, number<HEAD_SIZE_PADDED>{}),
{query_pos * num_queries_per_kv, 0});
const auto k_dram = [&]() {
// HEAD dim is skipped as defined in the ptrs
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -408,24 +412,19 @@ struct UnifiedAttentionKernel
number<UnifiedAttentionPipeline::kAlignmentK>{},
number<1>{});
const auto k_dram_pad = pad_tensor_view(
k_dram_naive,
// TODO can the BLOCK_SIZE_RAW needs padding?
make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED),
sequence<false, false, kPadHeadDimQ>{});
const auto k_dram_pad = pad_tensor_view(k_dram_naive,
// TODO can the BLOCK_SIZE_RAW needs padding?
make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED),
sequence<false, false, kPadHeadDimQ>{});
const auto k_dram_merged = transform_tensor_view(
k_dram_pad,
make_tuple(
make_merge_transform(
make_tuple(kargs.num_blks, BLOCK_SIZE)
),
make_pass_through_transform(HEAD_SIZE_PADDED)
),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{})
); // flattens the first two dims, head idx is the fastest changing dim in the merged dim
k_dram_pad,
make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)),
make_pass_through_transform(HEAD_SIZE_PADDED)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{},
sequence<1>{})); // flattens the first two dims, head idx is the fastest
// changing dim in the merged dim
return k_dram_merged;
}();
@@ -441,53 +440,50 @@ struct UnifiedAttentionKernel
number<UnifiedAttentionPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_pad = pad_tensor_view(
v_dram_naive,
make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED),
sequence<false, false, kPadHeadDimQ>{});
const auto v_dram_pad = pad_tensor_view(v_dram_naive,
make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED),
sequence<false, false, kPadHeadDimQ>{});
const auto v_dram_merged = transform_tensor_view(
v_dram_pad,
make_tuple(
make_merge_transform(
make_tuple(kargs.num_blks, BLOCK_SIZE)
),
make_pass_through_transform(HEAD_SIZE_PADDED)
),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{})
); // flattens the first two dims, head idx is the fastest changing dim in the merged dim
v_dram_pad,
make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)),
make_pass_through_transform(HEAD_SIZE_PADDED)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{},
sequence<1>{})); // flattens the first two dims, head idx is the fastest
// changing dim in the merged dim
return v_dram_merged;
}();
auto v_dram_window = make_tile_window(
v_dram, make_tuple(number<BLOCK_SIZE>{}, number<HEAD_SIZE_PADDED>{}), {0, 0});
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
seq_len - cur_batch_query_len, // y (i.e. context)
cur_batch_query_len, // x (i.e. extend)
seq_len, // y_total (x + y)
cur_batch_query_len, // x_total
num_queries_per_kv // the same sequence index is repeated num_queries_per_kv times along x dim of the tile
cur_batch_query_len, // x (i.e. extend)
seq_len, // y_total (x + y)
cur_batch_query_len, // x_total
num_queries_per_kv // the same sequence index is repeated num_queries_per_kv
// times along x dim of the tile
);
else
return FmhaMask{cur_batch_query_len, seq_len};
}();
auto o_acc_tile = [&]() {
return UnifiedAttentionPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
num_blocks,
num_blocks_start,
kargs.block_tables_ptr,
block_table_offset,
mask,
kargs.scale_s,
smem_ptr);
k_dram_window,
v_dram_window,
num_blocks,
num_blocks_start,
kargs.block_tables_ptr,
block_table_offset,
mask,
kargs.scale_s,
smem_ptr);
}();
// O DRAM and O DRAM window
@@ -499,24 +495,20 @@ struct UnifiedAttentionKernel
number<UnifiedAttentionPipeline::kAlignmentO>{},
number<1>{});
const auto o_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
o_dram_base,
// block sizes
make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED),
sequence<true, false, kPadHeadDimQ>{}
); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED)
const auto o_dram_pad =
pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
o_dram_base,
// block sizes
make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED),
sequence<true, false, kPadHeadDimQ>{}); // pads to (seq_len_padded, num_head_q,
// HEAD_SIZE_PADDED)
const auto o_dram_merged = transform_tensor_view(
o_dram_pad,
make_tuple(
make_merge_transform(
make_tuple(query_len_padded, num_queries_per_kv)
),
make_pass_through_transform(HEAD_SIZE_PADDED)
),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{})
);
o_dram_pad,
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
make_pass_through_transform(HEAD_SIZE_PADDED)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return o_dram_merged;
}();

View File

@@ -47,11 +47,14 @@ struct TileUnifiedAttentionShape
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static constexpr index_t BLOCK_M = BlockTile::at(number<0>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS)
static constexpr index_t BLOCK_Q = BlockTile::at(number<1>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS)
// static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * num_queries_per_kv (q_head//kv_head)
static constexpr index_t BLOCK_M = BlockTile::at(
number<0>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS)
static constexpr index_t BLOCK_Q = BlockTile::at(
number<1>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS)
// static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen *
// num_queries_per_kv (q_head//kv_head)
static constexpr index_t BLOCK_SIZE = BlockTile::at(number<2>{}); // BLOCK size for K seqlen
static constexpr index_t HEAD_SIZE = BlockTile::at(number<3>{}); // BLOCK size for K seqlen
static constexpr index_t HEAD_SIZE = BlockTile::at(number<3>{}); // BLOCK size for K seqlen
// static constexpr index_t kQKHeaddim =
// BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at

View File

@@ -9,14 +9,13 @@
namespace ck_tile {
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDim_ /* paddding for hdim_v */,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileUnifiedAttentionTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDim = kPadHeadDim_;
static constexpr bool kPadHeadDim = kPadHeadDim_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
}
} // namespace ck_tile

View File

@@ -268,14 +268,15 @@ struct UnifiedAttentionPipeline
static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize;
static constexpr ck_tile::index_t BLOCK_M = UnifiedAttentionShape::BLOCK_M;
static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q;
static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE;
static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE;
static constexpr ck_tile::index_t HEAD_SIZE_PADDED = UnifiedAttentionShape::HEAD_SIZE_PADDED;
static constexpr ck_tile::index_t BLOCK_M = UnifiedAttentionShape::BLOCK_M;
static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q;
static_assert(HEAD_SIZE_PADDED <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE;
static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE;
static constexpr ck_tile::index_t HEAD_SIZE_PADDED = UnifiedAttentionShape::HEAD_SIZE_PADDED;
static_assert(HEAD_SIZE_PADDED <= 256,
"hdim bigger than 256 is not suitable for this pipeline!");
// static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDim;
@@ -402,12 +403,13 @@ struct UnifiedAttentionPipeline
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
static_assert(
BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize());
auto s_lds = make_tensor_view<address_space_enum::lds>(
@@ -542,9 +544,11 @@ struct UnifiedAttentionPipeline
const auto q_origin = q_dram_window.get_window_origin();
// const auto [seqlen_k_start, seqlen_k_end] =
// mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<BLOCK_M>{}, number<BLOCK_SIZE>{});
// mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<BLOCK_M>{},
// number<BLOCK_SIZE>{});
// const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, BLOCK_SIZE);
// const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start,
// BLOCK_SIZE);
const auto num_total_loop = num_blocks;
// index_t kv_token_start = seqlen_k_start;
@@ -554,7 +558,6 @@ struct UnifiedAttentionPipeline
{
if(num_total_loop - num_blocks_start <= 0)
{
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
@@ -562,11 +565,11 @@ struct UnifiedAttentionPipeline
}
}
index_t i_total_loops = num_blocks_start;
const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
index_t kv_blk_idx_prev = 0;
index_t i_total_loops = num_blocks_start;
const ck_tile::index_t* block_tables_ptr_ =
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
index_t kv_blk_idx_prev = 0;
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -672,9 +675,10 @@ struct UnifiedAttentionPipeline
auto K_mem_load = [&](auto k_lds_write_idx) {
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
// TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx as the index
// TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx
// as the index
/// FIXME: use the future-predicting method to move the window
k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
};
auto K_lds_load = [&](auto k_lds_read_idx) {
@@ -683,9 +687,9 @@ struct UnifiedAttentionPipeline
auto V_mem_load = [&](auto v_lds_write_idx) {
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
// kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
/// FIXME: use the future-predicting method to move the window
v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
// kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
/// FIXME: use the future-predicting method to move the window
v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
};
auto V_lds_load = [&](auto v_lds_read_idx) {
@@ -894,8 +898,10 @@ struct UnifiedAttentionPipeline
auto fmha_mask = [&](auto sp_reg_idx) {
if constexpr(FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
q_origin.at(number<0>{}), i_total_loops * BLOCK_SIZE, number<BLOCK_M>{}, number<BLOCK_SIZE>{});
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
i_total_loops * BLOCK_SIZE,
number<BLOCK_M>{},
number<BLOCK_SIZE>{});
if(need_perpixel_check)
{
set_tile_if(sp(sp_reg_idx).sp_compute,
@@ -903,7 +909,8 @@ struct UnifiedAttentionPipeline
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = i_total_loops * BLOCK_SIZE + tile_idx.at(number<1>{});
const auto col =
i_total_loops * BLOCK_SIZE + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
@@ -1180,7 +1187,6 @@ struct UnifiedAttentionPipeline
fmha_post_process(number<0>{});
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();

View File

@@ -141,11 +141,11 @@ struct UnifiedAttentionPipelineDefaultPolicy
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr index_t N0 = NumIssues; // 8
constexpr index_t N0 = NumIssues; // 8
constexpr index_t N1 = LaneGroups; // 2
constexpr index_t N2 = NumWarps; // 8
constexpr index_t K0 = LanesPerK; // 32
constexpr index_t K1 = KVector; // 4
constexpr index_t N2 = NumWarps; // 8
constexpr index_t K0 = LanesPerK; // 32
constexpr index_t K1 = KVector; // 4
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
@@ -259,13 +259,13 @@ struct UnifiedAttentionPipelineDefaultPolicy
}
}();
using BlockGemmPolicy =
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::UnifiedAttentionShape::Gemm0BlockWarps,
decltype(warp_gemm),
GemmLoopOrder::MNK>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::UnifiedAttentionShape::Gemm0BlockWarps,
decltype(warp_gemm),
GemmLoopOrder::MNK>;
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
@@ -287,24 +287,25 @@ struct UnifiedAttentionPipelineDefaultPolicy
typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>;
/// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass
/// WGAttrNumAccessEnum::Double instead of WGAttrNumAccessEnum::Single
using WarpGemm = WarpGemmDispatcher<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<0>{}),
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}),
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Double>;
using WarpGemm =
WarpGemmDispatcher<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<0>{}),
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}),
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Double>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
typename Problem::UnifiedAttentionShape::Gemm1BlockWarps,
WarpGemm,
GemmLoopOrder::MNK>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
typename Problem::UnifiedAttentionShape::Gemm1BlockWarps,
WarpGemm,
GemmLoopOrder::MNK>;
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}

View File

@@ -19,23 +19,23 @@ template <typename QDataType_,
typename OaccDataType_,
typename ODataType_,
typename UnifiedAttentionShape_,
typename FmhaMask_,
typename FmhaMask_,
typename Traits_>
struct UnifiedAttentionPipelineProblem
{
// TODO kM0 and KN1??
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
// first gemm accumulation dtype
using SaccDataType = remove_cvref_t<SaccDataType_>;
using SaccDataType = remove_cvref_t<SaccDataType_>;
// Softmax dtype
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>;
using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
// data type for A matrix of second gemm
using PDataType = remove_cvref_t<PDataType_>;
// data type for second gemm accumulation
using PDataType = remove_cvref_t<PDataType_>;
// data type for second gemm accumulation
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using UnifiedAttentionShape = remove_cvref_t<UnifiedAttentionShape_>;
@@ -48,11 +48,11 @@ struct UnifiedAttentionPipelineProblem
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDim = Traits::kPadHeadDim;
static constexpr bool kPadHeadDim = Traits::kPadHeadDim;
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
}
} // namespace ck_tile