mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
fixing compile errors...
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false, true>;
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false, false>;
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
#include "ck_tile/ops/unified_attention.hpp"
|
||||
|
||||
// keep this in sync with ck_tile::GenericAttentionMaskEnum
|
||||
enum class mask_enum
|
||||
|
||||
@@ -85,6 +85,7 @@ struct unified_attention_kernel_traits
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::lse_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
@@ -93,7 +94,7 @@ struct unified_attention_kernel_traits
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline = Blockunified_attentionFwdV3Pipeline<unified_attention_pipeline_problem>;
|
||||
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
@@ -140,7 +141,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const
|
||||
args.num_seqs
|
||||
);
|
||||
|
||||
index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs
|
||||
index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs;
|
||||
|
||||
|
||||
dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);
|
||||
|
||||
30
include/ck_tile/ops/unified_attention.hpp
Normal file
30
include/ck_tile/ops/unified_attention.hpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
// Block-level components
|
||||
#include "ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/unified_attention/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/unified_attention/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/unified_attention/block/block_position_encoding.hpp"
|
||||
#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp"
|
||||
#include "ck_tile/ops/unified_attention/block/page_block_navigator.hpp"
|
||||
#include "ck_tile/ops/unified_attention/block/variants.hpp"
|
||||
|
||||
// Kernel-level components
|
||||
#include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp"
|
||||
|
||||
// Pipeline-level components
|
||||
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/unified_attention/block/block_masking.hpp"
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
|
||||
@@ -4,20 +4,20 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
#include "ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDim /* paddding for hdim_v */,
|
||||
bool kPadHeadDim_ /* paddding for hdim_v */,
|
||||
bool kStoreLSE_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileUnifiedAttentionTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadHeadDim = kPadHeadDim;
|
||||
static constexpr bool kPadHeadDim = kPadHeadDim_;
|
||||
static constexpr bool kStoreLSE = kStoreLSE_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
@@ -389,7 +389,6 @@ struct UnifiedAttentionPipeline
|
||||
index_t num_queries_per_kv,
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
const LSEElementFunction& lse_element_func,
|
||||
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
@@ -564,7 +563,8 @@ struct UnifiedAttentionPipeline
|
||||
}
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t kv_blk_idx = block_tables_ptr[block_table_offset + i_total_loops];
|
||||
const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
|
||||
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
|
||||
index_t kv_blk_idx_prev = 0;
|
||||
|
||||
|
||||
@@ -674,11 +674,7 @@ struct UnifiedAttentionPipeline
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
// TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx as the index
|
||||
/// FIXME: use the future-predicting method to move the window
|
||||
// move K tile windows
|
||||
auto k_dram_window = make_tile_window(k_dram_window.get_bottom_tensor_view(),
|
||||
k_dram_window.get_window_lengths(),
|
||||
{(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
|
||||
};
|
||||
|
||||
auto K_lds_load = [&](auto k_lds_read_idx) {
|
||||
@@ -687,12 +683,9 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
|
||||
/// FIXME: use the future-predicting method to move the window
|
||||
auto v_dram_window = make_tile_window(v_dram_window.get_bottom_tensor_view(),
|
||||
v_dram_window.get_window_lengths(),
|
||||
{(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
// kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
|
||||
/// FIXME: use the future-predicting method to move the window
|
||||
v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
|
||||
};
|
||||
|
||||
auto V_lds_load = [&](auto v_lds_read_idx) {
|
||||
|
||||
Reference in New Issue
Block a user