Add unified attention (42_unified_attention) and topk_softmax_decode

Squashed from aghamari/unified-attention-decode-opt branch.

42_unified_attention: CK tile paged-KV attention kernel optimized for
decode with 4-tier dispatch (tiny/small/medium/large), 16x16 MFMA,
2D decode grid, head-group merging. Supports hdim=64 GQA-8 and
hdim=128 MHA with block_size=32.

topk_softmax_decode: fused topk + softmax kernel for M=1 MoE decode.

Made-with: Cursor
This commit is contained in:
root
2026-04-01 16:24:04 +00:00
parent 2bb69a24ea
commit 4c5e290378
67 changed files with 6469 additions and 3 deletions

View File

@@ -0,0 +1,313 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
enum struct GenericAttentionMaskEnum
{
NO_MASK = 0,
// below enum could be causal, or sliding window
MASK_FROM_TOP_LEFT = 1,
MASK_FROM_BOTTOM_RIGHT = 2,
// this enum maybe not used by xformer/FA, since it's hard to
// specify left/right window for varlen case. put it here for
// debug purpose
MASK_GENERIC,
};
// clang-format off
/* generic Attention Mask Coordinate
use x(horizontal axis), y(vertical axis) to describe mask.
top-left corner is origin
x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask)
1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
l=7,-1/r=0(tl) l=7,-1/r=0(br)
x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2
1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
* 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1
* * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1
* * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1
l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl)
l=4/r=0(br) l=4/r=2(br) l=4/r=4(br)
x=4/y=-1 x=6/y=-1 x=8/y=-1
* * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1
* * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1
* * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1
* * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1
* * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1
x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r)
* * * * * * * * 1 * * * * * * *
* * * * * * * * 1 1 * * 1 * * *
* * * * * * * * 1 1 1 * 1 1 * *
1 * * * * * * * 1 1 1 1 1 1 1 *
1 1 * * * * * * 1 1 1 1 1 1 1 1
Validations:
x + y > 1 (x + y >= 2)
Note:
y = seq_q, x = 1 -> top-left
y = seq_q, x = seq_k - seq_q + 1 -> bottom-right
y < seq_q, x < seq_k -> local-attn
y = seq_q, x = seq_k -> no mask
*/
namespace impl {
template <bool IsMasking_, bool IsLocal_> struct MaskName;
template<> struct MaskName<false, false> { static constexpr const char * name = "mn"; };
template<> struct MaskName<false, true> { static constexpr const char * name = "mn"; };
template<> struct MaskName<true, false> { static constexpr const char * name = "mc"; };
template<> struct MaskName<true, true> { static constexpr const char * name = "mg"; };
}
// clang-format on
template <bool IsMasking_ = true, bool IsLocal_ = false>
struct GenericAttentionMask
{
static constexpr bool IsMasking = IsMasking_; // false will disable masking
static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask,
// else only upper-right could have mask
static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
// New constructor accepting repeat_idx with default value 1
CK_TILE_HOST_DEVICE
GenericAttentionMask(index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1)
: GenericAttentionMask(0, 0, y_total_, x_total_, repeat_idx_)
{
}
CK_TILE_HOST_DEVICE
GenericAttentionMask(
index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_), repeat_idx(repeat_idx_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord,
index_t repeat_idx_ = 1)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
y_total(mask_coord.at(number<2>{})),
x_total(mask_coord.at(number<3>{})),
repeat_idx(repeat_idx_)
{
}
// to get the loop length along X axis, return index:[start, end), end-start=length
// use this if need loop over X axis tile by tile (like k-seqlen loopover)
// TODO: x_end still could be negative, so end-start could be negative(need check)
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
// Transform the y index according to repeat_idx
index_t y_eff = i_y / repeat_idx;
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, x_total);
}
else
{
// get the tile start/end range assuming we loop over along X tile by tile
index_t x_start = [&]() {
if constexpr(IsLocal)
{
index_t tmp = max(-y + y_eff + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}
else
{
return 0;
}
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(y_eff + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
return ck_tile::make_tuple(x_start, x_end);
}
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// Note: this function does not take a dynamic y index so no transform is needed
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, y_total);
}
else
{
// get the tile start/end range assuming we loop over along Y tile by tile
index_t y_start = [&]() {
index_t tmp = max(-x + i_x + 1, 0);
return (tmp / YTile) * YTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t y_end = [&]() {
index_t tmp = min(i_x + XTile - 1 + y, y_total);
return ((tmp + YTile - 1) / YTile) * YTile;
}();
return ck_tile::make_tuple(y_start, y_end);
}
}
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
{
// Transform the y index according to repeat_idx
index_t y_eff = i_y / repeat_idx;
if constexpr(!IsMasking)
{
return i_x >= x_total;
}
else
{
// no need to do min/max here, since i_x will never be < 0 or >= x_total
index_t x_start = -y + y_eff + 1;
index_t x_end = min(y_eff + x, x_total);
if constexpr(IsLocal)
{
return i_x < x_start || i_x >= x_end;
}
else
{
return i_x >= x_end || y_eff >= y_total;
}
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the index passed in this function is within range of GetTileRangeAlongX/Y()
// can be used as a fast-path to decide if do per-pixel check or not
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto
IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number<TileHeight>, number<TileWidth>) const
{
// Transform the y index according to repeat_idx
index_t y_eff = i_tile_top / repeat_idx;
if constexpr(!IsMasking)
{
// TODO: no need to check begin
return (i_tile_left + TileWidth) > x_total;
}
else
{
if constexpr(IsLocal)
{
// check top-right corner > x or left-bottom corner < x
index_t i_tile_right = i_tile_left + TileWidth;
index_t i_tile_bottom = y_eff + TileHeight;
index_t x_end = min(y_eff + x, x_total);
bool top_right_edge = i_tile_right > (y_eff + x);
bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
bool is_partial_out_of_bound =
i_tile_right > x_end; // only consider right-pad for now
return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
}
else
{
// only need to check top-right corner > x
index_t i_tile_right = i_tile_left + TileWidth;
index_t x_end = min(y_eff + x, x_total);
bool top_right_edge = i_tile_right > x_end;
return top_right_edge;
}
}
}
private:
index_t y, x;
index_t y_total, x_total;
index_t repeat_idx;
};
template <typename>
struct is_generic_attention_mask : std::false_type
{
};
template <bool IsMasking, bool IsLocal>
struct is_generic_attention_mask<GenericAttentionMask<IsMasking, IsLocal>> : std::true_type
{
};
template <typename Mask>
static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask<Mask>::value;
// TODO: prefer use this function in host code
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
index_t right_size,
index_t y_total,
index_t x_total,
bool is_top_left = true)
{
// TODO: below should all use sgpr arithmetic
index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1;
index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1;
left_size = left_size < 0 ? left_size_tmp : left_size;
right_size = right_size < 0 ? right_size_tmp : right_size;
index_t x_tmp = is_top_left ? 0 : x_total - y_total;
index_t y_tmp = is_top_left ? 0 : y_total - x_total;
index_t x = 1 + right_size + x_tmp;
index_t y = 1 + left_size + y_tmp;
return ck_tile::make_tuple(y, x, y_total, x_total);
}
template <typename MaskType>
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_from_lr_window(index_t left_size,
index_t right_size,
index_t y_total,
index_t x_total,
index_t repeat_idx = 1,
bool is_top_left = true)
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total, repeat_idx};
}
} // namespace ck_tile

