mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add unified attention (42_unified_attention) and topk_softmax_decode
Squashed from aghamari/unified-attention-decode-opt branch. 42_unified_attention: CK tile paged-KV attention kernel optimized for decode with 4-tier dispatch (tiny/small/medium/large), 16x16 MFMA, 2D decode grid, head-group merging. Supports hdim=64 GQA-8 and hdim=128 MHA with block_size=32. topk_softmax_decode: fused topk + softmax kernel for M=1 MoE decode. Made-with: Cursor
This commit is contained in:
313
include/ck_tile/ops/unified_attention/block/block_masking.hpp
Normal file
313
include/ck_tile/ops/unified_attention/block/block_masking.hpp
Normal file
@@ -0,0 +1,313 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct GenericAttentionMaskEnum
|
||||
{
|
||||
NO_MASK = 0,
|
||||
|
||||
// below enum could be causal, or sliding window
|
||||
MASK_FROM_TOP_LEFT = 1,
|
||||
MASK_FROM_BOTTOM_RIGHT = 2,
|
||||
|
||||
// this enum maybe not used by xformer/FA, since it's hard to
|
||||
// specify left/right window for varlen case. put it here for
|
||||
// debug purpose
|
||||
MASK_GENERIC,
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
/* generic Attention Mask Coordinate
|
||||
use x(horizontal axis), y(vertical axis) to describe mask.
|
||||
top-left corner is origin
|
||||
|
||||
x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask)
|
||||
1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
|
||||
1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
|
||||
1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
|
||||
1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
|
||||
1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
|
||||
l=7,-1/r=0(tl) l=7,-1/r=0(br)
|
||||
|
||||
x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2
|
||||
1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
|
||||
1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
|
||||
* 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1
|
||||
* * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1
|
||||
* * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1
|
||||
l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl)
|
||||
l=4/r=0(br) l=4/r=2(br) l=4/r=4(br)
|
||||
|
||||
x=4/y=-1 x=6/y=-1 x=8/y=-1
|
||||
* * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1
|
||||
* * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1
|
||||
* * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1
|
||||
* * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1
|
||||
* * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1
|
||||
|
||||
x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r)
|
||||
* * * * * * * * 1 * * * * * * *
|
||||
* * * * * * * * 1 1 * * 1 * * *
|
||||
* * * * * * * * 1 1 1 * 1 1 * *
|
||||
1 * * * * * * * 1 1 1 1 1 1 1 *
|
||||
1 1 * * * * * * 1 1 1 1 1 1 1 1
|
||||
|
||||
Validations:
|
||||
x + y > 1 (x + y >= 2)
|
||||
|
||||
Note:
|
||||
y = seq_q, x = 1 -> top-left
|
||||
y = seq_q, x = seq_k - seq_q + 1 -> bottom-right
|
||||
y < seq_q, x < seq_k -> local-attn
|
||||
y = seq_q, x = seq_k -> no mask
|
||||
|
||||
*/
|
||||
namespace impl {
|
||||
template <bool IsMasking_, bool IsLocal_> struct MaskName;
|
||||
template<> struct MaskName<false, false> { static constexpr const char * name = "mn"; };
|
||||
template<> struct MaskName<false, true> { static constexpr const char * name = "mn"; };
|
||||
template<> struct MaskName<true, false> { static constexpr const char * name = "mc"; };
|
||||
template<> struct MaskName<true, true> { static constexpr const char * name = "mg"; };
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
template <bool IsMasking_ = true, bool IsLocal_ = false>
|
||||
struct GenericAttentionMask
|
||||
{
|
||||
static constexpr bool IsMasking = IsMasking_; // false will disable masking
|
||||
static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask,
|
||||
// else only upper-right could have mask
|
||||
|
||||
static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
|
||||
|
||||
// New constructor accepting repeat_idx with default value 1
|
||||
CK_TILE_HOST_DEVICE
|
||||
GenericAttentionMask(index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1)
|
||||
: GenericAttentionMask(0, 0, y_total_, x_total_, repeat_idx_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
GenericAttentionMask(
|
||||
index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1)
|
||||
: y(y_), x(x_), y_total(y_total_), x_total(x_total_), repeat_idx(repeat_idx_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename MaskCoordinates>
|
||||
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord,
|
||||
index_t repeat_idx_ = 1)
|
||||
: y(mask_coord.at(number<0>{})),
|
||||
x(mask_coord.at(number<1>{})),
|
||||
y_total(mask_coord.at(number<2>{})),
|
||||
x_total(mask_coord.at(number<3>{})),
|
||||
repeat_idx(repeat_idx_)
|
||||
{
|
||||
}
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (like k-seqlen loopover)
|
||||
// TODO: x_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
// Transform the y index according to repeat_idx
|
||||
index_t y_eff = i_y / repeat_idx;
|
||||
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, x_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assuming we loop over along X tile by tile
|
||||
index_t x_start = [&]() {
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
index_t tmp = max(-y + y_eff + 1, 0);
|
||||
return (tmp / XTile) * XTile; // round to tile aligned
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t x_end = [&]() {
|
||||
index_t tmp = min(y_eff + YTile - 1 + x, x_total);
|
||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(x_start, x_end);
|
||||
}
|
||||
}
|
||||
|
||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||
// Note: this function does not take a dynamic y index so no transform is needed
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, y_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assuming we loop over along Y tile by tile
|
||||
index_t y_start = [&]() {
|
||||
index_t tmp = max(-x + i_x + 1, 0);
|
||||
return (tmp / YTile) * YTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t y_end = [&]() {
|
||||
index_t tmp = min(i_x + XTile - 1 + y, y_total);
|
||||
return ((tmp + YTile - 1) / YTile) * YTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(y_start, y_end);
|
||||
}
|
||||
}
|
||||
|
||||
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
// Transform the y index according to repeat_idx
|
||||
index_t y_eff = i_y / repeat_idx;
|
||||
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return i_x >= x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
// no need to do min/max here, since i_x will never be < 0 or >= x_total
|
||||
index_t x_start = -y + y_eff + 1;
|
||||
index_t x_end = min(y_eff + x, x_total);
|
||||
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
return i_x < x_start || i_x >= x_end;
|
||||
}
|
||||
else
|
||||
{
|
||||
return i_x >= x_end || y_eff >= y_total;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the index passed in this function is within range of GetTileRangeAlongX/Y()
|
||||
// can be used as a fast-path to decide if do per-pixel check or not
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number<TileHeight>, number<TileWidth>) const
|
||||
{
|
||||
// Transform the y index according to repeat_idx
|
||||
index_t y_eff = i_tile_top / repeat_idx;
|
||||
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
// TODO: no need to check begin
|
||||
return (i_tile_left + TileWidth) > x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
// check top-right corner > x or left-bottom corner < x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_bottom = y_eff + TileHeight;
|
||||
index_t x_end = min(y_eff + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > (y_eff + x);
|
||||
bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
|
||||
bool is_partial_out_of_bound =
|
||||
i_tile_right > x_end; // only consider right-pad for now
|
||||
|
||||
return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
|
||||
}
|
||||
else
|
||||
{
|
||||
// only need to check top-right corner > x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t x_end = min(y_eff + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > x_end;
|
||||
return top_right_edge;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x;
|
||||
index_t y_total, x_total;
|
||||
index_t repeat_idx;
|
||||
};
|
||||
|
||||
template <typename>
|
||||
struct is_generic_attention_mask : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <bool IsMasking, bool IsLocal>
|
||||
struct is_generic_attention_mask<GenericAttentionMask<IsMasking, IsLocal>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Mask>
|
||||
static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask<Mask>::value;
|
||||
|
||||
// TODO: prefer use this function in host code
|
||||
// can convert from the FA style left/right to our generic coordinate
|
||||
// if left_size < 0 && right_size = 0, it is normal causal mask
|
||||
// local is left_size >=0 or right_size >=0
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
|
||||
index_t right_size,
|
||||
index_t y_total,
|
||||
index_t x_total,
|
||||
bool is_top_left = true)
|
||||
{
|
||||
// TODO: below should all use sgpr arithmetic
|
||||
index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1;
|
||||
index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1;
|
||||
|
||||
left_size = left_size < 0 ? left_size_tmp : left_size;
|
||||
right_size = right_size < 0 ? right_size_tmp : right_size;
|
||||
|
||||
index_t x_tmp = is_top_left ? 0 : x_total - y_total;
|
||||
index_t y_tmp = is_top_left ? 0 : y_total - x_total;
|
||||
|
||||
index_t x = 1 + right_size + x_tmp;
|
||||
index_t y = 1 + left_size + y_tmp;
|
||||
|
||||
return ck_tile::make_tuple(y, x, y_total, x_total);
|
||||
}
|
||||
|
||||
template <typename MaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_generic_attention_mask_from_lr_window(index_t left_size,
|
||||
index_t right_size,
|
||||
index_t y_total,
|
||||
index_t x_total,
|
||||
index_t repeat_idx = 1,
|
||||
bool is_top_left = true)
|
||||
{
|
||||
auto r = make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, is_top_left);
|
||||
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total, repeat_idx};
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,487 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/unified_attention/block/block_masking.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename UnifiedAttentionPipeline_, typename EpiloguePipeline_>
|
||||
struct UnifiedAttentionKernel
|
||||
{
|
||||
using UnifiedAttentionPipeline = ck_tile::remove_cvref_t<UnifiedAttentionPipeline_>;
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = UnifiedAttentionPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = UnifiedAttentionPipeline::kBlockPerCu;
|
||||
static_assert(kBlockPerCu > 0);
|
||||
|
||||
using QDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::VDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::ODataType>;
|
||||
using SaccDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::SaccDataType>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
static constexpr bool kPadSeqLenK = UnifiedAttentionPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadSeqLenQ = UnifiedAttentionPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimQ = UnifiedAttentionPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = UnifiedAttentionPipeline::kPadHeadDimV;
|
||||
|
||||
static constexpr index_t kHeadDim = UnifiedAttentionPipeline::kHeadDim;
|
||||
static constexpr index_t kHeadDimPadded = UnifiedAttentionPipeline::kHeadDimPadded;
|
||||
|
||||
// kBlockQ = kBlockM // num_queries_per_kv
|
||||
// kBlockQ is the block size for q seqlen
|
||||
/// static constexpr index_t kBlockQ = UnifiedAttentionPipeline::kBlockQ;
|
||||
static constexpr index_t kBlockM = UnifiedAttentionPipeline::kBlockM;
|
||||
static constexpr index_t kBlockQ = UnifiedAttentionPipeline::kBlockQ;
|
||||
// BLOCK size for K seqlen
|
||||
static constexpr index_t kPageBlockSize = UnifiedAttentionPipeline::kPageBlockSize;
|
||||
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
// The attention is default causal
|
||||
struct UnifiedAttentionCommonKargs
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr; // [num_blks, page_size, num_kv_heads, head_size]
|
||||
const void* v_ptr; // [num_blks, page_size, num_kv_heads, head_size]
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t num_blks;
|
||||
ck_tile::index_t num_head_q;
|
||||
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
|
||||
// if this param is larger than 1, indicate MQA/GQA case
|
||||
const ck_tile::index_t num_queries_per_kv;
|
||||
// scales
|
||||
float scale_s;
|
||||
float scale;
|
||||
float scale_k;
|
||||
float scale_v;
|
||||
float scale_out;
|
||||
|
||||
ck_tile::index_t page_size;
|
||||
|
||||
ck_tile::index_t total_num_q_blocks;
|
||||
ck_tile::index_t query_stride_0;
|
||||
ck_tile::index_t query_stride_1;
|
||||
ck_tile::index_t stride_k_cache_0;
|
||||
ck_tile::index_t stride_k_cache_1;
|
||||
ck_tile::index_t stride_k_cache_2;
|
||||
ck_tile::index_t stride_k_cache_3;
|
||||
ck_tile::index_t stride_v_cache_0;
|
||||
ck_tile::index_t stride_v_cache_1;
|
||||
ck_tile::index_t stride_v_cache_2;
|
||||
ck_tile::index_t stride_v_cache_3;
|
||||
ck_tile::index_t output_stride_0;
|
||||
ck_tile::index_t output_stride_1;
|
||||
};
|
||||
|
||||
struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs
|
||||
{
|
||||
const int32_t* block_tables_ptr;
|
||||
ck_tile::index_t block_table_stride;
|
||||
const int32_t* seq_lens_ptr; // seq len in each batch
|
||||
const int32_t* query_start_len_ptr; // [num_seqs+1]
|
||||
|
||||
ck_tile::index_t num_seqs; // number of batches for q
|
||||
};
|
||||
|
||||
using Kargs = UnifiedAttentionVarlenKargs;
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t num_blks,
|
||||
ck_tile::index_t num_head_q,
|
||||
const ck_tile::index_t num_queries_per_kv,
|
||||
float scale_s,
|
||||
float scale,
|
||||
float scale_k,
|
||||
float scale_v,
|
||||
float scale_out,
|
||||
ck_tile::index_t page_size,
|
||||
ck_tile::index_t total_num_q_blocks,
|
||||
ck_tile::index_t query_stride_0,
|
||||
ck_tile::index_t query_stride_1,
|
||||
ck_tile::index_t stride_k_cache_0,
|
||||
ck_tile::index_t stride_k_cache_1,
|
||||
ck_tile::index_t stride_k_cache_2,
|
||||
ck_tile::index_t stride_k_cache_3,
|
||||
ck_tile::index_t stride_v_cache_0,
|
||||
ck_tile::index_t stride_v_cache_1,
|
||||
ck_tile::index_t stride_v_cache_2,
|
||||
ck_tile::index_t stride_v_cache_3,
|
||||
ck_tile::index_t output_stride_0,
|
||||
ck_tile::index_t output_stride_1,
|
||||
const int32_t* block_tables_ptr,
|
||||
ck_tile::index_t block_table_stride,
|
||||
const int32_t* seq_lens_ptr,
|
||||
const int32_t* query_start_len_ptr,
|
||||
ck_tile::index_t num_seqs)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
num_blks,
|
||||
num_head_q,
|
||||
num_queries_per_kv,
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
scale,
|
||||
scale_k,
|
||||
scale_v,
|
||||
scale_out,
|
||||
page_size,
|
||||
total_num_q_blocks,
|
||||
query_stride_0,
|
||||
query_stride_1,
|
||||
stride_k_cache_0,
|
||||
stride_k_cache_1,
|
||||
stride_k_cache_2,
|
||||
stride_k_cache_3,
|
||||
stride_v_cache_0,
|
||||
stride_v_cache_1,
|
||||
stride_v_cache_2,
|
||||
stride_v_cache_3,
|
||||
output_stride_0,
|
||||
output_stride_1},
|
||||
block_tables_ptr,
|
||||
block_table_stride,
|
||||
seq_lens_ptr,
|
||||
query_start_len_ptr,
|
||||
num_seqs};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads,
|
||||
ck_tile::index_t total_num_q_blocks)
|
||||
{
|
||||
return dim3(num_kv_heads * total_num_q_blocks);
|
||||
}
|
||||
|
||||
// Binary search to find the sequence index for a given target index
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t
|
||||
find_seq_idx(const int32_t* query_start_len_ptr,
|
||||
ck_tile::index_t target_idx,
|
||||
ck_tile::index_t num_seqs,
|
||||
ck_tile::index_t block_q,
|
||||
bool use_q_block_mode)
|
||||
{
|
||||
ck_tile::index_t left = 0;
|
||||
ck_tile::index_t right = num_seqs;
|
||||
|
||||
while(left < right)
|
||||
{
|
||||
ck_tile::index_t mid = (left + right) / 2;
|
||||
ck_tile::index_t val = amd_wave_read_first_lane(query_start_len_ptr[mid]);
|
||||
ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val;
|
||||
|
||||
if(mid_val <= target_idx)
|
||||
{
|
||||
left = mid + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
right = mid;
|
||||
}
|
||||
}
|
||||
|
||||
return left - 1;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid,
|
||||
const Kargs& kargs)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
ck_tile::index_t num_head_kv = kargs.num_head_q / kargs.num_queries_per_kv;
|
||||
|
||||
return ck_tile::make_tuple(pid % num_head_kv, pid / num_head_kv);
|
||||
}
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return ck_tile::max(UnifiedAttentionPipeline::GetSmemSize(),
|
||||
EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSizeDecode(ck_tile::index_t num_kv_heads,
|
||||
ck_tile::index_t num_seqs)
|
||||
{
|
||||
return dim3(num_kv_heads, num_seqs);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
const index_t num_queries_per_kv = kargs.num_queries_per_kv;
|
||||
|
||||
assert(kBlockM / num_queries_per_kv == kBlockQ);
|
||||
|
||||
index_t kv_head_idx;
|
||||
index_t seq_idx;
|
||||
index_t q_block_local_idx;
|
||||
index_t cur_batch_in_all_start_index;
|
||||
index_t cur_batch_query_len;
|
||||
|
||||
if(gridDim.y > 1)
|
||||
{
|
||||
// Decode grid: dim3(num_kv_heads, num_seqs)
|
||||
// Direct mapping, no binary search, no padding CTAs.
|
||||
kv_head_idx = blockIdx.x;
|
||||
seq_idx = blockIdx.y;
|
||||
q_block_local_idx = 0;
|
||||
cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx];
|
||||
const index_t stop = kargs.query_start_len_ptr[seq_idx + 1];
|
||||
cur_batch_query_len = amd_wave_read_first_lane(stop - cur_batch_in_all_start_index);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Standard 1D grid with binary search
|
||||
ck_tile::index_t pid = blockIdx.x;
|
||||
|
||||
const auto [kv_head_idx_, q_block_global_idx] = GetTileIndex(pid, kargs);
|
||||
kv_head_idx = kv_head_idx_;
|
||||
|
||||
if(q_block_global_idx >= kargs.total_num_q_blocks)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
seq_idx = find_seq_idx(kargs.query_start_len_ptr,
|
||||
q_block_global_idx,
|
||||
kargs.num_seqs,
|
||||
kBlockQ,
|
||||
true);
|
||||
|
||||
const index_t q_block_start_idx =
|
||||
kargs.query_start_len_ptr[seq_idx] / kBlockQ + seq_idx;
|
||||
|
||||
q_block_local_idx =
|
||||
amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
|
||||
|
||||
cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx];
|
||||
const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1];
|
||||
|
||||
cur_batch_query_len =
|
||||
amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index);
|
||||
|
||||
if(q_block_local_idx * kBlockQ >= cur_batch_query_len)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * kBlockQ);
|
||||
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
|
||||
|
||||
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
|
||||
|
||||
index_t _max_seq_prefix_len = amd_wave_read_first_lane(
|
||||
(context_len + q_block_local_idx * kBlockQ + (kBlockM - 1) + 1));
|
||||
|
||||
if(seq_len < _max_seq_prefix_len)
|
||||
{
|
||||
_max_seq_prefix_len = seq_len;
|
||||
}
|
||||
|
||||
const auto max_seq_prefix_len = _max_seq_prefix_len;
|
||||
const index_t num_blocks =
|
||||
amd_wave_read_first_lane((max_seq_prefix_len + kPageBlockSize - 1) / kPageBlockSize);
|
||||
|
||||
// TODO sliding window
|
||||
const index_t num_blocks_start = 0;
|
||||
long_index_t kv_head_offset = static_cast<long_index_t>(kv_head_idx) * kargs.stride_k_cache_2;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
index_t q_ptr_offset_0 = cur_batch_in_all_start_index *
|
||||
kargs.query_stride_0; // move the pointer to the batch start
|
||||
index_t q_ptr_offset_1 =
|
||||
kv_head_idx * num_queries_per_kv *
|
||||
kargs.query_stride_1; // move the pointer to the correct head group start
|
||||
index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1;
|
||||
|
||||
index_t o_ptr_offset_0 = cur_batch_in_all_start_index *
|
||||
kargs.output_stride_0; // move the pointer to the batch start
|
||||
index_t o_ptr_offset_1 =
|
||||
kv_head_idx * num_queries_per_kv *
|
||||
kargs.output_stride_1; // move the pointer to the correct head group start
|
||||
index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1;
|
||||
index_t block_table_offset = seq_idx * kargs.block_table_stride;
|
||||
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) + q_ptr_offset;
|
||||
const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) + kv_head_offset;
|
||||
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) + kv_head_offset;
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) + o_ptr_offset;
|
||||
|
||||
index_t query_len_padded =
|
||||
amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, kBlockQ) * kBlockQ);
|
||||
// const bool is_query_len_padded = (cur_batch_query_len % kBlockQ == 0);
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_base = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim),
|
||||
make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1),
|
||||
number<UnifiedAttentionPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
|
||||
const auto q_dram_pad =
|
||||
pad_tensor_view( // aling seqlen with kBlockQ and head dim with kHeadDimPadded
|
||||
q_dram_base,
|
||||
// block sizes
|
||||
make_tuple(number<kBlockQ>{}, 1, kHeadDimPadded),
|
||||
sequence<true, false, kPadHeadDimQ>{}); // pads to (seq_len_padded, num_head_q,
|
||||
// kHeadDimPadded)
|
||||
|
||||
const auto q_dram_merged = transform_tensor_view(
|
||||
q_dram_pad,
|
||||
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
|
||||
make_pass_through_transform(kHeadDimPadded)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{})); // flattens the first two dims, head idx is the fastest
|
||||
// changing dim in the merged dim
|
||||
|
||||
return q_dram_merged;
|
||||
}();
|
||||
// static_assert(q_dram.desc_[number<0>{}] == 0,
|
||||
// "q_dram.get_bottom_tensor_view()[number<0>{}] == 0");
|
||||
|
||||
// Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim)
|
||||
// stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1)
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram,
|
||||
make_tuple(number<kBlockM>{}, number<kHeadDimPadded>{}),
|
||||
{query_pos * num_queries_per_kv, 0});
|
||||
|
||||
const auto k_dram = [&]() {
|
||||
// Use long_index_t for size/strides to prevent int32 overflow
|
||||
// when row * stride exceeds 2^31 (happens at ~66K blocks for d64/GQA-8).
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(static_cast<long_index_t>(kargs.num_blks) * kargs.page_size,
|
||||
static_cast<long_index_t>(kHeadDim)),
|
||||
make_tuple(static_cast<long_index_t>(kargs.stride_k_cache_1),
|
||||
static_cast<long_index_t>(kargs.stride_k_cache_3)),
|
||||
number<UnifiedAttentionPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
const auto k_dram_pad =
|
||||
pad_tensor_view(k_dram_naive,
|
||||
make_tuple(kPageBlockSize, kHeadDimPadded),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
return k_dram_pad;
|
||||
}();
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram, make_tuple(number<kPageBlockSize>{}, number<kHeadDimPadded>{}), {0, 0});
|
||||
|
||||
const auto v_dram = [&]() {
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(static_cast<long_index_t>(kargs.num_blks) * kargs.page_size,
|
||||
static_cast<long_index_t>(kHeadDim)),
|
||||
make_tuple(static_cast<long_index_t>(kargs.stride_v_cache_1),
|
||||
static_cast<long_index_t>(kargs.stride_v_cache_3)),
|
||||
number<UnifiedAttentionPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_pad = pad_tensor_view(v_dram_naive,
|
||||
make_tuple(kPageBlockSize, kHeadDimPadded),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
return v_dram_pad;
|
||||
}();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram, make_tuple(number<kPageBlockSize>{}, number<kHeadDimPadded>{}), {0, 0});
|
||||
|
||||
FmhaMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
-1,
|
||||
0,
|
||||
cur_batch_query_len, // y_total
|
||||
seq_len, // x_total
|
||||
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
|
||||
// times along x dim of the tile
|
||||
false);
|
||||
else
|
||||
return FmhaMask{cur_batch_query_len, seq_len};
|
||||
}();
|
||||
|
||||
const index_t kv_page_size_in_blocks = kargs.page_size / kPageBlockSize;
|
||||
assert(kv_page_size_in_blocks >= 1); // kPageBlockSize <= page_size
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
return UnifiedAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
num_blocks,
|
||||
num_blocks_start,
|
||||
kargs.block_tables_ptr,
|
||||
block_table_offset,
|
||||
kv_page_size_in_blocks,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
smem_ptr);
|
||||
}();
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
auto o_dram = [&]() {
|
||||
const auto o_dram_base = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr,
|
||||
make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim),
|
||||
make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1),
|
||||
number<UnifiedAttentionPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
const auto o_dram_pad =
|
||||
pad_tensor_view( // aling cu_seqlen with kBlockQ and head dim with kHeadDimPadded
|
||||
o_dram_base,
|
||||
// block sizes
|
||||
make_tuple(kBlockQ, 1, kHeadDimPadded),
|
||||
sequence<true, false, kPadHeadDimQ>{}); // pads to (seq_len_padded, num_head_q,
|
||||
// kHeadDimPadded)
|
||||
|
||||
const auto o_dram_merged = transform_tensor_view(
|
||||
o_dram_pad,
|
||||
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
|
||||
make_pass_through_transform(kHeadDimPadded)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return o_dram_merged;
|
||||
}();
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram,
|
||||
make_tuple(number<kBlockM>{}, number<kHeadDimPadded>{}),
|
||||
{query_pos * num_queries_per_kv, 0});
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t Headdim>
|
||||
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length()
|
||||
{
|
||||
if constexpr(Headdim == 48)
|
||||
return 48;
|
||||
else if constexpr(Headdim == 96)
|
||||
return 128;
|
||||
else if constexpr(Headdim == 160)
|
||||
return 256;
|
||||
else if constexpr(Headdim == 192)
|
||||
return 192;
|
||||
else if constexpr(is_power_of_two_integer(Headdim))
|
||||
return Headdim;
|
||||
else
|
||||
static_assert(Headdim == 0,
|
||||
"only Headdim of 48, 96, 160, 192 and power-of-two is supported");
|
||||
};
|
||||
|
||||
template <typename BlockTile_, // sequence<...
|
||||
typename Gemm0BlockWarps_,
|
||||
typename Gemm0WarpTile_,
|
||||
typename Gemm1BlockWarps_,
|
||||
typename Gemm1WarpTile_,
|
||||
bool IsVLayoutRowMajor_>
|
||||
struct TileUnifiedAttentionShape
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
|
||||
using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
|
||||
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
|
||||
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
|
||||
|
||||
static constexpr index_t NumGemm0Warps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
static constexpr index_t NumGemm1Warps =
|
||||
reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
|
||||
|
||||
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
|
||||
|
||||
static constexpr index_t kBlockM = BlockTile::at(
|
||||
number<0>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS)
|
||||
static constexpr index_t kBlockQ = BlockTile::at(
|
||||
number<1>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS)
|
||||
// static constexpr index_t kBlockM = BlockTile::at(number<1>{}); // tile size along q seqlen *
|
||||
// num_queries_per_kv (q_head//kv_head)
|
||||
static constexpr index_t kPageBlockSize =
|
||||
BlockTile::at(number<2>{}); // BLOCK size for K seqlen
|
||||
static constexpr index_t kHeadDim = BlockTile::at(number<3>{}); // BLOCK size for K seqlen
|
||||
|
||||
// BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
|
||||
// // once (or repeately load Q as a whole tile)
|
||||
// static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
|
||||
|
||||
static constexpr index_t kHeadDimPadded = ceil_to_qualified_tile_length<kHeadDim>();
|
||||
|
||||
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
|
||||
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
|
||||
using VLayout = std::conditional_t<IsVLayoutRowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDim_ /* paddding for hdim_v */,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileUnifiedAttentionTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadHeadDim = kPadHeadDim_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,613 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct UnifiedAttentionPipelineDefaultPolicy
|
||||
{
|
||||
static constexpr ck_tile::index_t NumWarpPerGroup = 4;
|
||||
static constexpr ck_tile::index_t NumThreadPerWarpGroup =
|
||||
NumWarpPerGroup * ck_tile::get_warp_size();
|
||||
|
||||
// TODO: GetAlignment*() currently didn't consider if need padding or not
|
||||
// so in pipeline still need check padding requirement
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
|
||||
{
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetAlignmentK()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t MaxReadSizeInBytes = 16;
|
||||
#else
|
||||
constexpr index_t MaxReadSizeInBytes = 4;
|
||||
#endif
|
||||
return MaxReadSizeInBytes / sizeof(KDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t MaxReadSizeInBytes = 16;
|
||||
#else
|
||||
constexpr index_t MaxReadSizeInBytes = 4;
|
||||
#endif
|
||||
return MaxReadSizeInBytes / sizeof(VDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// TODO: this is for 3d layout
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
return 16 / sizeof(KDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemVPackK()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// TODO: this is for 3d layout
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
return 16 / sizeof(VDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr index_t N0 = NumIssues;
|
||||
constexpr index_t N1 = LaneGroups;
|
||||
constexpr index_t N2 = NumWarps;
|
||||
constexpr index_t K0 = LanesPerK;
|
||||
constexpr index_t K1 = KVector;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<2>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size(); // 64
|
||||
|
||||
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
|
||||
// 4
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr index_t N0 = NumIssues; // 8
|
||||
constexpr index_t N1 = LaneGroups; // 2
|
||||
constexpr index_t N2 = NumWarps; // 8
|
||||
constexpr index_t K0 = LanesPerK; // 32
|
||||
constexpr index_t K1 = KVector; // 4
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<2>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeQRegTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
|
||||
return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeKRegTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
|
||||
return make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakePRegTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
|
||||
|
||||
return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVRegTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto v_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto v_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding(
|
||||
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
// compute the endcoding before transpose
|
||||
constexpr auto v_block_dstr =
|
||||
make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(v_block_dstr_encode),
|
||||
typename Problem::VDataType>::TransposedDstrEncode{});
|
||||
|
||||
return v_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::UnifiedAttentionShape::kBlockM,
|
||||
Problem::UnifiedAttentionShape::kPageBlockSize,
|
||||
Problem::UnifiedAttentionShape::kHeadDim>,
|
||||
typename Problem::UnifiedAttentionShape::Gemm0BlockWarps,
|
||||
typename Problem::UnifiedAttentionShape::Gemm0WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
typename Problem::UnifiedAttentionShape::Gemm0BlockWarps,
|
||||
WarpGemm,
|
||||
GemmLoopOrder::MNK>;
|
||||
|
||||
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetPVBlockGemm()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::UnifiedAttentionShape::kBlockM,
|
||||
Problem::UnifiedAttentionShape::kHeadDim,
|
||||
Problem::UnifiedAttentionShape::kPageBlockSize>,
|
||||
typename Problem::UnifiedAttentionShape::Gemm1BlockWarps,
|
||||
typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>;
|
||||
/// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass
|
||||
/// WGAttrNumAccessEnum::Double instead of WGAttrNumAccessEnum::Single
|
||||
using WarpGemm =
|
||||
WarpGemmDispatcher<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Double>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
typename Problem::UnifiedAttentionShape::Gemm1BlockWarps,
|
||||
WarpGemm,
|
||||
GemmLoopOrder::MNK>;
|
||||
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords
|
||||
static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords
|
||||
|
||||
template <typename Problem, ck_tile::index_t IBuf = 0>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeKLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// K is always k-major, we use async-copy to load into LDS
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
[[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kKLdsPadInBytes /
|
||||
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps.
|
||||
// Optimize this for lds_read speed
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK =
|
||||
kKPerBlock / KVector; // how many lane (within a wave) to load K
|
||||
constexpr index_t LaneGroups =
|
||||
WarpSize /
|
||||
LanesPerK; // how many groups (within a wave), they may load different N, but same K
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
|
||||
make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
|
||||
// TODO this layout is hard coded, and will be used in async copy buffer view load
|
||||
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
|
||||
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return k_lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// K is always k-major, we use async-copy to load into LDS
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kKLdsPadInBytes /
|
||||
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
|
||||
number<NumWarps>{}, // n2
|
||||
number<LaneGroups>{}, // n1
|
||||
number<kKPerBlock / KPack>{}, // k0
|
||||
number<KPack>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<kKPerBlock>{},
|
||||
number<KPack>{},
|
||||
number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
|
||||
{
|
||||
// this function assume K/V can share smem
|
||||
constexpr index_t SingleKSize = [&]() {
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad = KPack;
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector;
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK;
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
|
||||
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
}();
|
||||
|
||||
constexpr index_t SingleVSize = [&]() {
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
constexpr index_t NPerRow = PixelsPerRow / kKPack;
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
static_assert(kNPerBlock % NPerRow == 0);
|
||||
static_assert(kKPerBlock % kKPack == 0);
|
||||
|
||||
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
|
||||
}();
|
||||
|
||||
return max(SingleKSize, SingleVSize);
|
||||
}
|
||||
|
||||
template <typename Problem, ck_tile::index_t IBuf = 0>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeVLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
/// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
[[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kVLdsPadInBytes /
|
||||
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps.
|
||||
// Optimize this for lds_read speed
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK =
|
||||
kKPerBlock / KVector; // how many lane (within a wave) to load K
|
||||
constexpr index_t LaneGroups =
|
||||
WarpSize /
|
||||
LanesPerK; // how many groups (within a wave), they may load different N, but same K
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
|
||||
make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<(IBuf + 2) * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
|
||||
// TODO this layout is hard coded, and will be used in async copy buffer view load
|
||||
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
|
||||
constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return v_lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVLdsLoadBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
/// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kVLdsPadInBytes /
|
||||
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto v_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
|
||||
number<NumWarps>{}, // n2
|
||||
number<LaneGroups>{}, // n1
|
||||
number<kKPerBlock / KPack>{}, // k0
|
||||
number<KPack>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<kKPerBlock>{},
|
||||
number<KPack>{},
|
||||
number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
static_assert(MakeKLdsLoadBlockDescriptor<Problem>().get_element_space_size() ==
|
||||
MakeKLdsStoreBlockDescriptor<Problem>().get_element_space_size());
|
||||
constexpr index_t k_element_space_size =
|
||||
MakeKLdsLoadBlockDescriptor<Problem>().get_element_space_size();
|
||||
|
||||
static_assert(MakeVLdsLoadBlockDescriptor<Problem>().get_element_space_size() ==
|
||||
MakeVLdsStoreBlockDescriptor<Problem>().get_element_space_size());
|
||||
constexpr index_t v_element_space_size =
|
||||
MakeVLdsLoadBlockDescriptor<Problem>().get_element_space_size();
|
||||
|
||||
static_assert(ck_tile::max(k_element_space_size, v_element_space_size) <=
|
||||
GetSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
/// TODO: override GetSingleSmemElementSpaceSize() to align with MakeKLdsBlockDescriptor() &
|
||||
/// MakeVLdsBlockDescriptor()
|
||||
static_assert(std::is_same_v<typename Problem::KDataType, typename Problem::VDataType>);
|
||||
constexpr index_t kv_element_space_size_in_bytes =
|
||||
GetSingleSmemElementSpaceSize<Problem>() * sizeof(typename Problem::KDataType);
|
||||
|
||||
return kv_element_space_size_in_bytes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return 4 * GetSmemSizeKV<Problem>();
|
||||
}
|
||||
};
|
||||
|
||||
struct UnifiedAttentionPipelineDecodePolicy : UnifiedAttentionPipelineDefaultPolicy
|
||||
{
|
||||
static constexpr ck_tile::index_t NumWarpPerGroup = 2;
|
||||
static constexpr ck_tile::index_t NumThreadPerWarpGroup =
|
||||
NumWarpPerGroup * ck_tile::get_warp_size();
|
||||
};
|
||||
|
||||
struct UnifiedAttentionPipelineTinyDecodePolicy : UnifiedAttentionPipelineDefaultPolicy
|
||||
{
|
||||
static constexpr ck_tile::index_t NumWarpPerGroup = 1;
|
||||
static constexpr ck_tile::index_t NumThreadPerWarpGroup =
|
||||
NumWarpPerGroup * ck_tile::get_warp_size();
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,57 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename BiasDataType_,
|
||||
typename RandValOutputDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
typename ODataType_,
|
||||
typename UnifiedAttentionShape_,
|
||||
typename FmhaMask_,
|
||||
typename Traits_>
|
||||
struct UnifiedAttentionPipelineProblem
|
||||
{
|
||||
// TODO kM0 and KN1??
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
// first gemm accumulation dtype
|
||||
using SaccDataType = remove_cvref_t<SaccDataType_>;
|
||||
// Softmax dtype
|
||||
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
|
||||
// data type for A matrix of second gemm
|
||||
using PDataType = remove_cvref_t<PDataType_>;
|
||||
// data type for second gemm accumulation
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using UnifiedAttentionShape = remove_cvref_t<UnifiedAttentionShape_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
|
||||
static constexpr index_t kNumGemm0Warps = UnifiedAttentionShape::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = UnifiedAttentionShape::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = UnifiedAttentionShape::NumWarps * get_warp_size();
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDim = Traits::kPadHeadDim;
|
||||
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
|
||||
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user