mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Optimize batch prefill kernel performance for VECTORIZED_LAYOUT KV cache (#3657)
- Add multi-dimensional page index support (YsGatherDims) in tile_scatter_gather
- Add is_gather_dim() and get_gather_index() for multi-dim page lookup
- Override MakeVDramTileDistribution() for VECTORIZED_LAYOUT to match
GEMM's BWarpDstrEncoding (K decomposition: {K2, K0, K1})
- Add GetGemmKDecomposition() to retrieve kABKLane and kKPerThread
- Add static_assert for RowMajor VLayout requirement in batch prefill
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -533,32 +533,170 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
|
||||
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
|
||||
auto v_coord = v_dist.calculate_index();
|
||||
const auto VPageIndexDim = I1;
|
||||
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
|
||||
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
|
||||
statically_indexed_array<index_t, V_KRepeat> v_offsets;
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
0,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
|
||||
auto v_coord = v_dist.calculate_index();
|
||||
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
|
||||
|
||||
// V tensor K-dimension decomposition for page index computation
|
||||
// ============================================================
|
||||
// The K dimension (seqlen_k) in V distribution is decomposed into multiple sub-dimensions.
|
||||
// This decomposition determines how threads iterate over the K dimension and how page
|
||||
// indices are computed for paged KV cache.
|
||||
//
|
||||
// The decomposition pattern differs by memory layout:
|
||||
//
|
||||
// VECTORIZED_LAYOUT (ColumnMajor, custom distribution):
|
||||
// 3D decomposition: K = K2 × K0 × K1
|
||||
// - K2 (V_KIterOuter): Outer iteration count
|
||||
// - K0 (V_KLanes): Lanes for K dimension (matches GEMM kABKLane)
|
||||
// - K1 (V_KIterInner): Vector load size (matches GEMM kKPerThread)
|
||||
// - hs_lengthss_[I1] = {K2, K0, K1}, size = 3 (or {K0, K1} size = 2 if no outer iter)
|
||||
//
|
||||
// LINEAR_LAYOUT ColumnMajor (base class distribution):
|
||||
// 2D decomposition: K = K0 × K1
|
||||
// - K0: Lanes for K dimension (may not match GEMM kABKLane)
|
||||
// - K1: Vector load size
|
||||
// - hs_lengthss_[I1] = {K0, K1}, size = 2
|
||||
//
|
||||
// LINEAR_LAYOUT RowMajor (base class distribution):
|
||||
// 4D decomposition: K = K0 × K1 × K2 × K3 (uses shuffle_tile for GEMM alignment)
|
||||
// 3D decomposition: K = K0 × K1 × K2 (fallback case)
|
||||
// - Page lookup uses Y-space's last dimension only (inner iteration)
|
||||
//
|
||||
// V_PageIdxRepeat = total number of page lookups per thread = V_KIterOuter × V_KIterInner
|
||||
constexpr index_t V_KIterInner = VDstrEncode::hs_lengthss_[I1].back();
|
||||
|
||||
// Compute V_KIterOuter and V_KLanes based on memory layout and K decomposition
|
||||
constexpr index_t V_KIterOuter = [] {
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// VECTORIZED_LAYOUT: 3D decomposition {K2, K0, K1} when outer iteration is needed
|
||||
if constexpr(VDstrEncode::hs_lengthss_[I1].size() == 3)
|
||||
return static_cast<index_t>(VDstrEncode::hs_lengthss_[I1][I0]);
|
||||
else
|
||||
return index_t{1};
|
||||
}
|
||||
else
|
||||
{
|
||||
// LINEAR_LAYOUT: No outer iteration for page lookup
|
||||
// RowMajor uses shuffle_tile, ColumnMajor has simple 2D decomposition
|
||||
// Both cases use single-dimension Y-space page lookup
|
||||
return index_t{1};
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr index_t V_KLanes = [] {
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// VECTORIZED_LAYOUT: K0 is the lanes dimension
|
||||
if constexpr(V_KIterOuter > 1)
|
||||
return static_cast<index_t>(VDstrEncode::hs_lengthss_[I1][I1]);
|
||||
else
|
||||
return static_cast<index_t>(VDstrEncode::hs_lengthss_[I1][I0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
// LINEAR_LAYOUT: First dimension is K0 (lanes)
|
||||
return static_cast<index_t>(VDstrEncode::hs_lengthss_[I1][I0]);
|
||||
}
|
||||
}();
|
||||
|
||||
// This affects page offset computation - need to track offsets for each (k2, k1)
|
||||
// combination
|
||||
constexpr index_t V_PageIdxRepeat = V_KIterInner * V_KIterOuter;
|
||||
|
||||
// VPageIndexYDims: Y-space dimension indices that participate in page index computation
|
||||
// ================================================================================
|
||||
// In tile_scatter_gather, the gather index is computed from Y-space coordinates.
|
||||
// This sequence specifies which Y dimensions should be linearized to form the page lookup
|
||||
// index.
|
||||
//
|
||||
// VECTORIZED_LAYOUT with outer iteration: sequence<Y_K1, Y_K2>
|
||||
// - Both K1 and K2 are in Y-space (thread iteration dimensions)
|
||||
// - gather_index = y_k1 + y_k2 * len(Y_K1) (linearized 2D -> 1D)
|
||||
//
|
||||
// VECTORIZED_LAYOUT without outer iteration / LINEAR_LAYOUT: sequence<Y_K1>
|
||||
// - Only the innermost K dimension is used for page lookup (single dimension)
|
||||
//
|
||||
constexpr auto VPageIndexYDims = []() {
|
||||
// K1Minor is always the last element index in hs_lengthss_[I1]
|
||||
constexpr index_t K1Minor = VDstrEncode::hs_lengthss_[I1].size() - 1;
|
||||
constexpr index_t Y_K1 = VDstrEncode::detail::rhs_major_minor_to_ys_[2][K1Minor];
|
||||
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT &&
|
||||
V_KIterOuter > 1)
|
||||
{
|
||||
// VECTORIZED_LAYOUT with outer iteration: need 2D page lookup
|
||||
constexpr index_t Y_K2 = VDstrEncode::detail::rhs_major_minor_to_ys_[2][I0];
|
||||
return sequence<Y_K1, Y_K2>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
// LINEAR_LAYOUT or VECTORIZED_LAYOUT without outer iteration: 1D page lookup
|
||||
return sequence<Y_K1>{};
|
||||
}
|
||||
}();
|
||||
|
||||
static_assert(decltype(VPageIndexYDims)::at(0) < VDstrEncode::NDimY,
|
||||
"V page-index Y dim must be valid");
|
||||
|
||||
statically_indexed_array<index_t, V_PageIdxRepeat> v_offsets;
|
||||
auto update_v_offsets = [&](auto k_loop_start) {
|
||||
constexpr index_t kLoopStart = decltype(k_loop_start)::value;
|
||||
// For 3D K decomposition (K2, K0, K1), compute offsets for each K2 slice
|
||||
// The global K offset for (k2, k1) is: kLoopStart + k2 * (K0 * K1) + k1
|
||||
// We iterate K2 outer, K1 inner, and merge into 1D v_offsets array
|
||||
if constexpr(V_KIterOuter > 1)
|
||||
{
|
||||
static_for<0, V_KIterOuter, 1>{}([&](auto k2) {
|
||||
statically_indexed_array<index_t, V_KIterInner> v_offsets_k2;
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KIterInner>,
|
||||
decltype(v_coord),
|
||||
I1,
|
||||
kPageBlockSize,
|
||||
kLoopStart + k2.value * V_KLanes * V_KIterInner,
|
||||
V_KIterInner,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets_k2, current_seq_k);
|
||||
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
|
||||
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
|
||||
v_offsets[idx] = v_offsets_k2[k1];
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KIterInner>,
|
||||
decltype(v_coord),
|
||||
I1,
|
||||
kPageBlockSize,
|
||||
kLoopStart,
|
||||
V_KIterInner,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
}
|
||||
};
|
||||
update_v_offsets(number<0>{});
|
||||
auto v_dram_window =
|
||||
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
v_dist,
|
||||
v_offsets,
|
||||
VPageIndexDim);
|
||||
number<1>{}, // HsGatherDim
|
||||
number<1>{}, // NumCoord
|
||||
VPageIndexYDims);
|
||||
|
||||
// prefetch K tile
|
||||
async_load_tile_raw(
|
||||
@@ -625,18 +763,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
__builtin_amdgcn_sched_barrier(1);
|
||||
|
||||
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
update_v_offsets(number<kK1>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
|
||||
const auto p = [&]() {
|
||||
@@ -766,7 +893,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F);
|
||||
// store & prefetch next v, after the max reduction
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> &&
|
||||
kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
@@ -787,8 +916,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
store_tile(v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
|
||||
const auto v_store_tile = tile_elementwise_in(v_element_func, v_buf);
|
||||
store_tile(v_lds_window_tmp, v_store_tile); // store the prefetch
|
||||
}
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
@@ -799,18 +928,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kK1}); // will have scratch if move this right after load_tile(v_dram)...
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
2 * kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
update_v_offsets(number<2 * kK1>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -938,18 +1056,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
(2 + i_k1.value) * kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
update_v_offsets(number<(2 + i_k1.value) * kK1>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
}
|
||||
block_sync_lds();
|
||||
@@ -961,7 +1068,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> &&
|
||||
kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
|
||||
@@ -4,15 +4,246 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
using BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy =
|
||||
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ true,
|
||||
/* NumPrefetchK = */ 3,
|
||||
/* NumPrefetchV = */ 3>;
|
||||
struct BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy
|
||||
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ true,
|
||||
/* NumPrefetchK = */ 3,
|
||||
/* NumPrefetchV = */ 3>
|
||||
{
|
||||
using Base = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ true,
|
||||
/* NumPrefetchK = */ 3,
|
||||
/* NumPrefetchV = */ 3>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
if constexpr(Problem::kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t kDwordx4Bytes = 16;
|
||||
return kDwordx4Bytes / sizeof(VDataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Base::template GetAlignmentV<Problem>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
|
||||
{
|
||||
if constexpr(Problem::kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// For VECTORIZED_LAYOUT, kKPack should match GEMM's kKPerThread
|
||||
// to ensure correct LDS access pattern
|
||||
constexpr auto gemm_k_decomp = GetGemmKDecomposition<Problem>();
|
||||
constexpr index_t kKPerThread = gemm_k_decomp.template at<1>();
|
||||
return kKPerThread;
|
||||
}
|
||||
else
|
||||
{
|
||||
return Base::template GetSmemKPackV<Problem>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
|
||||
{
|
||||
if constexpr(Problem::kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// For VECTORIZED_LAYOUT, we need to use our GetSmemKPackV for V size calculation
|
||||
constexpr index_t SingleKSize = [&]() {
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = Base::template GetSmemKPackK<Problem>();
|
||||
constexpr index_t KVector = Base::template GetAlignmentK<Problem>();
|
||||
constexpr index_t kPad = KPack;
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock &&
|
||||
WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector;
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK;
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
|
||||
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
}();
|
||||
|
||||
constexpr index_t SingleVSize = [&]() {
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>(); // Use our override!
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
constexpr index_t NPerRow = PixelsPerRow / kKPack;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
static_assert(kNPerBlock % NPerRow == 0);
|
||||
static_assert(kKPerBlock % kKPack == 0);
|
||||
|
||||
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
|
||||
}();
|
||||
|
||||
return max(SingleKSize, SingleVSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Base::template GetSingleSmemElementSpaceSize<Problem>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
if constexpr(Problem::kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
constexpr index_t NPerRow = PixelsPerRow / kKPack;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
static_assert(kNPerBlock % NPerRow == 0);
|
||||
static_assert(kKPerBlock % kKPack == 0);
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<Base::NumKVLdsBuffers>{},
|
||||
number<kKPerBlock / kKPack>{},
|
||||
number<kNPerBlock / NPerRow>{},
|
||||
number<NPerRow>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
|
||||
number<PixelsPerRow + kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(make_tuple(number<Base::NumKVLdsBuffers>{},
|
||||
number<kNPerBlock / NPerRow>{},
|
||||
number<NPerRow>{})),
|
||||
make_merge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
return Base::template MakeVLdsBlockDescriptor<Problem>();
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to get GEMM's K decomposition parameters (kABKLane, kKPerThread)
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetGemmKDecomposition()
|
||||
{
|
||||
// Get the KV block GEMM and extract warp gemm's K decomposition
|
||||
constexpr auto gemm = Base::template GetKVBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(gemm)>;
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
// Return kABKLane and kKPerThread from warp gemm
|
||||
return make_tuple(number<WG::WarpGemmAttribute::Impl::kABKLane>{},
|
||||
number<WG::kKPerThread>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
if constexpr(Problem::kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// For VECTORIZED_LAYOUT, use column-major distribution (K direction vector load)
|
||||
// The K decomposition must match GEMM's BWarpDstrEncoding to ensure correct LDS access
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
// Get GEMM's K decomposition (kABKLane, kKPerThread)
|
||||
constexpr auto gemm_k_decomp = GetGemmKDecomposition<Problem>();
|
||||
constexpr index_t kABKLane = gemm_k_decomp.template at<0>();
|
||||
constexpr index_t kKPerThread = gemm_k_decomp.template at<1>();
|
||||
|
||||
// K1 = kKPerThread (inner K dimension, matches GEMM's expectation)
|
||||
// K0 = kKPerBlock / K1 (outer K dimension)
|
||||
// But we need K0 to match kABKLane for the per-warp iteration
|
||||
constexpr index_t K1 = kKPerThread;
|
||||
constexpr index_t K0 = kABKLane;
|
||||
|
||||
// Verify K decomposition matches GEMM's BWarpDstrEncoding requirements
|
||||
static_assert(K0 == kABKLane, "K0 must match GEMM's kABKLane for correct LDS access");
|
||||
static_assert(K1 == kKPerThread,
|
||||
"K1 must match GEMM's kKPerThread for correct LDS access");
|
||||
|
||||
// K0 * K1 may be less than kKPerBlock, so we need outer iteration
|
||||
constexpr index_t KPerIter = K0 * K1;
|
||||
constexpr index_t KOuterIter = kKPerBlock / KPerIter;
|
||||
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
static_assert(N0 != 0, "N0 is zero");
|
||||
|
||||
if constexpr(KOuterIter == 1)
|
||||
{
|
||||
// Simple case: K decomposition matches exactly
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 0>>{});
|
||||
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock);
|
||||
return dstr;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Need outer K iteration
|
||||
constexpr index_t K2 = KOuterIter;
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K2, K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<2, 0, 0>>{});
|
||||
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock);
|
||||
return dstr;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// For non-VECTORIZED_LAYOUT, use base class implementation
|
||||
return Base::template MakeVDramTileDistribution<Problem>();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -121,6 +121,9 @@ struct BlockFmhaBatchPrefillPipelineProblem
|
||||
static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0,
|
||||
"kPageBlockSize must be divisible by kVectorSize for vectorized layout");
|
||||
static_assert(kIsGroupMode_, "Batch prefill requires group mode");
|
||||
|
||||
static_assert(BlockFmhaShape_::IsVLayoutRowMajor,
|
||||
"Batch prefill kernel requires RowMajor VLayout");
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
|
||||
Reference in New Issue
Block a user