mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add unified attention (42_unified_attention)
Squashed from aghamari/unified-attention-decode-opt branch. CK tile paged-KV attention kernel optimized for decode with 4-tier dispatch (tiny/small/medium/large), 16x16 MFMA, 2D decode grid, head-group merging. Supports hdim=64 GQA-8 and hdim=128 MHA with block_size=32. Made-with: Cursor
This commit is contained in:
14
include/ck_tile/ops/unified_attention.hpp
Normal file
14
include/ck_tile/ops/unified_attention.hpp
Normal file
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/unified_attention/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
@@ -615,7 +615,7 @@ struct UnifiedAttentionPipeline
|
||||
}
|
||||
else
|
||||
{
|
||||
auto casted = detail::cvt_pk_bf16_f32(x, y);
|
||||
auto casted = cvt_pk_bf16_f32(x, y);
|
||||
sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
|
||||
sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user