View File

@@ -0,0 +1,487 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/unified_attention/block/block_masking.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
namespace ck_tile {
template <typename UnifiedAttentionPipeline_, typename EpiloguePipeline_>
struct UnifiedAttentionKernel
{
using UnifiedAttentionPipeline = ck_tile::remove_cvref_t<UnifiedAttentionPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = UnifiedAttentionPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = UnifiedAttentionPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
using QDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::VDataType>;
using ODataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::SaccDataType>;
using FmhaMask = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kPadSeqLenK = UnifiedAttentionPipeline::kPadSeqLenK;
static constexpr bool kPadSeqLenQ = UnifiedAttentionPipeline::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = UnifiedAttentionPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = UnifiedAttentionPipeline::kPadHeadDimV;
static constexpr index_t kHeadDim = UnifiedAttentionPipeline::kHeadDim;
static constexpr index_t kHeadDimPadded = UnifiedAttentionPipeline::kHeadDimPadded;
// kBlockQ = kBlockM // num_queries_per_kv
// kBlockQ is the block size for q seqlen
/// static constexpr index_t kBlockQ = UnifiedAttentionPipeline::kBlockQ;
static constexpr index_t kBlockM = UnifiedAttentionPipeline::kBlockM;
static constexpr index_t kBlockQ = UnifiedAttentionPipeline::kBlockQ;
// BLOCK size for K seqlen
static constexpr index_t kPageBlockSize = UnifiedAttentionPipeline::kPageBlockSize;
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
// The attention is default causal
struct UnifiedAttentionCommonKargs
{
const void* q_ptr;
const void* k_ptr; // [num_blks, page_size, num_kv_heads, head_size]
const void* v_ptr; // [num_blks, page_size, num_kv_heads, head_size]
void* o_ptr;
ck_tile::index_t num_blks;
ck_tile::index_t num_head_q;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
const ck_tile::index_t num_queries_per_kv;
// scales
float scale_s;
float scale;
float scale_k;
float scale_v;
float scale_out;
ck_tile::index_t page_size;
ck_tile::index_t total_num_q_blocks;
ck_tile::index_t query_stride_0;
ck_tile::index_t query_stride_1;
ck_tile::index_t stride_k_cache_0;
ck_tile::index_t stride_k_cache_1;
ck_tile::index_t stride_k_cache_2;
ck_tile::index_t stride_k_cache_3;
ck_tile::index_t stride_v_cache_0;
ck_tile::index_t stride_v_cache_1;
ck_tile::index_t stride_v_cache_2;
ck_tile::index_t stride_v_cache_3;
ck_tile::index_t output_stride_0;
ck_tile::index_t output_stride_1;
};
struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs
{
const int32_t* block_tables_ptr;
ck_tile::index_t block_table_stride;
const int32_t* seq_lens_ptr; // seq len in each batch
const int32_t* query_start_len_ptr; // [num_seqs+1]
ck_tile::index_t num_seqs; // number of batches for q
};
using Kargs = UnifiedAttentionVarlenKargs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
void* o_ptr,
ck_tile::index_t num_blks,
ck_tile::index_t num_head_q,
const ck_tile::index_t num_queries_per_kv,
float scale_s,
float scale,
float scale_k,
float scale_v,
float scale_out,
ck_tile::index_t page_size,
ck_tile::index_t total_num_q_blocks,
ck_tile::index_t query_stride_0,
ck_tile::index_t query_stride_1,
ck_tile::index_t stride_k_cache_0,
ck_tile::index_t stride_k_cache_1,
ck_tile::index_t stride_k_cache_2,
ck_tile::index_t stride_k_cache_3,
ck_tile::index_t stride_v_cache_0,
ck_tile::index_t stride_v_cache_1,
ck_tile::index_t stride_v_cache_2,
ck_tile::index_t stride_v_cache_3,
ck_tile::index_t output_stride_0,
ck_tile::index_t output_stride_1,
const int32_t* block_tables_ptr,
ck_tile::index_t block_table_stride,
const int32_t* seq_lens_ptr,
const int32_t* query_start_len_ptr,
ck_tile::index_t num_seqs)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
num_blks,
num_head_q,
num_queries_per_kv,
static_cast<float>(scale_s * ck_tile::log2e_v<>),
scale,
scale_k,
scale_v,
scale_out,
page_size,
total_num_q_blocks,
query_stride_0,
query_stride_1,
stride_k_cache_0,
stride_k_cache_1,
stride_k_cache_2,
stride_k_cache_3,
stride_v_cache_0,
stride_v_cache_1,
stride_v_cache_2,
stride_v_cache_3,
output_stride_0,
output_stride_1},
block_tables_ptr,
block_table_stride,
seq_lens_ptr,
query_start_len_ptr,
num_seqs};
return kargs;
}
CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads,
ck_tile::index_t total_num_q_blocks)
{
return dim3(num_kv_heads * total_num_q_blocks);
}
// Binary search to find the sequence index for a given target index
CK_TILE_DEVICE static constexpr ck_tile::index_t
find_seq_idx(const int32_t* query_start_len_ptr,
ck_tile::index_t target_idx,
ck_tile::index_t num_seqs,
ck_tile::index_t block_q,
bool use_q_block_mode)
{
ck_tile::index_t left = 0;
ck_tile::index_t right = num_seqs;
while(left < right)
{
ck_tile::index_t mid = (left + right) / 2;
ck_tile::index_t val = amd_wave_read_first_lane(query_start_len_ptr[mid]);
ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val;
if(mid_val <= target_idx)
{
left = mid + 1;
}
else
{
right = mid;
}
}
return left - 1;
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid,
const Kargs& kargs)
{
using namespace ck_tile;
ck_tile::index_t num_head_kv = kargs.num_head_q / kargs.num_queries_per_kv;
return ck_tile::make_tuple(pid % num_head_kv, pid / num_head_kv);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(UnifiedAttentionPipeline::GetSmemSize(),
EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST static constexpr auto GridSizeDecode(ck_tile::index_t num_kv_heads,
ck_tile::index_t num_seqs)
{
return dim3(num_kv_heads, num_seqs);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
using namespace ck_tile;
const index_t num_queries_per_kv = kargs.num_queries_per_kv;
assert(kBlockM / num_queries_per_kv == kBlockQ);
index_t kv_head_idx;
index_t seq_idx;
index_t q_block_local_idx;
index_t cur_batch_in_all_start_index;
index_t cur_batch_query_len;
if(gridDim.y > 1)
{
// Decode grid: dim3(num_kv_heads, num_seqs)
// Direct mapping, no binary search, no padding CTAs.
kv_head_idx = blockIdx.x;
seq_idx = blockIdx.y;
q_block_local_idx = 0;
cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx];
const index_t stop = kargs.query_start_len_ptr[seq_idx + 1];
cur_batch_query_len = amd_wave_read_first_lane(stop - cur_batch_in_all_start_index);
}
else
{
// Standard 1D grid with binary search
ck_tile::index_t pid = blockIdx.x;
const auto [kv_head_idx_, q_block_global_idx] = GetTileIndex(pid, kargs);
kv_head_idx = kv_head_idx_;
if(q_block_global_idx >= kargs.total_num_q_blocks)
{
return;
}
seq_idx = find_seq_idx(kargs.query_start_len_ptr,
q_block_global_idx,
kargs.num_seqs,
kBlockQ,
true);
const index_t q_block_start_idx =
kargs.query_start_len_ptr[seq_idx] / kBlockQ + seq_idx;
q_block_local_idx =
amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx];
const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1];
cur_batch_query_len =
amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index);
if(q_block_local_idx * kBlockQ >= cur_batch_query_len)
{
return;
}
}
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * kBlockQ);
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
index_t _max_seq_prefix_len = amd_wave_read_first_lane(
(context_len + q_block_local_idx * kBlockQ + (kBlockM - 1) + 1));
if(seq_len < _max_seq_prefix_len)
{
_max_seq_prefix_len = seq_len;
}
const auto max_seq_prefix_len = _max_seq_prefix_len;
const index_t num_blocks =
amd_wave_read_first_lane((max_seq_prefix_len + kPageBlockSize - 1) / kPageBlockSize);
// TODO sliding window
const index_t num_blocks_start = 0;
long_index_t kv_head_offset = static_cast<long_index_t>(kv_head_idx) * kargs.stride_k_cache_2;
// Q/K/V DRAM and DRAM window
index_t q_ptr_offset_0 = cur_batch_in_all_start_index *
kargs.query_stride_0; // move the pointer to the batch start
index_t q_ptr_offset_1 =
kv_head_idx * num_queries_per_kv *
kargs.query_stride_1; // move the pointer to the correct head group start
index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1;
index_t o_ptr_offset_0 = cur_batch_in_all_start_index *
kargs.output_stride_0; // move the pointer to the batch start
index_t o_ptr_offset_1 =
kv_head_idx * num_queries_per_kv *
kargs.output_stride_1; // move the pointer to the correct head group start
index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1;
index_t block_table_offset = seq_idx * kargs.block_table_stride;
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) + q_ptr_offset;
const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) + kv_head_offset;
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) + kv_head_offset;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) + o_ptr_offset;
index_t query_len_padded =
amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, kBlockQ) * kBlockQ);
// const bool is_query_len_padded = (cur_batch_query_len % kBlockQ == 0);
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
const auto q_dram_base = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim),
make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1),
number<UnifiedAttentionPipeline::kAlignmentQ>{},
number<1>{});
const auto q_dram_pad =
pad_tensor_view( // aling seqlen with kBlockQ and head dim with kHeadDimPadded
q_dram_base,
// block sizes
make_tuple(number<kBlockQ>{}, 1, kHeadDimPadded),
sequence<true, false, kPadHeadDimQ>{}); // pads to (seq_len_padded, num_head_q,
// kHeadDimPadded)
const auto q_dram_merged = transform_tensor_view(
q_dram_pad,
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
make_pass_through_transform(kHeadDimPadded)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{},
sequence<1>{})); // flattens the first two dims, head idx is the fastest
// changing dim in the merged dim
return q_dram_merged;
}();
// static_assert(q_dram.desc_[number<0>{}] == 0,
// "q_dram.get_bottom_tensor_view()[number<0>{}] == 0");
// Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim)
// stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1)
auto q_dram_window =
make_tile_window(q_dram,
make_tuple(number<kBlockM>{}, number<kHeadDimPadded>{}),
{query_pos * num_queries_per_kv, 0});
const auto k_dram = [&]() {
// Use long_index_t for size/strides to prevent int32 overflow
// when row * stride exceeds 2^31 (happens at ~66K blocks for d64/GQA-8).
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(static_cast<long_index_t>(kargs.num_blks) * kargs.page_size,
static_cast<long_index_t>(kHeadDim)),
make_tuple(static_cast<long_index_t>(kargs.stride_k_cache_1),
static_cast<long_index_t>(kargs.stride_k_cache_3)),
number<UnifiedAttentionPipeline::kAlignmentK>{},
number<1>{});
const auto k_dram_pad =
pad_tensor_view(k_dram_naive,
make_tuple(kPageBlockSize, kHeadDimPadded),
sequence<false, kPadHeadDimQ>{});
return k_dram_pad;
}();
auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<kPageBlockSize>{}, number<kHeadDimPadded>{}), {0, 0});
const auto v_dram = [&]() {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(static_cast<long_index_t>(kargs.num_blks) * kargs.page_size,
static_cast<long_index_t>(kHeadDim)),
make_tuple(static_cast<long_index_t>(kargs.stride_v_cache_1),
static_cast<long_index_t>(kargs.stride_v_cache_3)),
number<UnifiedAttentionPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_pad = pad_tensor_view(v_dram_naive,
make_tuple(kPageBlockSize, kHeadDimPadded),
sequence<false, kPadHeadDimQ>{});
return v_dram_pad;
}();
auto v_dram_window = make_tile_window(
v_dram, make_tuple(number<kPageBlockSize>{}, number<kHeadDimPadded>{}), {0, 0});
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
-1,
0,
cur_batch_query_len, // y_total
seq_len, // x_total
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
// times along x dim of the tile
false);
else
return FmhaMask{cur_batch_query_len, seq_len};
}();
const index_t kv_page_size_in_blocks = kargs.page_size / kPageBlockSize;
assert(kv_page_size_in_blocks >= 1); // kPageBlockSize <= page_size
auto o_acc_tile = [&]() {
return UnifiedAttentionPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
num_blocks,
num_blocks_start,
kargs.block_tables_ptr,
block_table_offset,
kv_page_size_in_blocks,
mask,
kargs.scale_s,
smem_ptr);
}();
// O DRAM and O DRAM window
auto o_dram = [&]() {
const auto o_dram_base = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim),
make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1),
number<UnifiedAttentionPipeline::kAlignmentO>{},
number<1>{});
const auto o_dram_pad =
pad_tensor_view( // aling cu_seqlen with kBlockQ and head dim with kHeadDimPadded
o_dram_base,
// block sizes
make_tuple(kBlockQ, 1, kHeadDimPadded),
sequence<true, false, kPadHeadDimQ>{}); // pads to (seq_len_padded, num_head_q,
// kHeadDimPadded)
const auto o_dram_merged = transform_tensor_view(
o_dram_pad,
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
make_pass_through_transform(kHeadDimPadded)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return o_dram_merged;
}();
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<kBlockM>{}, number<kHeadDimPadded>{}),
{query_pos * num_queries_per_kv, 0});
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,72 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <index_t Headdim>
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length()
{
if constexpr(Headdim == 48)
return 48;
else if constexpr(Headdim == 96)
return 128;
else if constexpr(Headdim == 160)
return 256;
else if constexpr(Headdim == 192)
return 192;
else if constexpr(is_power_of_two_integer(Headdim))
return Headdim;
else
static_assert(Headdim == 0,
"only Headdim of 48, 96, 160, 192 and power-of-two is supported");
};
template <typename BlockTile_, // sequence<...
typename Gemm0BlockWarps_,
typename Gemm0WarpTile_,
typename Gemm1BlockWarps_,
typename Gemm1WarpTile_,
bool IsVLayoutRowMajor_>
struct TileUnifiedAttentionShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
static constexpr index_t NumGemm0Warps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
static constexpr index_t NumGemm1Warps =
reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{});
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static constexpr index_t kBlockM = BlockTile::at(
number<0>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS)
static constexpr index_t kBlockQ = BlockTile::at(
number<1>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS)
// static constexpr index_t kBlockM = BlockTile::at(number<1>{}); // tile size along q seqlen *
// num_queries_per_kv (q_head//kv_head)
static constexpr index_t kPageBlockSize =
BlockTile::at(number<2>{}); // BLOCK size for K seqlen
static constexpr index_t kHeadDim = BlockTile::at(number<3>{}); // BLOCK size for K seqlen
// BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
// // once (or repeately load Q as a whole tile)
// static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static constexpr index_t kHeadDimPadded = ceil_to_qualified_tile_length<kHeadDim>();
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
using VLayout = std::conditional_t<IsVLayoutRowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
};
} // namespace ck_tile

