fixing compile errors...

This commit is contained in:
Juuso Korhonen
2025-10-20 15:04:47 +00:00
parent 97e7527eb1
commit d68a541c19
8 changed files with 47 additions and 23 deletions

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false, true>;
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, true>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false, false>;
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)

View File

@@ -7,7 +7,7 @@
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/unified_attention.hpp"
// keep this in sync with ck_tile::GenericAttentionMaskEnum
enum class mask_enum

View File

@@ -85,6 +85,7 @@ struct unified_attention_kernel_traits
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::lse_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
@@ -93,7 +94,7 @@ struct unified_attention_kernel_traits
unified_attention_mask,
unified_attention_traits>;
using unified_attention_pipeline = Blockunified_attentionFwdV3Pipeline<unified_attention_pipeline_problem>;
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
@@ -140,7 +141,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const
args.num_seqs
);
index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs
index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs;
dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);

View File

@@ -0,0 +1,30 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
// Block-level components
#include "ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/unified_attention/block/block_dropout.hpp"
#include "ck_tile/ops/unified_attention/block/block_masking.hpp"
#include "ck_tile/ops/unified_attention/block/block_position_encoding.hpp"
#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/unified_attention/block/page_block_navigator.hpp"
#include "ck_tile/ops/unified_attention/block/variants.hpp"
// Kernel-level components
#include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp"
// Pipeline-level components
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp"
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp"

View File

@@ -4,7 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/unified_attention/block/block_masking.hpp"
#include <cmath>
#include <vector>

View File

@@ -4,20 +4,20 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp"
namespace ck_tile {
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDim /* paddding for hdim_v */,
bool kPadHeadDim_ /* paddding for hdim_v */,
bool kStoreLSE_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileUnifiedAttentionTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDim = kPadHeadDim;
static constexpr bool kPadHeadDim = kPadHeadDim_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};

View File

@@ -389,7 +389,6 @@ struct UnifiedAttentionPipeline
index_t num_queries_per_kv,
const void* block_tables_ptr,
index_t block_table_offset,
const LSEElementFunction& lse_element_func,
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
@@ -564,7 +563,8 @@ struct UnifiedAttentionPipeline
}
index_t i_total_loops = 0;
index_t kv_blk_idx = block_tables_ptr[block_table_offset + i_total_loops];
const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
index_t kv_blk_idx_prev = 0;
@@ -674,11 +674,7 @@ struct UnifiedAttentionPipeline
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
// TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx as the index
/// FIXME: use the future-predicting method to move the window
// move K tile windows
auto k_dram_window = make_tile_window(k_dram_window.get_bottom_tensor_view(),
k_dram_window.get_window_lengths(),
{(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0},
Policy::template MakeVDramTileDistribution<Problem>());
k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
};
auto K_lds_load = [&](auto k_lds_read_idx) {
@@ -687,12 +683,9 @@ struct UnifiedAttentionPipeline
auto V_mem_load = [&](auto v_lds_write_idx) {
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
/// FIXME: use the future-predicting method to move the window
auto v_dram_window = make_tile_window(v_dram_window.get_bottom_tensor_view(),
v_dram_window.get_window_lengths(),
{(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0},
Policy::template MakeVDramTileDistribution<Problem>());
// kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
/// FIXME: use the future-predicting method to move the window
v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
};
auto V_lds_load = [&](auto v_lds_read_idx) {