From d68a541c1994c552b7baef484682373ccb3be843 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 20 Oct 2025 15:04:47 +0000 Subject: [PATCH] fixing compile errors... --- .../unified_attention_d128_bf16_mask.cpp | 2 +- .../unified_attention_d128_fp16_nmask.cpp | 2 +- example/ck_tile/01_unified_attention/mask.hpp | 2 +- .../unified_attention_impl.hpp | 5 ++-- include/ck_tile/ops/unified_attention.hpp | 30 +++++++++++++++++++ .../block/block_position_encoding.hpp | 2 +- .../tile_unified_attention_traits.hpp | 8 ++--- .../pipeline/unified_attention_pipeline.hpp | 19 ++++-------- 8 files changed, 47 insertions(+), 23 deletions(-) create mode 100644 include/ck_tile/ops/unified_attention.hpp diff --git a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_mask.cpp b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_mask.cpp index d99838d17c..72717026bc 100644 --- a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_mask.cpp +++ b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_mask.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_kernel_traits; + unified_attention_kernel_traits; INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp index d8fcd7d97d..6a2a9984d1 100644 --- a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp +++ b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_kernel_traits; + unified_attention_kernel_traits; INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_unified_attention/mask.hpp b/example/ck_tile/01_unified_attention/mask.hpp index 2dfe0e7c52..33f9bf72a9 100644 --- a/example/ck_tile/01_unified_attention/mask.hpp +++ b/example/ck_tile/01_unified_attention/mask.hpp @@ -7,7 +7,7 @@ #include #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/unified_attention.hpp" // keep this in sync with ck_tile::GenericAttentionMaskEnum enum class mask_enum diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 4fa0bdab0d..65f17fa251 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -85,6 +85,7 @@ struct unified_attention_kernel_traits typename unified_attention_problem_traits::qkvp_dtype, typename unified_attention_problem_traits::acc_dtype, typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::acc_dtype, typename unified_attention_problem_traits::lse_dtype, typename unified_attention_problem_traits::qkvp_dtype, typename unified_attention_problem_traits::acc_dtype, @@ -93,7 +94,7 @@ struct unified_attention_kernel_traits unified_attention_mask, unified_attention_traits>; - using unified_attention_pipeline = Blockunified_attentionFwdV3Pipeline; + using unified_attention_pipeline = UnifiedAttentionPipeline; using epilogue = Default2DEpilogue< Default2DEpilogueProblem::acc_dtype, @@ -140,7 +141,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const args.num_seqs ); - index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs + index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs; dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); diff --git a/include/ck_tile/ops/unified_attention.hpp b/include/ck_tile/ops/unified_attention.hpp new file mode 100644 index 0000000000..62e6c58acb --- /dev/null +++ b/include/ck_tile/ops/unified_attention.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + + +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" + +// Block-level components +#include "ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/unified_attention/block/block_dropout.hpp" +#include "ck_tile/ops/unified_attention/block/block_masking.hpp" +#include "ck_tile/ops/unified_attention/block/block_position_encoding.hpp" +#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp" +#include "ck_tile/ops/unified_attention/block/page_block_navigator.hpp" +#include "ck_tile/ops/unified_attention/block/variants.hpp" + +// Kernel-level components +#include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp" + +// Pipeline-level components +#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp" +#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp" + diff --git a/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp b/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp index 703ec0967a..3dd36a712d 100644 --- a/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp +++ b/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/unified_attention/block/block_masking.hpp" #include #include diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp index a285c30876..f10b064487 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp @@ -4,20 +4,20 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" +#include "ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp" namespace ck_tile { template struct TileUnifiedAttentionTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; - static constexpr bool kPadHeadDim = kPadHeadDim; + static constexpr bool kPadHeadDim = kPadHeadDim_; static constexpr bool kStoreLSE = kStoreLSE_; static constexpr index_t kBlockPerCu = kBlockPerCu_; }; diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 5c1a91fb22..0b7d313757 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -389,7 +389,6 @@ struct UnifiedAttentionPipeline index_t num_queries_per_kv, const void* block_tables_ptr, index_t block_table_offset, - const LSEElementFunction& lse_element_func, [[maybe_unused]] const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, @@ -564,7 +563,8 @@ struct UnifiedAttentionPipeline } index_t i_total_loops = 0; - index_t kv_blk_idx = block_tables_ptr[block_table_offset + i_total_loops]; + const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; index_t kv_blk_idx_prev = 0; @@ -674,11 +674,7 @@ struct UnifiedAttentionPipeline async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx as the index /// FIXME: use the future-predicting method to move the window - // move K tile windows - auto k_dram_window = make_tile_window(k_dram_window.get_bottom_tensor_view(), - k_dram_window.get_window_lengths(), - {(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0}, - Policy::template MakeVDramTileDistribution()); + k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; auto K_lds_load = [&](auto k_lds_read_idx) { @@ -687,12 +683,9 @@ struct UnifiedAttentionPipeline auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - - /// FIXME: use the future-predicting method to move the window - auto v_dram_window = make_tile_window(v_dram_window.get_bottom_tensor_view(), - v_dram_window.get_window_lengths(), - {(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0}, - Policy::template MakeVDramTileDistribution()); + // kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + /// FIXME: use the future-predicting method to move the window + v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; auto V_lds_load = [&](auto v_lds_read_idx) {