View File

@@ -0,0 +1,19 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDim_ /* paddding for hdim_v */,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileUnifiedAttentionTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDim = kPadHeadDim_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,613 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
namespace ck_tile {
struct UnifiedAttentionPipelineDefaultPolicy
{
static constexpr ck_tile::index_t NumWarpPerGroup = 4;
static constexpr ck_tile::index_t NumThreadPerWarpGroup =
NumWarpPerGroup * ck_tile::get_warp_size();
// TODO: GetAlignment*() currently didn't consider if need padding or not
// so in pipeline still need check padding requirement
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetAlignmentK()
{
using namespace ck_tile;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
#if defined(__gfx950__)
constexpr index_t MaxReadSizeInBytes = 16;
#else
constexpr index_t MaxReadSizeInBytes = 4;
#endif
return MaxReadSizeInBytes / sizeof(KDataType);
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetAlignmentV()
{
using namespace ck_tile;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
#if defined(__gfx950__)
constexpr index_t MaxReadSizeInBytes = 16;
#else
constexpr index_t MaxReadSizeInBytes = 4;
#endif
return MaxReadSizeInBytes / sizeof(VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
{
using namespace ck_tile;
// TODO: this is for 3d layout
using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 16 / sizeof(KDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemVPackK()
{
using namespace ck_tile;
// TODO: this is for 3d layout
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType);
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKDramTileDistribution()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr index_t N0 = NumIssues;
constexpr index_t N1 = LaneGroups;
constexpr index_t N2 = NumWarps;
constexpr index_t K0 = LanesPerK;
constexpr index_t K1 = KVector;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size(); // 64
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
// 4
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr index_t N0 = NumIssues; // 8
constexpr index_t N1 = LaneGroups; // 2
constexpr index_t N2 = NumWarps; // 8
constexpr index_t K0 = LanesPerK; // 32
constexpr index_t K1 = KVector; // 4
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeQRegTileDistribution()
{
using namespace ck_tile;
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKRegTileDistribution()
{
using namespace ck_tile;
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakePRegTileDistribution()
{
using namespace ck_tile;
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVRegTileDistribution()
{
using namespace ck_tile;
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
// compute the endcoding before transpose
constexpr auto v_block_dstr =
make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(v_block_dstr_encode),
typename Problem::VDataType>::TransposedDstrEncode{});
return v_block_dstr;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetQKBlockGemm()
{
using namespace ck_tile;
using GemmProblem =
BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::UnifiedAttentionShape::kBlockM,
Problem::UnifiedAttentionShape::kPageBlockSize,
Problem::UnifiedAttentionShape::kHeadDim>,
typename Problem::UnifiedAttentionShape::Gemm0BlockWarps,
typename Problem::UnifiedAttentionShape::Gemm0WarpTile>>;
using WarpGemm =
WarpGemmDispatcher<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<0>{}),
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}),
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}),
true,
false,
false>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::UnifiedAttentionShape::Gemm0BlockWarps,
WarpGemm,
GemmLoopOrder::MNK>;
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetPVBlockGemm()
{
using namespace ck_tile;
using GemmProblem =
BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::UnifiedAttentionShape::kBlockM,
Problem::UnifiedAttentionShape::kHeadDim,
Problem::UnifiedAttentionShape::kPageBlockSize>,
typename Problem::UnifiedAttentionShape::Gemm1BlockWarps,
typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>;
/// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass
/// WGAttrNumAccessEnum::Double instead of WGAttrNumAccessEnum::Single
using WarpGemm =
WarpGemmDispatcher<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<0>{}),
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}),
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Double>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
typename Problem::UnifiedAttentionShape::Gemm1BlockWarps,
WarpGemm,
GemmLoopOrder::MNK>;
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords
static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords
template <typename Problem, ck_tile::index_t IBuf = 0>
CK_TILE_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
{
using namespace ck_tile;
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
kKLdsPadInBytes /
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps.
// Optimize this for lds_read speed
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK =
kKPerBlock / KVector; // how many lane (within a wave) to load K
constexpr index_t LaneGroups =
WarpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return k_lds_block_desc_issues_warps_lanes;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor()
{
using namespace ck_tile;
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
kKLdsPadInBytes /
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<WarpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
{
// this function assume K/V can share smem
constexpr index_t SingleKSize = [&]() {
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack;
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector;
constexpr index_t LaneGroups = WarpSize / LanesPerK;
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
}();
constexpr index_t SingleVSize = [&]() {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
}();
return max(SingleKSize, SingleVSize);
}
template <typename Problem, ck_tile::index_t IBuf = 0>
CK_TILE_DEVICE static constexpr auto
MakeVLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
{
using namespace ck_tile;
/// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
constexpr index_t kPad =
kVLdsPadInBytes /
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps.
// Optimize this for lds_read speed
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK =
kKPerBlock / KVector; // how many lane (within a wave) to load K
constexpr index_t LaneGroups =
WarpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<(IBuf + 2) * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return v_lds_block_desc_issues_warps_lanes;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVLdsLoadBlockDescriptor()
{
using namespace ck_tile;
/// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
kVLdsPadInBytes /
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto v_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<WarpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
{
using namespace ck_tile;
static_assert(MakeKLdsLoadBlockDescriptor<Problem>().get_element_space_size() ==
MakeKLdsStoreBlockDescriptor<Problem>().get_element_space_size());
constexpr index_t k_element_space_size =
MakeKLdsLoadBlockDescriptor<Problem>().get_element_space_size();
static_assert(MakeVLdsLoadBlockDescriptor<Problem>().get_element_space_size() ==
MakeVLdsStoreBlockDescriptor<Problem>().get_element_space_size());
constexpr index_t v_element_space_size =
MakeVLdsLoadBlockDescriptor<Problem>().get_element_space_size();
static_assert(ck_tile::max(k_element_space_size, v_element_space_size) <=
GetSingleSmemElementSpaceSize<Problem>());
/// TODO: override GetSingleSmemElementSpaceSize() to align with MakeKLdsBlockDescriptor() &
/// MakeVLdsBlockDescriptor()
static_assert(std::is_same_v<typename Problem::KDataType, typename Problem::VDataType>);
constexpr index_t kv_element_space_size_in_bytes =
GetSingleSmemElementSpaceSize<Problem>() * sizeof(typename Problem::KDataType);
return kv_element_space_size_in_bytes;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return 4 * GetSmemSizeKV<Problem>();
}
};
struct UnifiedAttentionPipelineDecodePolicy : UnifiedAttentionPipelineDefaultPolicy
{
static constexpr ck_tile::index_t NumWarpPerGroup = 2;
static constexpr ck_tile::index_t NumThreadPerWarpGroup =
NumWarpPerGroup * ck_tile::get_warp_size();
};
struct UnifiedAttentionPipelineTinyDecodePolicy : UnifiedAttentionPipelineDefaultPolicy
{
static constexpr ck_tile::index_t NumWarpPerGroup = 1;
static constexpr ck_tile::index_t NumThreadPerWarpGroup =
NumWarpPerGroup * ck_tile::get_warp_size();
};
} // namespace ck_tile

View File

@@ -0,0 +1,57 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename SaccDataType_,
typename SMPLComputeDataType_,
typename BiasDataType_,
typename RandValOutputDataType_,
typename PDataType_,
typename OaccDataType_,
typename ODataType_,
typename UnifiedAttentionShape_,
typename FmhaMask_,
typename Traits_>
struct UnifiedAttentionPipelineProblem
{
// TODO kM0 and KN1??
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
// first gemm accumulation dtype
using SaccDataType = remove_cvref_t<SaccDataType_>;
// Softmax dtype
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>;
using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
// data type for A matrix of second gemm
using PDataType = remove_cvref_t<PDataType_>;
// data type for second gemm accumulation
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using UnifiedAttentionShape = remove_cvref_t<UnifiedAttentionShape_>;
using Traits = remove_cvref_t<Traits_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
static constexpr index_t kNumGemm0Warps = UnifiedAttentionShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = UnifiedAttentionShape::NumGemm1Warps;
static constexpr index_t kBlockSize = UnifiedAttentionShape::NumWarps * get_warp_size();
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDim = Traits::kPadHeadDim;
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile