Optimize batch prefill kernel performance for VECTORIZED_LAYOUT KV cache (#3657)

- Add multi-dimensional page index support (YsGatherDims) in tile_scatter_gather
- Add is_gather_dim() and get_gather_index() for multi-dim page lookup
- Override MakeVDramTileDistribution() for VECTORIZED_LAYOUT to match
  GEMM's BWarpDstrEncoding (K decomposition: {K2, K0, K1})
- Add GetGemmKDecomposition() to retrieve kABKLane and kKPerThread
- Add static_assert for RowMajor VLayout requirement in batch prefill

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
Jeff Huang
2026-01-29 07:18:41 +08:00
committed by GitHub
parent 83b58bb0c3
commit e3556fed04
4 changed files with 553 additions and 84 deletions

View File

@@ -26,17 +26,26 @@ namespace ck_tile {
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
* @tparam NumCoord TBD
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions.
* @tparam StaticPageIndexArray_ Array type holding page indices for scatter/gather.
* @tparam StaticValidArray_ Array type holding validity flags (nullptr_t if unused).
* @tparam HsGatherDim H-space dimension index used for gather lookup (default: 0).
* @tparam NumCoord Number of pre-computed coordinates for pipelining (default: 1).
* @tparam YsGatherDims Sequence of Y-space dimension indices used for page lookup.
* For single dimension: sequence<0> (default).
* For multiple dimensions: sequence<dim0, dim1, ...> where
* the combined index is computed as:
* idx[dim0] + idx[dim1] * len[dim0] + idx[dim2] * len[dim0] *
* len[dim1] + ...
*/
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
typename StaticValidArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1,
index_t YsGatherDim = 0>
index_t HsGatherDim = 0,
index_t NumCoord = 1,
typename YsGatherDims = sequence<0>>
struct tile_scatter_gather
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
@@ -77,6 +86,75 @@ struct tile_scatter_gather
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
/**
* @brief Check if a given Y-space dimension index is a gather dimension.
*
* Gather dimensions are those specified in YsGatherDims template parameter.
* When computing forward_step_scatter, gather dimensions are set to 0
* because page offset lookup handles address calculation for these dimensions.
*
* @param i Y-space dimension index to check
* @return true if dimension i is in YsGatherDims, false otherwise
*/
CK_TILE_DEVICE static constexpr bool is_gather_dim(index_t i)
{
return sequence_any_of(YsGatherDims{}, [i](auto k) { return i == k; });
}
/**
* @brief Compute the linearized gather index from Y-space indices for page lookup.
*
* This function converts multi-dimensional Y-space indices (specified by YsGatherDims)
* into a single linearized index used to look up the page offset in page_idx_ array.
*
* For single gather dimension (YsGatherDims::size() == 1):
* Simply returns idx_ys_start[YsGatherDims::at(0)]
*
* For multiple gather dimensions (e.g., YsGatherDims = sequence<0, 2>):
* Computes: idx[dim0] + idx[dim1] * len[dim0] + idx[dim2] * len[dim0] * len[dim1] + ...
* This is row-major linearization where earlier dimensions are inner (faster-varying).
*
* @tparam YsIndex Type of the Y-space index tuple/array
* @param idx_ys_start Current Y-space indices from space-filling curve iteration
* @return Linearized index for page_idx_ array lookup
*/
template <typename YsIndex>
CK_TILE_DEVICE static constexpr auto get_gather_index(const YsIndex& idx_ys_start)
{
// TODO: Consider making ys_lengths_ part of public API or adding accessor
static_assert(sizeof(TileDstr::DstrEncode::detail::ys_lengths_) > 0,
"Relies on internal detail::ys_lengths_");
constexpr index_t num_gather_dims = YsGatherDims::size();
if constexpr(num_gather_dims == 1)
{
return idx_ys_start[number<YsGatherDims::at(0)>{}];
}
else
{
// Recursive lambda to compute index as a compile-time number
// Uses row-major linearization: idx[0] + idx[1] * len[0] + idx[2] * len[0] * len[1] +
// ...
auto recurse = [&](auto self, auto i_constant) {
constexpr index_t i = decltype(i_constant)::value;
constexpr index_t dim = YsGatherDims::at(i);
auto current_val = idx_ys_start[number<dim>{}];
if constexpr(i + 1 < num_gather_dims)
{
constexpr index_t len = TileDstr::DstrEncode::detail::ys_lengths_[dim];
return current_val + self(self, number<i + 1>{}) * number<len>{};
}
else
{
return current_val;
}
};
return recurse(recurse, number<0>{});
}
}
struct load_store_traits
{
private:
@@ -375,7 +453,7 @@ struct tile_scatter_gather
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
@@ -427,7 +505,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -485,7 +563,7 @@ struct tile_scatter_gather
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// merge page_offset into bottom_coord
@@ -513,7 +591,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -598,7 +676,7 @@ struct tile_scatter_gather
}();
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
@@ -624,7 +702,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -718,7 +796,7 @@ struct tile_scatter_gather
}();
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
@@ -748,7 +826,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -791,7 +869,7 @@ struct tile_scatter_gather
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<0>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// read from distributed tensor
@@ -837,7 +915,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -874,11 +952,11 @@ struct tile_scatter_gather
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<0>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
// idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
// get_gather_index(idx_ys_start)+0, idx_ys_start[number<1>{}]+0);
// read from distributed tensor
// vector_type_t vec;
@@ -928,7 +1006,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -1076,6 +1154,53 @@ struct tile_scatter_gather
};
// TODO: use strategy
/**
* @brief Factory function to create tile_scatter_gather with multi-dimensional gather support.
*
* This overload accepts a sequence<YsGatherDims...> to specify multiple Y-space dimensions
* for page lookup. Use this when the tile distribution decomposes the paged dimension
* into multiple Y-space dimensions (e.g., VECTORIZED_LAYOUT V tensor with K decomposition
* {K2, K0, K1} where both Y0 and Y2 contribute to page index).
*
* @tparam HsGatherDim H-space dimension for gather
* @tparam NumCoord Number of pre-computed coordinates
* @tparam YsGatherDims Parameter pack specifying which Y-dimensions are used for page lookup
*
* @param tensor_view The underlying tensor view for device memory access
* @param window_lengths Static window sizes for each dimension
* @param origin Window origin coordinates on the bottom tensor
* @param tile_distribution Thread-to-tile mapping distribution
* @param page_idx Array of page offsets (in bytes) for scatter/gather
*/
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
index_t HsGatherDim,
index_t NumCoord,
index_t... YsGatherDims>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
number<HsGatherDim>,
number<NumCoord>,
sequence<YsGatherDims...>)
{
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
std::nullptr_t,
HsGatherDim,
NumCoord,
sequence<YsGatherDims...>>{
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
}
// Legacy overload (compatible with original API)
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
@@ -1087,7 +1212,7 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx, // perbytes
const StaticPageIndexArray_& page_idx,
number<HsGatherDim> = {},
number<NumCoord> = {})
{
@@ -1097,7 +1222,8 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
remove_cvref_t<StaticPageIndexArray_>,
std::nullptr_t,
HsGatherDim,
NumCoord>{
NumCoord,
sequence<0>>{
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
}