Add unified attention (42_unified_attention)

Squashed from aghamari/unified-attention-decode-opt branch.

CK tile paged-KV attention kernel optimized for decode with 4-tier
dispatch (tiny/small/medium/large), 16x16 MFMA, 2D decode grid,
head-group merging. Supports hdim=64 GQA-8 and hdim=128 MHA with
block_size=32.

Made-with: Cursor
This commit is contained in:
root
2026-04-01 16:24:53 +00:00
parent ec2db01e4a
commit cd7ba6e2e8
7 changed files with 19 additions and 455 deletions

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/unified_attention/block/block_masking.hpp"
#include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp"
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp"
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -615,7 +615,7 @@ struct UnifiedAttentionPipeline
}
else
{
auto casted = detail::cvt_pk_bf16_f32(x, y);
auto casted = cvt_pk_bf16_f32(x, y);
sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
}