Optimize batch prefill kernel performance for VECTORIZED_LAYOUT KV cache (#3657)

- Add multi-dimensional page index support (YsGatherDims) in tile_scatter_gather
- Add is_gather_dim() and get_gather_index() for multi-dim page lookup
- Override MakeVDramTileDistribution() for VECTORIZED_LAYOUT to match
  GEMM's BWarpDstrEncoding (K decomposition: {K2, K0, K1})
- Add GetGemmKDecomposition() to retrieve kABKLane and kKPerThread
- Add static_assert for RowMajor VLayout requirement in batch prefill

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
Jeff Huang
2026-01-29 07:18:41 +08:00
committed by GitHub
parent 83b58bb0c3
commit e3556fed04
4 changed files with 553 additions and 84 deletions

View File

@@ -26,17 +26,26 @@ namespace ck_tile {
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
* @tparam NumCoord TBD
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions.
* @tparam StaticPageIndexArray_ Array type holding page indices for scatter/gather.
* @tparam StaticValidArray_ Array type holding validity flags (nullptr_t if unused).
* @tparam HsGatherDim H-space dimension index used for gather lookup (default: 0).
* @tparam NumCoord Number of pre-computed coordinates for pipelining (default: 1).
* @tparam YsGatherDims Sequence of Y-space dimension indices used for page lookup.
* For single dimension: sequence<0> (default).
* For multiple dimensions: sequence<dim0, dim1, ...> where
* the combined index is computed as:
* idx[dim0] + idx[dim1] * len[dim0] + idx[dim2] * len[dim0] *
* len[dim1] + ...
*/
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
typename StaticValidArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1,
index_t YsGatherDim = 0>
index_t HsGatherDim = 0,
index_t NumCoord = 1,
typename YsGatherDims = sequence<0>>
struct tile_scatter_gather
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
@@ -77,6 +86,75 @@ struct tile_scatter_gather
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
/**
* @brief Check if a given Y-space dimension index is a gather dimension.
*
* Gather dimensions are those specified in YsGatherDims template parameter.
* When computing forward_step_scatter, gather dimensions are set to 0
* because page offset lookup handles address calculation for these dimensions.
*
* @param i Y-space dimension index to check
* @return true if dimension i is in YsGatherDims, false otherwise
*/
CK_TILE_DEVICE static constexpr bool is_gather_dim(index_t i)
{
return sequence_any_of(YsGatherDims{}, [i](auto k) { return i == k; });
}
/**
* @brief Compute the linearized gather index from Y-space indices for page lookup.
*
* This function converts multi-dimensional Y-space indices (specified by YsGatherDims)
* into a single linearized index used to look up the page offset in page_idx_ array.
*
* For single gather dimension (YsGatherDims::size() == 1):
* Simply returns idx_ys_start[YsGatherDims::at(0)]
*
* For multiple gather dimensions (e.g., YsGatherDims = sequence<0, 2>):
* Computes: idx[dim0] + idx[dim1] * len[dim0] + idx[dim2] * len[dim0] * len[dim1] + ...
* This is row-major linearization where earlier dimensions are inner (faster-varying).
*
* @tparam YsIndex Type of the Y-space index tuple/array
* @param idx_ys_start Current Y-space indices from space-filling curve iteration
* @return Linearized index for page_idx_ array lookup
*/
template <typename YsIndex>
CK_TILE_DEVICE static constexpr auto get_gather_index(const YsIndex& idx_ys_start)
{
// TODO: Consider making ys_lengths_ part of public API or adding accessor
static_assert(sizeof(TileDstr::DstrEncode::detail::ys_lengths_) > 0,
"Relies on internal detail::ys_lengths_");
constexpr index_t num_gather_dims = YsGatherDims::size();
if constexpr(num_gather_dims == 1)
{
return idx_ys_start[number<YsGatherDims::at(0)>{}];
}
else
{
// Recursive lambda to compute index as a compile-time number
// Uses row-major linearization: idx[0] + idx[1] * len[0] + idx[2] * len[0] * len[1] +
// ...
auto recurse = [&](auto self, auto i_constant) {
constexpr index_t i = decltype(i_constant)::value;
constexpr index_t dim = YsGatherDims::at(i);
auto current_val = idx_ys_start[number<dim>{}];
if constexpr(i + 1 < num_gather_dims)
{
constexpr index_t len = TileDstr::DstrEncode::detail::ys_lengths_[dim];
return current_val + self(self, number<i + 1>{}) * number<len>{};
}
else
{
return current_val;
}
};
return recurse(recurse, number<0>{});
}
}
struct load_store_traits
{
private:
@@ -375,7 +453,7 @@ struct tile_scatter_gather
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
@@ -427,7 +505,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -485,7 +563,7 @@ struct tile_scatter_gather
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// merge page_offset into bottom_coord
@@ -513,7 +591,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -598,7 +676,7 @@ struct tile_scatter_gather
}();
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
@@ -624,7 +702,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -718,7 +796,7 @@ struct tile_scatter_gather
}();
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
@@ -748,7 +826,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -791,7 +869,7 @@ struct tile_scatter_gather
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<0>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// read from distributed tensor
@@ -837,7 +915,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -874,11 +952,11 @@ struct tile_scatter_gather
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<0>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
// idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
// get_gather_index(idx_ys_start)+0, idx_ys_start[number<1>{}]+0);
// read from distributed tensor
// vector_type_t vec;
@@ -928,7 +1006,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -1076,6 +1154,53 @@ struct tile_scatter_gather
};
// TODO: use strategy
/**
* @brief Factory function to create tile_scatter_gather with multi-dimensional gather support.
*
* This overload accepts a sequence<YsGatherDims...> to specify multiple Y-space dimensions
* for page lookup. Use this when the tile distribution decomposes the paged dimension
* into multiple Y-space dimensions (e.g., VECTORIZED_LAYOUT V tensor with K decomposition
* {K2, K0, K1} where both Y0 and Y2 contribute to page index).
*
* @tparam HsGatherDim H-space dimension for gather
* @tparam NumCoord Number of pre-computed coordinates
* @tparam YsGatherDims Parameter pack specifying which Y-dimensions are used for page lookup
*
* @param tensor_view The underlying tensor view for device memory access
* @param window_lengths Static window sizes for each dimension
* @param origin Window origin coordinates on the bottom tensor
* @param tile_distribution Thread-to-tile mapping distribution
* @param page_idx Array of page offsets (in bytes) for scatter/gather
*/
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
index_t HsGatherDim,
index_t NumCoord,
index_t... YsGatherDims>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
number<HsGatherDim>,
number<NumCoord>,
sequence<YsGatherDims...>)
{
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
std::nullptr_t,
HsGatherDim,
NumCoord,
sequence<YsGatherDims...>>{
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
}
// Legacy overload (compatible with original API)
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
@@ -1087,7 +1212,7 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx, // perbytes
const StaticPageIndexArray_& page_idx,
number<HsGatherDim> = {},
number<NumCoord> = {})
{
@@ -1097,7 +1222,8 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
remove_cvref_t<StaticPageIndexArray_>,
std::nullptr_t,
HsGatherDim,
NumCoord>{
NumCoord,
sequence<0>>{
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
}

View File

@@ -533,32 +533,170 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
auto v_coord = v_dist.calculate_index();
const auto VPageIndexDim = I1;
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
statically_indexed_array<index_t, V_KRepeat> v_offsets;
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
0,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
auto v_coord = v_dist.calculate_index();
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
// V tensor K-dimension decomposition for page index computation
// ============================================================
// The K dimension (seqlen_k) in V distribution is decomposed into multiple sub-dimensions.
// This decomposition determines how threads iterate over the K dimension and how page
// indices are computed for paged KV cache.
//
// The decomposition pattern differs by memory layout:
//
// VECTORIZED_LAYOUT (ColumnMajor, custom distribution):
// 3D decomposition: K = K2 × K0 × K1
// - K2 (V_KIterOuter): Outer iteration count
// - K0 (V_KLanes): Lanes for K dimension (matches GEMM kABKLane)
// - K1 (V_KIterInner): Vector load size (matches GEMM kKPerThread)
// - hs_lengthss_[I1] = {K2, K0, K1}, size = 3 (or {K0, K1} size = 2 if no outer iter)
//
// LINEAR_LAYOUT ColumnMajor (base class distribution):
// 2D decomposition: K = K0 × K1
// - K0: Lanes for K dimension (may not match GEMM kABKLane)
// - K1: Vector load size
// - hs_lengthss_[I1] = {K0, K1}, size = 2
//
// LINEAR_LAYOUT RowMajor (base class distribution):
// 4D decomposition: K = K0 × K1 × K2 × K3 (uses shuffle_tile for GEMM alignment)
// 3D decomposition: K = K0 × K1 × K2 (fallback case)
// - Page lookup uses Y-space's last dimension only (inner iteration)
//
// V_PageIdxRepeat = total number of page lookups per thread = V_KIterOuter × V_KIterInner
constexpr index_t V_KIterInner = VDstrEncode::hs_lengthss_[I1].back();
// Compute V_KIterOuter and V_KLanes based on memory layout and K decomposition
constexpr index_t V_KIterOuter = [] {
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// VECTORIZED_LAYOUT: 3D decomposition {K2, K0, K1} when outer iteration is needed
if constexpr(VDstrEncode::hs_lengthss_[I1].size() == 3)
return static_cast<index_t>(VDstrEncode::hs_lengthss_[I1][I0]);
else
return index_t{1};
}
else
{
// LINEAR_LAYOUT: No outer iteration for page lookup
// RowMajor uses shuffle_tile, ColumnMajor has simple 2D decomposition
// Both cases use single-dimension Y-space page lookup
return index_t{1};
}
}();
constexpr index_t V_KLanes = [] {
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// VECTORIZED_LAYOUT: K0 is the lanes dimension
if constexpr(V_KIterOuter > 1)
return static_cast<index_t>(VDstrEncode::hs_lengthss_[I1][I1]);
else
return static_cast<index_t>(VDstrEncode::hs_lengthss_[I1][I0]);
}
else
{
// LINEAR_LAYOUT: First dimension is K0 (lanes)
return static_cast<index_t>(VDstrEncode::hs_lengthss_[I1][I0]);
}
}();
// This affects page offset computation - need to track offsets for each (k2, k1)
// combination
constexpr index_t V_PageIdxRepeat = V_KIterInner * V_KIterOuter;
// VPageIndexYDims: Y-space dimension indices that participate in page index computation
// ================================================================================
// In tile_scatter_gather, the gather index is computed from Y-space coordinates.
// This sequence specifies which Y dimensions should be linearized to form the page lookup
// index.
//
// VECTORIZED_LAYOUT with outer iteration: sequence<Y_K1, Y_K2>
// - Both K1 and K2 are in Y-space (thread iteration dimensions)
// - gather_index = y_k1 + y_k2 * len(Y_K1) (linearized 2D -> 1D)
//
// VECTORIZED_LAYOUT without outer iteration / LINEAR_LAYOUT: sequence<Y_K1>
// - Only the innermost K dimension is used for page lookup (single dimension)
//
constexpr auto VPageIndexYDims = []() {
// K1Minor is always the last element index in hs_lengthss_[I1]
constexpr index_t K1Minor = VDstrEncode::hs_lengthss_[I1].size() - 1;
constexpr index_t Y_K1 = VDstrEncode::detail::rhs_major_minor_to_ys_[2][K1Minor];
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT &&
V_KIterOuter > 1)
{
// VECTORIZED_LAYOUT with outer iteration: need 2D page lookup
constexpr index_t Y_K2 = VDstrEncode::detail::rhs_major_minor_to_ys_[2][I0];
return sequence<Y_K1, Y_K2>{};
}
else
{
// LINEAR_LAYOUT or VECTORIZED_LAYOUT without outer iteration: 1D page lookup
return sequence<Y_K1>{};
}
}();
static_assert(decltype(VPageIndexYDims)::at(0) < VDstrEncode::NDimY,
"V page-index Y dim must be valid");
statically_indexed_array<index_t, V_PageIdxRepeat> v_offsets;
auto update_v_offsets = [&](auto k_loop_start) {
constexpr index_t kLoopStart = decltype(k_loop_start)::value;
// For 3D K decomposition (K2, K0, K1), compute offsets for each K2 slice
// The global K offset for (k2, k1) is: kLoopStart + k2 * (K0 * K1) + k1
// We iterate K2 outer, K1 inner, and merge into 1D v_offsets array
if constexpr(V_KIterOuter > 1)
{
static_for<0, V_KIterOuter, 1>{}([&](auto k2) {
statically_indexed_array<index_t, V_KIterInner> v_offsets_k2;
kv_offset_array_transform<statically_indexed_array<index_t, V_KIterInner>,
decltype(v_coord),
I1,
kPageBlockSize,
kLoopStart + k2.value * V_KLanes * V_KIterInner,
V_KIterInner,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets_k2, current_seq_k);
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
v_offsets[idx] = v_offsets_k2[k1];
});
});
}
else
{
kv_offset_array_transform<statically_indexed_array<index_t, V_KIterInner>,
decltype(v_coord),
I1,
kPageBlockSize,
kLoopStart,
V_KIterInner,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
}
};
update_v_offsets(number<0>{});
auto v_dram_window =
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
v_dist,
v_offsets,
VPageIndexDim);
number<1>{}, // HsGatherDim
number<1>{}, // NumCoord
VPageIndexYDims);
// prefetch K tile
async_load_tile_raw(
@@ -625,18 +763,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(1);
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
update_v_offsets(number<kK1>{});
v_dram_window.update_page_idx(v_offsets);
const auto p = [&]() {
@@ -766,7 +893,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> &&
kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
@@ -787,8 +916,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
const auto v_store_tile = tile_elementwise_in(v_element_func, v_buf);
store_tile(v_lds_window_tmp, v_store_tile); // store the prefetch
}
if constexpr(k1_loops > 1)
@@ -799,18 +928,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
2 * kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
update_v_offsets(number<2 * kK1>{});
v_dram_window.update_page_idx(v_offsets);
}
__builtin_amdgcn_sched_barrier(0);
@@ -938,18 +1056,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
(2 + i_k1.value) * kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
update_v_offsets(number<(2 + i_k1.value) * kK1>{});
v_dram_window.update_page_idx(v_offsets);
}
block_sync_lds();
@@ -961,7 +1068,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> &&
kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());

