From 618ed6defb2d7959ae3dcf3db43be74a241141e6 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 11 Nov 2025 14:35:26 +0000 Subject: [PATCH] cmake list update --- .../01_unified_attention/CMakeLists.txt | 34 +- .../example_unified_attention.cpp | 227 +++++----- .../unified_attention.cpp | 12 +- .../unified_attention.hpp | 10 +- .../unified_attention_impl.hpp | 110 ++--- include/ck_tile/ops/unified_attention.hpp | 16 +- .../unified_attention/block/block_masking.hpp | 14 +- .../kernel/unified_attention_kernel.hpp | 402 +++++++++--------- .../pipeline/tile_unified_attention_shape.hpp | 11 +- .../tile_unified_attention_traits.hpp | 5 +- .../pipeline/unified_attention_pipeline.hpp | 66 +-- ...fied_attention_pipeline_default_policy.hpp | 57 +-- .../unified_attention_pipeline_problem.hpp | 18 +- 13 files changed, 495 insertions(+), 487 deletions(-) diff --git a/example/ck_tile/01_unified_attention/CMakeLists.txt b/example/ck_tile/01_unified_attention/CMakeLists.txt index 11c413d192..45f67f3e0d 100644 --- a/example/ck_tile/01_unified_attention/CMakeLists.txt +++ b/example/ck_tile/01_unified_attention/CMakeLists.txt @@ -187,36 +187,42 @@ if(NOT INST_TARGETS) return() endif() -set(EXAMPLE_FMHA_FWD_V3 "tile_example_unified_attention") -message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") +set(EXAMPLE_UNIFIED_ATTENTION "tile_example_unified_attention") +message(DEBUG "adding example ${EXAMPLE_UNIFIED_ATTENTION}") -add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_unified_attention.cpp) -target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS +add_executable(${EXAMPLE_UNIFIED_ATTENTION} EXCLUDE_FROM_ALL example_unified_attention.cpp) +target_include_directories(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +file(GLOB UNIFIED_ATTENTION_INSTANCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" ) -target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE +target_sources(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE unified_attention.cpp - ${FMHA_FWD_V3_INSTANCES} + ${UNIFIED_ATTENTION_INSTANCES} ) -set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS) -list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS +set(EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS) +list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero -Wno-undefined-func-template --save-temps ) -set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) +set(EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS) check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) if(HAS_DISABLE_PACKED_FP32) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS + list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS -mllvm --amdgpu-disable-packed-fp32=1 ) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS + list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS -DCK_TILE_DISABLE_PACKED_FP32=1 ) endif() -target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) -target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS}) +target_compile_options(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS}) +target_compile_definitions(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 50ac6ea94c..b103043b70 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -93,20 +93,20 @@ struct Problem { explicit Problem(const ck_tile::ArgParser& args) { - data_type = args.get_str("prec") == "fp16" - ? ck_tile::unified_attention_args::data_type_enum::fp16 - : ck_tile::unified_attention_args::data_type_enum::bf16; - batch = args.get_int("b"); - max_seqlen_q = args.get_int("s"); + data_type = args.get_str("prec") == "fp16" + ? ck_tile::unified_attention_args::data_type_enum::fp16 + : ck_tile::unified_attention_args::data_type_enum::bf16; + batch = args.get_int("b"); + max_seqlen_q = args.get_int("s"); max_context_len = args.get_int("s_k"); - num_blks = args.get_int("nb"); - BLOCK_SIZE = args.get_int("bs"); - nhead_q = args.get_int("h"); - nhead_kv = args.get_int("h_k"); + num_blks = args.get_int("nb"); + BLOCK_SIZE = args.get_int("bs"); + nhead_q = args.get_int("h"); + nhead_kv = args.get_int("h_k"); - hdim = args.get_int("d"); + hdim = args.get_int("d"); query_lens = args.get_int_vec("query_lens"); - kv_lens = args.get_int_vec("kv_lens"); + kv_lens = args.get_int_vec("kv_lens"); // Calculate scale_s scale_s = args.get_float("scale_s"); @@ -114,14 +114,15 @@ struct Problem scale_s = 1.0f / ck_tile::sqrt(static_cast(hdim)); // Initialize other scales - scale = args.get_float("scale"); + scale = args.get_float("scale"); scale_k = args.get_float("scale_k"); scale_v = args.get_float("scale_v"); // Calculate sums of query_lens and kv_lens if provided // int64_t kv_lens_sum = 0; - for (const auto& len : query_lens) { + for(const auto& len : query_lens) + { num_tokens += len; } @@ -130,10 +131,7 @@ struct Problem // } } - std::vector get_query_shape() const - { - return {num_tokens, nhead_q, hdim}; - } + std::vector get_query_shape() const { return {num_tokens, nhead_q, hdim}; } std::vector get_key_shape() const { @@ -145,11 +143,7 @@ struct Problem return {num_blks, BLOCK_SIZE, nhead_kv, hdim}; } - std::vector get_output_shape() const - { - return {num_tokens, nhead_q, hdim}; - - } + std::vector get_output_shape() const { return {num_tokens, nhead_q, hdim}; } ck_tile::unified_attention_args::data_type_enum data_type; ck_tile::index_t batch; @@ -209,7 +203,6 @@ auto generate_qkv(const Problem& problem, return std::make_tuple(q, k, v); } - // namespace host { // template q_host_ref({nhead_q, seqlen_q, hdim_qk}); - // ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); - // ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); - // ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); +// ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); +// ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); +// ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); +// ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); - // ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); - // ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); +// ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); +// ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); - // // do computation for each batch - // for(int b = 0; b < batch_size; ++b) - // { - // // copy per-batch data from input tensors - // // clang-format off - // q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); }); - // k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); }); - // v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); - // // clang-format on - // ck_tile::reference_batched_gemm( - // q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); +// // do computation for each batch +// for(int b = 0; b < batch_size; ++b) +// { +// // copy per-batch data from input tensors +// // clang-format off +// q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , +// idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], +// idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = +// v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); +// // clang-format on +// ck_tile::reference_batched_gemm( +// q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - // if(mask.type == mask_enum::no_mask) - // { - // ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); - // } - // else if(mask.type == mask_enum::window_generic) - // { - // ck_tile::reference_batched_masking( - // s_host_ref, - // ck_tile::make_generic_attention_mask_from_lr_window( - // mask.left, mask.right, seqlen_q, seqlen_kv)); - // } - // else - // { - // // if left window size is negative, means causal - // // else means generic (for current batch) - // if(mask.left < 0) - // ck_tile::reference_batched_masking( - // s_host_ref, - // ck_tile::make_generic_attention_mask_from_lr_window( - // mask.left, - // mask.right, - // seqlen_q, - // seqlen_kv, - // mask.type == mask_enum::mask_top_left)); - // else - // ck_tile::reference_batched_masking( - // s_host_ref, - // ck_tile::make_generic_attention_mask_from_lr_window( - // mask.left, - // mask.right, - // seqlen_q, - // seqlen_kv, - // mask.type == mask_enum::mask_top_left)); - // } +// if(mask.type == mask_enum::no_mask) +// { +// ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); +// } +// else if(mask.type == mask_enum::window_generic) +// { +// ck_tile::reference_batched_masking( +// s_host_ref, +// ck_tile::make_generic_attention_mask_from_lr_window( +// mask.left, mask.right, seqlen_q, seqlen_kv)); +// } +// else +// { +// // if left window size is negative, means causal +// // else means generic (for current batch) +// if(mask.left < 0) +// ck_tile::reference_batched_masking( +// s_host_ref, +// ck_tile::make_generic_attention_mask_from_lr_window( +// mask.left, +// mask.right, +// seqlen_q, +// seqlen_kv, +// mask.type == mask_enum::mask_top_left)); +// else +// ck_tile::reference_batched_masking( +// s_host_ref, +// ck_tile::make_generic_attention_mask_from_lr_window( +// mask.left, +// mask.right, +// seqlen_q, +// seqlen_kv, +// mask.type == mask_enum::mask_top_left)); +// } - // ck_tile::reference_batched_softmax( - // s_host_ref, p_host_ref, ck_tile::identity{}); +// ck_tile::reference_batched_softmax( +// s_host_ref, p_host_ref, ck_tile::identity{}); - // ck_tile::reference_batched_gemm( - // p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); +// ck_tile::reference_batched_gemm( +// p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - // // copy resulting per-batch data to the output tensor - // o_host_ref.ForEach( - // [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); - // } +// // copy resulting per-batch data to the output tensor +// o_host_ref.ForEach( +// [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); +// } // } // } // namespace host @@ -328,20 +322,20 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::unified_attention_args args{}; - args.data_type = problem.data_type; - args.num_seqs = problem.batch; + args.data_type = problem.data_type; + args.num_seqs = problem.batch; // args.seqlen_q = problem.seqlen_q; // args.seqlen_k = problem.seqlen_k; - args.num_head_q = problem.nhead_q; - args.num_queries_per_kv = problem.nhead_q / problem.nhead_kv; - args.mask_type = 2; - args.hdim = problem.hdim; + args.num_head_q = problem.nhead_q; + args.num_queries_per_kv = problem.nhead_q / problem.nhead_kv; + args.mask_type = 2; + args.hdim = problem.hdim; args.num_blks = problem.num_blks; // args.query_lens = problem.query_lens // args.kv_lens = problem.kv_lens - args.q_ptr = q_buf.GetDeviceBuffer(); + args.q_ptr = q_buf.GetDeviceBuffer(); args.query_stride_0 = problem.hdim * problem.nhead_q; args.query_stride_0 = problem.hdim; @@ -352,13 +346,13 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.stride_k_cache_2 = problem.hdim; args.stride_k_cache_3 = 1; - args.v_ptr = v_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); args.stride_v_cache_0 = args.stride_k_cache_0; args.stride_v_cache_1 = args.stride_k_cache_1; args.stride_v_cache_2 = args.stride_k_cache_2; args.stride_v_cache_3 = args.stride_k_cache_3; - args.o_ptr = o_buf.GetDeviceBuffer(); + args.o_ptr = o_buf.GetDeviceBuffer(); args.output_stride_0 = args.query_stride_0; args.output_stride_1 = args.query_stride_1; @@ -380,13 +374,13 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) return eff; }; - const auto eff_query_lens = make_effective_vec(problem.query_lens, 1024); - const auto eff_kv_lens = make_effective_vec(problem.kv_lens, 1024); + const auto eff_query_lens = make_effective_vec(problem.query_lens, 1024); + const auto eff_kv_lens = make_effective_vec(problem.kv_lens, 1024); args.num_tokens = std::accumulate(eff_query_lens.begin(), eff_query_lens.end(), 0); // Calculate cumulative sums for kernel arguments if varlen is used - std::vector cu_query_lens ; + std::vector cu_query_lens; auto calculate_cumulative = [&](const std::vector& per_batch_vec, std::vector& cum_vec) { @@ -403,14 +397,16 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) seq_lens_buf.ToDevice(eff_kv_lens.data()); query_start_len_buf.ToDevice(cu_query_lens.data()); - args.seq_lens_ptr =reinterpret_cast(seq_lens_buf.GetDeviceBuffer()); - args.query_start_len_ptr =reinterpret_cast(query_start_len_buf.GetDeviceBuffer()); - + args.seq_lens_ptr = reinterpret_cast(seq_lens_buf.GetDeviceBuffer()); + args.query_start_len_ptr = + reinterpret_cast(query_start_len_buf.GetDeviceBuffer()); auto max_element = [&](const std::vector& opt_vec) { ck_tile::index_t max = opt_vec[0]; - for (ck_tile::index_t i: opt_vec) { - if (i > max){ + for(ck_tile::index_t i : opt_vec) + { + if(i > max) + { max = i; } } @@ -419,10 +415,12 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::index_t max_kv_len = max_element(eff_kv_lens); - ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE; + ck_tile::index_t max_num_blocks_per_seq = + (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE; // Create block_tables - ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * + sizeof(ck_tile::index_t)); // Allocate host memory for block_tables std::vector block_tables_host(problem.batch * max_num_blocks_per_seq); @@ -430,7 +428,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // Fill block_tables with random integers between 0 and num_blocks-1 std::mt19937 rng(run_config.seed ? *run_config.seed : std::random_device{}()); std::uniform_int_distribution dist(0, problem.num_blks - 1); - for (size_t i = 0; i < block_tables_host.size(); ++i) { + for(size_t i = 0; i < block_tables_host.size(); ++i) + { block_tables_host[i] = dist(rng); } @@ -438,10 +437,10 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) block_tables_buf.ToDevice(block_tables_host.data()); // Set pointer in args - args.block_tables_ptr = reinterpret_cast(block_tables_buf.GetDeviceBuffer()); + args.block_tables_ptr = + reinterpret_cast(block_tables_buf.GetDeviceBuffer()); args.block_table_stride = max_num_blocks_per_seq; - ck_tile::stream_config stream_config{nullptr, true, /*log_level=*/0, @@ -476,7 +475,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv // << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim // << ", scale_s:" << problem.sacle_s << ", mask:" << problem.mask << std::fixed - // << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops + // << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << + // tflops // << " TFlops" << std::endl; // if(!run_config.verify) @@ -548,7 +548,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // } // } // } - // ck_tile::HostTensor o(problem.get_output_shape()); // o_buf.FromDevice(o.data()); diff --git a/example/ck_tile/01_unified_attention/unified_attention.cpp b/example/ck_tile/01_unified_attention/unified_attention.cpp index 8c2b22f0a2..fb3e37e1e0 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/unified_attention.cpp @@ -7,7 +7,8 @@ namespace ck_tile { -std::ostream& operator<<(std::ostream& stream, const unified_attention_args::data_type_enum& data_type) +std::ostream& operator<<(std::ostream& stream, + const unified_attention_args::data_type_enum& data_type) { switch(data_type) { @@ -17,14 +18,16 @@ std::ostream& operator<<(std::ostream& stream, const unified_attention_args::dat } } -std::pair unified_attention(const unified_attention_args& args, const stream_config& config) +std::pair unified_attention(const unified_attention_args& args, + const stream_config& config) { if(args.data_type == unified_attention_args::data_type_enum::fp16) { if(args.mask_type == static_cast(mask_enum::no_mask)) { using kernel_traits = - unified_attention_kernel_traits; + unified_attention_kernel_traits; return unified_attention_kernel_dispatch(args, config); } @@ -41,7 +44,8 @@ std::pair unified_attention(const unified_attention_args& args, con if(args.mask_type == static_cast(mask_enum::no_mask)) { using kernel_traits = - unified_attention_kernel_traits; + unified_attention_kernel_traits; return unified_attention_kernel_dispatch(args, config); } diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/01_unified_attention/unified_attention.hpp index 50462d3110..083a3acd85 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention.hpp @@ -52,24 +52,26 @@ struct unified_attention_args index_t stride_v_cache_1; index_t stride_v_cache_2; index_t stride_v_cache_3; - + void* o_ptr; index_t output_stride_0; index_t output_stride_1; const int32_t* block_tables_ptr; index_t block_table_stride; - const int32_t* seq_lens_ptr; // seq len in each batch + const int32_t* seq_lens_ptr; // seq len in each batch const int32_t* query_start_len_ptr; // [num_seqs+1] index_t num_seqs; // number of batches for q }; -std::ostream& operator<<(std::ostream& stream, const unified_attention_args::data_type_enum& data_type); +std::ostream& operator<<(std::ostream& stream, + const unified_attention_args::data_type_enum& data_type); // return value: // first = whether the kernel was launched (true = launched, false = skipped) // second = elapsed time (ms) of the kernel launch, valid only if first == true -std::pair unified_attention(const unified_attention_args& args, const stream_config& config); +std::pair unified_attention(const unified_attention_args& args, + const stream_config& config); } // namespace ck_tile diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 64aead84f5..dc3104e4f2 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -20,13 +20,13 @@ #include "unified_attention.hpp" #include "mask.hpp" -#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \ - template <> \ - std::pair unified_attention_kernel_dispatch( \ - const unified_attention_args& args, const stream_config& config) \ - { \ - return std::make_pair(true, \ - unified_attention_kernel_launch(args, config)); \ +#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \ + template <> \ + std::pair unified_attention_kernel_dispatch( \ + const unified_attention_args& args, const stream_config& config) \ + { \ + return std::make_pair( \ + true, unified_attention_kernel_launch(args, config)); \ } namespace ck_tile { @@ -55,8 +55,8 @@ struct unified_attention_problem_traits struct unified_attention_kernel_traits { - static constexpr auto date_type = DataType; - static constexpr bool is_masking = IsMasking; + static constexpr auto date_type = DataType; + static constexpr bool is_masking = IsMasking; // BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE using unified_attention_block_tile = sequence<256, 64, 128, 128>; @@ -64,34 +64,34 @@ struct unified_attention_kernel_traits using unified_attention_block_warps = sequence<8, 1, 1>; using unified_attention_shape = TileUnifiedAttentionShape; + unified_attention_block_warps, + unified_attention_warp_gemm_shape, + unified_attention_block_warps, + unified_attention_warp_gemm_shape, + true // IsVLayoutRowMajor + >; using unified_attention_traits = TileUnifiedAttentionTraits; + false, // kPadHeadDimQ + -1 // kBlockPerCu + >; using unified_attention_mask = GenericAttentionMask; - using unified_attention_pipeline_problem = - UnifiedAttentionPipelineProblem::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::lse_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::o_dtype, - unified_attention_shape, - unified_attention_mask, - unified_attention_traits>; + using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem< + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::lse_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::o_dtype, + unified_attention_shape, + unified_attention_mask, + unified_attention_traits>; using unified_attention_pipeline = UnifiedAttentionPipeline; @@ -107,11 +107,12 @@ struct unified_attention_kernel_traits }; template -float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) +float unified_attention_kernel_launch(const unified_attention_args& args, + const stream_config& config) { - + index_t BLOCK_Q = Kernel::BLOCK_M / args.num_queries_per_kv; - + index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, @@ -128,26 +129,25 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const args.scale_out, total_num_q_blocks, args.query_stride_0, - args.query_stride_1, - args.stride_k_cache_0, - args.stride_k_cache_1, - args.stride_k_cache_2, - args.stride_k_cache_3, - args.stride_v_cache_0, - args.stride_v_cache_1, - args.stride_v_cache_2, - args.stride_v_cache_3, - args.output_stride_0, - args.output_stride_1, - args.block_tables_ptr, - args.block_table_stride, - args.seq_lens_ptr, - args.query_start_len_ptr, - args.num_seqs - ); + args.query_stride_1, + args.stride_k_cache_0, + args.stride_k_cache_1, + args.stride_k_cache_2, + args.stride_k_cache_3, + args.stride_v_cache_0, + args.stride_v_cache_1, + args.stride_v_cache_2, + args.stride_v_cache_3, + args.output_stride_0, + args.output_stride_1, + args.block_tables_ptr, + args.block_table_stride, + args.seq_lens_ptr, + args.query_start_len_ptr, + args.num_seqs); - dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); + constexpr dim3 blocks = Kernel::BlockSize(); constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); @@ -158,6 +158,6 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const // second = elapsed time (ms) of the kernel launch, valid only if first == true template std::pair unified_attention_kernel_dispatch(const unified_attention_args& args, - const stream_config& config); + const stream_config& config); } // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention.hpp b/include/ck_tile/ops/unified_attention.hpp index 62e6c58acb..20eee5a819 100644 --- a/include/ck_tile/ops/unified_attention.hpp +++ b/include/ck_tile/ops/unified_attention.hpp @@ -3,12 +3,6 @@ #pragma once - -#include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/common/utils.hpp" - -// Block-level components #include "ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/unified_attention/block/block_dropout.hpp" #include "ck_tile/ops/unified_attention/block/block_masking.hpp" @@ -16,15 +10,15 @@ #include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp" #include "ck_tile/ops/unified_attention/block/page_block_navigator.hpp" #include "ck_tile/ops/unified_attention/block/variants.hpp" - -// Kernel-level components #include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp" - -// Pipeline-level components #include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp" #include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp" - +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/unified_attention/block/block_masking.hpp b/include/ck_tile/ops/unified_attention/block/block_masking.hpp index 87868a56a1..33ca84d2c5 100644 --- a/include/ck_tile/ops/unified_attention/block/block_masking.hpp +++ b/include/ck_tile/ops/unified_attention/block/block_masking.hpp @@ -86,19 +86,22 @@ struct GenericAttentionMask static constexpr const char* name = impl::MaskName::name; // New constructor accepting repeat_idx with default value 1 - CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1) + CK_TILE_HOST_DEVICE + GenericAttentionMask(index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1) : GenericAttentionMask(0, 0, y_total_, x_total_, repeat_idx_) { } CK_TILE_HOST_DEVICE - GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1) + GenericAttentionMask( + index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1) : y(y_), x(x_), y_total(y_total_), x_total(x_total_), repeat_idx(repeat_idx_) { } template - CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord, index_t repeat_idx_ = 1) + CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord, + index_t repeat_idx_ = 1) : y(mask_coord.at(number<0>{})), x(mask_coord.at(number<1>{})), y_total(mask_coord.at(number<2>{})), @@ -248,13 +251,12 @@ struct GenericAttentionMask } } -private: + private: index_t y, x; index_t y_total, x_total; index_t repeat_idx; }; - // TODO: prefer use this function in host code // can convert from the FA style left/right to our generic coordinate // if left_size < 0 && right_size = 0, it is normal causal mask @@ -289,7 +291,7 @@ make_generic_attention_mask_from_lr_window(index_t left_size, index_t y_total, index_t x_total, index_t repeat_idx = 1, - bool is_top_left = true) + bool is_top_left = true) { auto r = make_generic_attention_mask_coordinates_from_lr_window( left_size, right_size, y_total, x_total, is_top_left); diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index ee4eeab920..ac7a06a961 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -18,8 +18,8 @@ namespace ck_tile { template struct UnifiedAttentionKernel { - using UnifiedAttentionPipeline = ck_tile::remove_cvref_t; - using EpiloguePipeline = ck_tile::remove_cvref_t; + using UnifiedAttentionPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = UnifiedAttentionPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = UnifiedAttentionPipeline::kBlockPerCu; static_assert(kBlockPerCu > 0); @@ -29,18 +29,18 @@ struct UnifiedAttentionKernel using VDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; - using FmhaMask = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; - + static constexpr bool kPadSeqLenK = UnifiedAttentionPipeline::kPadSeqLenK; static constexpr bool kPadSeqLenQ = UnifiedAttentionPipeline::kPadSeqLenQ; static constexpr bool kPadHeadDimQ = UnifiedAttentionPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = UnifiedAttentionPipeline::kPadHeadDimV; - // TODO add yjese - static constexpr index_t HEAD_SIZE = UnifiedAttentionPipeline::HEAD_SIZE; + // TODO add yjese + static constexpr index_t HEAD_SIZE = UnifiedAttentionPipeline::HEAD_SIZE; static constexpr index_t HEAD_SIZE_PADDED = UnifiedAttentionPipeline::HEAD_SIZE_PADDED; - + // BLOCK_Q = BLOCK_M // num_queries_per_kv // BLOCK_Q is the block size for q seqlen /// static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q; @@ -49,7 +49,6 @@ struct UnifiedAttentionKernel // BLOCK size for K seqlen static constexpr index_t BLOCK_SIZE = UnifiedAttentionPipeline::BLOCK_SIZE; - // kargs use aggregate initializer, so no constructor will provided // use inheritance to minimize karg size // user need to use MakeKargs() function to create kargs. @@ -88,12 +87,11 @@ struct UnifiedAttentionKernel ck_tile::index_t output_stride_1; }; - - struct UnifiedAttentionVarlenKargs: UnifiedAttentionCommonKargs + struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs { const int32_t* block_tables_ptr; ck_tile::index_t block_table_stride; - const int32_t* seq_lens_ptr; // seq len in each batch + const int32_t* seq_lens_ptr; // seq len in each batch const int32_t* query_start_len_ptr; // [num_seqs+1] ck_tile::index_t num_seqs; // number of batches for q @@ -101,38 +99,36 @@ struct UnifiedAttentionKernel using Kargs = UnifiedAttentionVarlenKargs; - CK_TILE_HOST static constexpr Kargs MakeKargs( - const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - void* o_ptr, - ck_tile::index_t num_blks, - ck_tile::index_t num_head_q, - const ck_tile::index_t num_queries_per_kv, - float scale_s, - float scale, - float scale_k, - float scale_v, - float scale_out, - ck_tile::index_t total_num_q_blocks, - ck_tile::index_t query_stride_0, - ck_tile::index_t query_stride_1, - ck_tile::index_t stride_k_cache_0, - ck_tile::index_t stride_k_cache_1, - ck_tile::index_t stride_k_cache_2, - ck_tile::index_t stride_k_cache_3, - ck_tile::index_t stride_v_cache_0, - ck_tile::index_t stride_v_cache_1, - ck_tile::index_t stride_v_cache_2, - ck_tile::index_t stride_v_cache_3, - ck_tile::index_t output_stride_0, - ck_tile::index_t output_stride_1, - const int32_t* block_tables_ptr, - ck_tile::index_t block_table_stride, - const int32_t* seq_lens_ptr, - const int32_t* query_start_len_ptr, - ck_tile::index_t num_seqs - ) + CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + void* o_ptr, + ck_tile::index_t num_blks, + ck_tile::index_t num_head_q, + const ck_tile::index_t num_queries_per_kv, + float scale_s, + float scale, + float scale_k, + float scale_v, + float scale_out, + ck_tile::index_t total_num_q_blocks, + ck_tile::index_t query_stride_0, + ck_tile::index_t query_stride_1, + ck_tile::index_t stride_k_cache_0, + ck_tile::index_t stride_k_cache_1, + ck_tile::index_t stride_k_cache_2, + ck_tile::index_t stride_k_cache_3, + ck_tile::index_t stride_v_cache_0, + ck_tile::index_t stride_v_cache_1, + ck_tile::index_t stride_v_cache_2, + ck_tile::index_t stride_v_cache_3, + ck_tile::index_t output_stride_0, + ck_tile::index_t output_stride_1, + const int32_t* block_tables_ptr, + ck_tile::index_t block_table_stride, + const int32_t* seq_lens_ptr, + const int32_t* query_start_len_ptr, + ck_tile::index_t num_seqs) { Kargs kargs{{q_ptr, k_ptr, @@ -146,31 +142,30 @@ struct UnifiedAttentionKernel scale_k, scale_v, scale_out, - total_num_q_blocks, - query_stride_0, - query_stride_1, - stride_k_cache_0, - stride_k_cache_1, - stride_k_cache_2, - stride_k_cache_3, - stride_v_cache_0, - stride_v_cache_1, - stride_v_cache_2, - stride_v_cache_3, - output_stride_0, - output_stride_1}, - block_tables_ptr, - block_table_stride, - seq_lens_ptr, - query_start_len_ptr, - num_seqs - }; + total_num_q_blocks, + query_stride_0, + query_stride_1, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + output_stride_0, + output_stride_1}, + block_tables_ptr, + block_table_stride, + seq_lens_ptr, + query_start_len_ptr, + num_seqs}; return kargs; } CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads, - ck_tile::index_t total_num_q_blocks) + ck_tile::index_t total_num_q_blocks) { return dim3(num_kv_heads * total_num_q_blocks, 0, 0); } @@ -190,16 +185,16 @@ struct UnifiedAttentionKernel ck_tile::index_t block_q, bool use_q_block_mode) { - ck_tile::index_t left = 0; + ck_tile::index_t left = 0; ck_tile::index_t right = num_seqs; - - while (left < right) + + while(left < right) { - ck_tile::index_t mid = (left + right) / 2; - ck_tile::index_t val = query_start_len_ptr[mid]; + ck_tile::index_t mid = (left + right) / 2; + ck_tile::index_t val = query_start_len_ptr[mid]; ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val; - - if (mid_val <= target_idx) + + if(mid_val <= target_idx) { left = mid + 1; } @@ -208,32 +203,31 @@ struct UnifiedAttentionKernel right = mid; } } - + return left - 1; } - - CK_TILE_DEVICE static constexpr auto - RemapTileIndices(const ck_tile::index_t pid, const Kargs& kargs) + + CK_TILE_DEVICE static constexpr auto RemapTileIndices(const ck_tile::index_t pid, + const Kargs& kargs) { using namespace ck_tile; constexpr index_t NUM_XCDS = 8; - const index_t GRID_MN = kargs.total_num_q_blocks * - (kargs.num_head_q); - + const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q); + // Number of pids per XCD in the new arrangement const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; - + // When GRID_MN cannot divide NUM_XCDS, some xcds will have // pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. // We calculate the number of xcds that have pids_per_xcd pids as tall_xcds index_t tall_xcds = GRID_MN % NUM_XCDS; - tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds; - + tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds; + // Compute current XCD and local pid within the XCD - const index_t xcd = pid % NUM_XCDS; + const index_t xcd = pid % NUM_XCDS; const index_t local_pid = pid / NUM_XCDS; - + // Calculate new pid based on the new grouping index_t remapped_pid = 0; // Initialize to avoid constexpr error if(xcd < tall_xcds) @@ -242,15 +236,15 @@ struct UnifiedAttentionKernel } else { - remapped_pid = tall_xcds * pids_per_xcd + - (xcd - tall_xcds) * (pids_per_xcd - 1) + - local_pid; + remapped_pid = + tall_xcds * pids_per_xcd + (xcd - tall_xcds) * (pids_per_xcd - 1) + local_pid; } - + return remapped_pid; } - CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, const Kargs& kargs) + CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, + const Kargs& kargs) { using namespace ck_tile; @@ -258,8 +252,8 @@ struct UnifiedAttentionKernel // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, // UnifiedAttentionPipeline::kN1); - const index_t i_tile_m = pid % total_num_q_blocks; // Query block index - const index_t i_tile_n = pid / total_num_q_blocks; // Head index + const index_t i_tile_m = pid % total_num_q_blocks; // Query block index + const index_t i_tile_n = pid / total_num_q_blocks; // Head index return ck_tile::make_tuple(i_tile_m, i_tile_n); } @@ -268,7 +262,8 @@ struct UnifiedAttentionKernel CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return ck_tile::max(UnifiedAttentionPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + return ck_tile::max(UnifiedAttentionPipeline::GetSmemSize(), + EpiloguePipeline::GetSmemSize()); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -285,7 +280,7 @@ struct UnifiedAttentionKernel // const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; // for simplicity, batch stride we just modify the pointer // const index_t num_head_q = kargs.num_head_q; - + // const index_t num_head_k = num_head_q / num_queries_per_kv; pid = RemapTileIndices(pid, kargs); @@ -296,65 +291,76 @@ struct UnifiedAttentionKernel // grid size is (num_kv_heads, total_num_q_blocks) // total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs // q.shape[0] is total number of query tokens across all batches - // one q_block spans BLOCK_Q = BLOCK_M // num_queries_per_kv number of query token groups. One query token group shares one kv token + // one q_block spans BLOCK_Q = BLOCK_M // num_queries_per_kv number of query token groups. + // One query token group shares one kv token - const index_t seq_idx = find_seq_idx( - kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, BLOCK_Q, true - ); // which batch + const index_t seq_idx = find_seq_idx(kargs.query_start_len_ptr, + q_block_global_idx, + kargs.num_seqs, + BLOCK_Q, + true); // which batch - const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); + const index_t q_block_start_idx = + amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); - const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); + const index_t q_block_local_idx = + amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); - const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); - const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]); + const index_t cur_batch_in_all_start_index = + amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); + const index_t cur_batch_in_all_stop_index = + amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]); - const index_t cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index; + const index_t cur_batch_query_len = + cur_batch_in_all_stop_index - cur_batch_in_all_start_index; // TODO check if we get the block size info from pipeline - if (q_block_local_idx * BLOCK_Q >= cur_batch_query_len) { + if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len) + { return; } const index_t query_pos = q_block_local_idx * BLOCK_Q; - const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; + const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; const index_t context_len = seq_len - cur_batch_query_len; - index_t _max_seq_prefix_len = ( - context_len - + q_block_local_idx * BLOCK_Q - + (BLOCK_M - 1) // num_queries_per_kv - + 1 - ); + index_t _max_seq_prefix_len = + (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) // num_queries_per_kv + + 1); - if (seq_len < _max_seq_prefix_len) { + if(seq_len < _max_seq_prefix_len) + { _max_seq_prefix_len = seq_len; } const auto max_seq_prefix_len = _max_seq_prefix_len; - const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE; // TODO sliding window const index_t num_blocks_start = 0; - index_t kv_head_offset = kv_head_idx * kargs.stride_k_cache_2; - + index_t kv_head_offset = kv_head_idx * kargs.stride_k_cache_2; + // Q/K/V DRAM and DRAM window - index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.query_stride_0; // move the pointer to the batch start - index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.query_stride_1; // move the pointer to the correct head group start + index_t q_ptr_offset_0 = cur_batch_in_all_start_index * + kargs.query_stride_0; // move the pointer to the batch start + index_t q_ptr_offset_1 = + kv_head_idx * num_queries_per_kv * + kargs.query_stride_1; // move the pointer to the correct head group start index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1; - index_t o_ptr_offset_0 = cur_batch_in_all_start_index * kargs.output_stride_0; // move the pointer to the batch start - index_t o_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.output_stride_1; // move the pointer to the correct head group start - index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1; + index_t o_ptr_offset_0 = cur_batch_in_all_start_index * + kargs.output_stride_0; // move the pointer to the batch start + index_t o_ptr_offset_1 = + kv_head_idx * num_queries_per_kv * + kargs.output_stride_1; // move the pointer to the correct head group start + index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1; index_t block_table_offset = seq_idx * kargs.block_table_stride; - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + q_ptr_offset; const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + kv_head_offset; const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + kv_head_offset; - ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; - + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q; // const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); @@ -368,37 +374,35 @@ struct UnifiedAttentionKernel number{}, number<2>{}); - const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED - q_dram_base, - // block sizes - make_tuple(number{}, number<1>{}, number{}), - sequence{} - ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) + const auto q_dram_pad = + pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED + q_dram_base, + // block sizes + make_tuple(number{}, number<1>{}, number{}), + sequence{}); // pads to (seq_len_padded, num_head_q, + // HEAD_SIZE_PADDED) const auto q_dram_merged = transform_tensor_view( - q_dram_pad, - make_tuple( - make_merge_transform( - make_tuple(query_len_padded, num_queries_per_kv) - ), - make_pass_through_transform(number{}) - ), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}) - ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim - + q_dram_pad, + make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, + sequence<1>{})); // flattens the first two dims, head idx is the fastest + // changing dim in the merged dim + return q_dram_merged; }(); - // static_assert(q_dram.desc_[number<0>{}] == 0, "q_dram.get_bottom_tensor_view()[number<0>{}] == 0"); + // static_assert(q_dram.desc_[number<0>{}] == 0, + // "q_dram.get_bottom_tensor_view()[number<0>{}] == 0"); // Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim) // stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1) - auto q_dram_window = make_tile_window( - q_dram, - make_tuple(number{}, number{}), - {query_pos * num_queries_per_kv, 0} - ); - + auto q_dram_window = + make_tile_window(q_dram, + make_tuple(number{}, number{}), + {query_pos * num_queries_per_kv, 0}); + const auto k_dram = [&]() { // HEAD dim is skipped as defined in the ptrs const auto k_dram_naive = make_naive_tensor_view( @@ -408,24 +412,19 @@ struct UnifiedAttentionKernel number{}, number<1>{}); - const auto k_dram_pad = pad_tensor_view( - k_dram_naive, - // TODO can the BLOCK_SIZE_RAW needs padding? - make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), - sequence{}); - + const auto k_dram_pad = pad_tensor_view(k_dram_naive, + // TODO can the BLOCK_SIZE_RAW needs padding? + make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), + sequence{}); const auto k_dram_merged = transform_tensor_view( - k_dram_pad, - make_tuple( - make_merge_transform( - make_tuple(kargs.num_blks, BLOCK_SIZE) - ), - make_pass_through_transform(HEAD_SIZE_PADDED) - ), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}) - ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim + k_dram_pad, + make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)), + make_pass_through_transform(HEAD_SIZE_PADDED)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, + sequence<1>{})); // flattens the first two dims, head idx is the fastest + // changing dim in the merged dim return k_dram_merged; }(); @@ -441,53 +440,50 @@ struct UnifiedAttentionKernel number{}, number<1>{}); - const auto v_dram_pad = pad_tensor_view( - v_dram_naive, - make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), - sequence{}); + const auto v_dram_pad = pad_tensor_view(v_dram_naive, + make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), + sequence{}); const auto v_dram_merged = transform_tensor_view( - v_dram_pad, - make_tuple( - make_merge_transform( - make_tuple(kargs.num_blks, BLOCK_SIZE) - ), - make_pass_through_transform(HEAD_SIZE_PADDED) - ), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}) - ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim + v_dram_pad, + make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)), + make_pass_through_transform(HEAD_SIZE_PADDED)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, + sequence<1>{})); // flattens the first two dims, head idx is the fastest + // changing dim in the merged dim return v_dram_merged; }(); auto v_dram_window = make_tile_window( v_dram, make_tuple(number{}, number{}), {0, 0}); - + FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( seq_len - cur_batch_query_len, // y (i.e. context) - cur_batch_query_len, // x (i.e. extend) - seq_len, // y_total (x + y) - cur_batch_query_len, // x_total - num_queries_per_kv // the same sequence index is repeated num_queries_per_kv times along x dim of the tile + cur_batch_query_len, // x (i.e. extend) + seq_len, // y_total (x + y) + cur_batch_query_len, // x_total + num_queries_per_kv // the same sequence index is repeated num_queries_per_kv + // times along x dim of the tile ); else return FmhaMask{cur_batch_query_len, seq_len}; }(); - + auto o_acc_tile = [&]() { return UnifiedAttentionPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - num_blocks, - num_blocks_start, - kargs.block_tables_ptr, - block_table_offset, - mask, - kargs.scale_s, - smem_ptr); + k_dram_window, + v_dram_window, + num_blocks, + num_blocks_start, + kargs.block_tables_ptr, + block_table_offset, + mask, + kargs.scale_s, + smem_ptr); }(); // O DRAM and O DRAM window @@ -499,24 +495,20 @@ struct UnifiedAttentionKernel number{}, number<1>{}); - const auto o_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED - o_dram_base, - // block sizes - make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), - sequence{} - ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) + const auto o_dram_pad = + pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED + o_dram_base, + // block sizes + make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), + sequence{}); // pads to (seq_len_padded, num_head_q, + // HEAD_SIZE_PADDED) const auto o_dram_merged = transform_tensor_view( - o_dram_pad, - make_tuple( - make_merge_transform( - make_tuple(query_len_padded, num_queries_per_kv) - ), - make_pass_through_transform(HEAD_SIZE_PADDED) - ), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}) - ); + o_dram_pad, + make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), + make_pass_through_transform(HEAD_SIZE_PADDED)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); return o_dram_merged; }(); diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp index 790b0614a6..de7762e121 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp @@ -47,11 +47,14 @@ struct TileUnifiedAttentionShape static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); - static constexpr index_t BLOCK_M = BlockTile::at(number<0>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) - static constexpr index_t BLOCK_Q = BlockTile::at(number<1>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) - // static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * num_queries_per_kv (q_head//kv_head) + static constexpr index_t BLOCK_M = BlockTile::at( + number<0>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) + static constexpr index_t BLOCK_Q = BlockTile::at( + number<1>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) + // static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * + // num_queries_per_kv (q_head//kv_head) static constexpr index_t BLOCK_SIZE = BlockTile::at(number<2>{}); // BLOCK size for K seqlen - static constexpr index_t HEAD_SIZE = BlockTile::at(number<3>{}); // BLOCK size for K seqlen + static constexpr index_t HEAD_SIZE = BlockTile::at(number<3>{}); // BLOCK size for K seqlen // static constexpr index_t kQKHeaddim = // BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp index b27a09a1b4..40ec0fd0aa 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp @@ -9,14 +9,13 @@ namespace ck_tile { - template struct TileUnifiedAttentionTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; - static constexpr bool kPadHeadDim = kPadHeadDim_; + static constexpr bool kPadHeadDim = kPadHeadDim_; static constexpr index_t kBlockPerCu = kBlockPerCu_; }; -} +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index b2541ab74e..486acc4243 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -268,14 +268,15 @@ struct UnifiedAttentionPipeline static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; - static constexpr ck_tile::index_t BLOCK_M = UnifiedAttentionShape::BLOCK_M; - static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q; - - static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE; - static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE; - static constexpr ck_tile::index_t HEAD_SIZE_PADDED = UnifiedAttentionShape::HEAD_SIZE_PADDED; + static constexpr ck_tile::index_t BLOCK_M = UnifiedAttentionShape::BLOCK_M; + static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q; - static_assert(HEAD_SIZE_PADDED <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE; + static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE; + static constexpr ck_tile::index_t HEAD_SIZE_PADDED = UnifiedAttentionShape::HEAD_SIZE_PADDED; + + static_assert(HEAD_SIZE_PADDED <= 256, + "hdim bigger than 256 is not suitable for this pipeline!"); // static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDim; @@ -402,12 +403,13 @@ struct UnifiedAttentionPipeline std::is_same_v>, "wrong!"); - static_assert(BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); + static_assert( + BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize()); auto s_lds = make_tensor_view( @@ -542,9 +544,11 @@ struct UnifiedAttentionPipeline const auto q_origin = q_dram_window.get_window_origin(); // const auto [seqlen_k_start, seqlen_k_end] = - // mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + // mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, + // number{}); - // const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, BLOCK_SIZE); + // const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, + // BLOCK_SIZE); const auto num_total_loop = num_blocks; // index_t kv_token_start = seqlen_k_start; @@ -554,7 +558,6 @@ struct UnifiedAttentionPipeline { if(num_total_loop - num_blocks_start <= 0) { - // Note: here occ are all cleard, return it // Note: q loaded but no fence, ignore it. @@ -562,11 +565,11 @@ struct UnifiedAttentionPipeline } } - index_t i_total_loops = num_blocks_start; - const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); - index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; - index_t kv_blk_idx_prev = 0; - + index_t i_total_loops = num_blocks_start; + const ck_tile::index_t* block_tables_ptr_ = + reinterpret_cast(block_tables_ptr); + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + index_t kv_blk_idx_prev = 0; auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), @@ -672,9 +675,10 @@ struct UnifiedAttentionPipeline auto K_mem_load = [&](auto k_lds_write_idx) { async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); - // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx as the index + // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx + // as the index /// FIXME: use the future-predicting method to move the window - k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); + k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; auto K_lds_load = [&](auto k_lds_read_idx) { @@ -683,9 +687,9 @@ struct UnifiedAttentionPipeline auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - // kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; - /// FIXME: use the future-predicting method to move the window - v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); + // kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + /// FIXME: use the future-predicting method to move the window + v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; auto V_lds_load = [&](auto v_lds_read_idx) { @@ -894,8 +898,10 @@ struct UnifiedAttentionPipeline auto fmha_mask = [&](auto sp_reg_idx) { if constexpr(FmhaMask::IsMasking) { - bool need_perpixel_check = mask.IsEdgeTile( - q_origin.at(number<0>{}), i_total_loops * BLOCK_SIZE, number{}, number{}); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + i_total_loops * BLOCK_SIZE, + number{}, + number{}); if(need_perpixel_check) { set_tile_if(sp(sp_reg_idx).sp_compute, @@ -903,7 +909,8 @@ struct UnifiedAttentionPipeline [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = i_total_loops * BLOCK_SIZE + tile_idx.at(number<1>{}); + const auto col = + i_total_loops * BLOCK_SIZE + tile_idx.at(number<1>{}); return mask.IsOutOfBound(row, col); }); } @@ -1180,7 +1187,6 @@ struct UnifiedAttentionPipeline fmha_post_process(number<0>{}); } - // finally, O constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp index 32f97aba50..3d5b46c176 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp @@ -141,11 +141,11 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); - constexpr index_t N0 = NumIssues; // 8 + constexpr index_t N0 = NumIssues; // 8 constexpr index_t N1 = LaneGroups; // 2 - constexpr index_t N2 = NumWarps; // 8 - constexpr index_t K0 = LanesPerK; // 32 - constexpr index_t K1 = KVector; // 4 + constexpr index_t N2 = NumWarps; // 8 + constexpr index_t K0 = LanesPerK; // 32 + constexpr index_t K1 = KVector; // 4 return make_static_tile_distribution( tile_distribution_encoding, @@ -259,13 +259,13 @@ struct UnifiedAttentionPipelineDefaultPolicy } }(); - using BlockGemmPolicy = - BlockGemmARegBRegCRegV2CustomPolicy; + using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy< + typename Problem::QDataType, + typename Problem::KDataType, + typename Problem::SaccDataType, + typename Problem::UnifiedAttentionShape::Gemm0BlockWarps, + decltype(warp_gemm), + GemmLoopOrder::MNK>; return BlockGemmARegBRegCRegV2{}; } @@ -287,24 +287,25 @@ struct UnifiedAttentionPipelineDefaultPolicy typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>; /// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass /// WGAttrNumAccessEnum::Double instead of WGAttrNumAccessEnum::Single - using WarpGemm = WarpGemmDispatcher{}), - Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}), - Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}), - true, - false, - false, - WGAttrNumAccessEnum::Double>; + using WarpGemm = + WarpGemmDispatcher{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Double>; - using BlockGemmPolicy = - BlockGemmARegBRegCRegV2CustomPolicy; + using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + typename Problem::UnifiedAttentionShape::Gemm1BlockWarps, + WarpGemm, + GemmLoopOrder::MNK>; return BlockGemmARegBRegCRegV2{}; } diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp index d21d8316af..f2caaa23df 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp @@ -19,23 +19,23 @@ template struct UnifiedAttentionPipelineProblem { // TODO kM0 and KN1?? - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; // first gemm accumulation dtype - using SaccDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; // Softmax dtype using SMPLComputeDataType = remove_cvref_t; using BiasDataType = remove_cvref_t; using RandValOutputDataType = remove_cvref_t; // data type for A matrix of second gemm - using PDataType = remove_cvref_t; - // data type for second gemm accumulation + using PDataType = remove_cvref_t; + // data type for second gemm accumulation using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using UnifiedAttentionShape = remove_cvref_t; @@ -48,11 +48,11 @@ struct UnifiedAttentionPipelineProblem // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadHeadDim = Traits::kPadHeadDim; + static constexpr bool kPadHeadDim = Traits::kPadHeadDim; static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ; static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; -} +} // namespace ck_tile