View File

@@ -4,15 +4,246 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ true,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>;
struct BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ true,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>
{
using Base = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ true,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
if constexpr(Problem::kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kDwordx4Bytes = 16;
return kDwordx4Bytes / sizeof(VDataType);
}
else
{
return Base::template GetAlignmentV<Problem>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
if constexpr(Problem::kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// For VECTORIZED_LAYOUT, kKPack should match GEMM's kKPerThread
// to ensure correct LDS access pattern
constexpr auto gemm_k_decomp = GetGemmKDecomposition<Problem>();
constexpr index_t kKPerThread = gemm_k_decomp.template at<1>();
return kKPerThread;
}
else
{
return Base::template GetSmemKPackV<Problem>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
{
if constexpr(Problem::kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// For VECTORIZED_LAYOUT, we need to use our GetSmemKPackV for V size calculation
constexpr index_t SingleKSize = [&]() {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = Base::template GetSmemKPackK<Problem>();
constexpr index_t KVector = Base::template GetAlignmentK<Problem>();
constexpr index_t kPad = KPack;
static_assert(WarpSize * KVector >= kKPerBlock &&
WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector;
constexpr index_t LaneGroups = WarpSize / LanesPerK;
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
}();
constexpr index_t SingleVSize = [&]() {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>(); // Use our override!
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
}();
return max(SingleKSize, SingleVSize);
}
else
{
return Base::template GetSingleSmemElementSpaceSize<Problem>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{
if constexpr(Problem::kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<Base::NumKVLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<GetSingleSmemElementSpaceSize<Problem>()>{},
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_merge_transform(make_tuple(number<Base::NumKVLdsBuffers>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{})),
make_merge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
}
else
{
return Base::template MakeVLdsBlockDescriptor<Problem>();
}
}
// Helper to get GEMM's K decomposition parameters (kABKLane, kKPerThread)
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetGemmKDecomposition()
{
// Get the KV block GEMM and extract warp gemm's K decomposition
constexpr auto gemm = Base::template GetKVBlockGemm<Problem>();
using BlockGemm = remove_cvref_t<decltype(gemm)>;
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
// Return kABKLane and kKPerThread from warp gemm
return make_tuple(number<WG::WarpGemmAttribute::Impl::kABKLane>{},
number<WG::kKPerThread>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
if constexpr(Problem::kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// For VECTORIZED_LAYOUT, use column-major distribution (K direction vector load)
// The K decomposition must match GEMM's BWarpDstrEncoding to ensure correct LDS access
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
// Get GEMM's K decomposition (kABKLane, kKPerThread)
constexpr auto gemm_k_decomp = GetGemmKDecomposition<Problem>();
constexpr index_t kABKLane = gemm_k_decomp.template at<0>();
constexpr index_t kKPerThread = gemm_k_decomp.template at<1>();
// K1 = kKPerThread (inner K dimension, matches GEMM's expectation)
// K0 = kKPerBlock / K1 (outer K dimension)
// But we need K0 to match kABKLane for the per-warp iteration
constexpr index_t K1 = kKPerThread;
constexpr index_t K0 = kABKLane;
// Verify K decomposition matches GEMM's BWarpDstrEncoding requirements
static_assert(K0 == kABKLane, "K0 must match GEMM's kABKLane for correct LDS access");
static_assert(K1 == kKPerThread,
"K1 must match GEMM's kKPerThread for correct LDS access");
// K0 * K1 may be less than kKPerBlock, so we need outer iteration
constexpr index_t KPerIter = K0 * K1;
constexpr index_t KOuterIter = kKPerBlock / KPerIter;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = kNPerBlock / (N2 * N1);
static_assert(N0 != 0, "N0 is zero");
if constexpr(KOuterIter == 1)
{
// Simple case: K decomposition matches exactly
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<2, 1>,
sequence<1, 0>>{});
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock);
return dstr;
}
else
{
// Need outer K iteration
constexpr index_t K2 = KOuterIter;
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K2, K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<2, 1, 2>,
sequence<2, 0, 0>>{});
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock);
return dstr;
}
}
else
{
// For non-VECTORIZED_LAYOUT, use base class implementation
return Base::template MakeVDramTileDistribution<Problem>();
}
}
};
} // namespace ck_tile

View File

@@ -121,6 +121,9 @@ struct BlockFmhaBatchPrefillPipelineProblem
static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0,
"kPageBlockSize must be divisible by kVectorSize for vectorized layout");
static_assert(kIsGroupMode_, "Batch prefill requires group mode");
static_assert(BlockFmhaShape_::IsVLayoutRowMajor,
"Batch prefill kernel requires RowMajor VLayout");
};
template <typename QDataType_,