mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Optimize batch prefill kernel performance for VECTORIZED_LAYOUT KV cache (#3657)
- Add multi-dimensional page index support (YsGatherDims) in tile_scatter_gather
- Add is_gather_dim() and get_gather_index() for multi-dim page lookup
- Override MakeVDramTileDistribution() for VECTORIZED_LAYOUT to match
GEMM's BWarpDstrEncoding (K decomposition: {K2, K0, K1})
- Add GetGemmKDecomposition() to retrieve kABKLane and kKPerThread
- Add static_assert for RowMajor VLayout requirement in batch prefill
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -26,17 +26,26 @@ namespace ck_tile {
|
||||
*
|
||||
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
|
||||
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
|
||||
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
|
||||
* @tparam NumCoord TBD
|
||||
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions.
|
||||
* @tparam StaticPageIndexArray_ Array type holding page indices for scatter/gather.
|
||||
* @tparam StaticValidArray_ Array type holding validity flags (nullptr_t if unused).
|
||||
* @tparam HsGatherDim H-space dimension index used for gather lookup (default: 0).
|
||||
* @tparam NumCoord Number of pre-computed coordinates for pipelining (default: 1).
|
||||
* @tparam YsGatherDims Sequence of Y-space dimension indices used for page lookup.
|
||||
* For single dimension: sequence<0> (default).
|
||||
* For multiple dimensions: sequence<dim0, dim1, ...> where
|
||||
* the combined index is computed as:
|
||||
* idx[dim0] + idx[dim1] * len[dim0] + idx[dim2] * len[dim0] *
|
||||
* len[dim1] + ...
|
||||
*/
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
typename StaticValidArray_,
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1,
|
||||
index_t YsGatherDim = 0>
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1,
|
||||
typename YsGatherDims = sequence<0>>
|
||||
struct tile_scatter_gather
|
||||
{
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
@@ -77,6 +86,75 @@ struct tile_scatter_gather
|
||||
using BottomTensorCoord =
|
||||
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
|
||||
|
||||
/**
|
||||
* @brief Check if a given Y-space dimension index is a gather dimension.
|
||||
*
|
||||
* Gather dimensions are those specified in YsGatherDims template parameter.
|
||||
* When computing forward_step_scatter, gather dimensions are set to 0
|
||||
* because page offset lookup handles address calculation for these dimensions.
|
||||
*
|
||||
* @param i Y-space dimension index to check
|
||||
* @return true if dimension i is in YsGatherDims, false otherwise
|
||||
*/
|
||||
CK_TILE_DEVICE static constexpr bool is_gather_dim(index_t i)
|
||||
{
|
||||
return sequence_any_of(YsGatherDims{}, [i](auto k) { return i == k; });
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Compute the linearized gather index from Y-space indices for page lookup.
|
||||
*
|
||||
* This function converts multi-dimensional Y-space indices (specified by YsGatherDims)
|
||||
* into a single linearized index used to look up the page offset in page_idx_ array.
|
||||
*
|
||||
* For single gather dimension (YsGatherDims::size() == 1):
|
||||
* Simply returns idx_ys_start[YsGatherDims::at(0)]
|
||||
*
|
||||
* For multiple gather dimensions (e.g., YsGatherDims = sequence<0, 2>):
|
||||
* Computes: idx[dim0] + idx[dim1] * len[dim0] + idx[dim2] * len[dim0] * len[dim1] + ...
|
||||
* This is row-major linearization where earlier dimensions are inner (faster-varying).
|
||||
*
|
||||
* @tparam YsIndex Type of the Y-space index tuple/array
|
||||
* @param idx_ys_start Current Y-space indices from space-filling curve iteration
|
||||
* @return Linearized index for page_idx_ array lookup
|
||||
*/
|
||||
template <typename YsIndex>
|
||||
CK_TILE_DEVICE static constexpr auto get_gather_index(const YsIndex& idx_ys_start)
|
||||
{
|
||||
// TODO: Consider making ys_lengths_ part of public API or adding accessor
|
||||
static_assert(sizeof(TileDstr::DstrEncode::detail::ys_lengths_) > 0,
|
||||
"Relies on internal detail::ys_lengths_");
|
||||
|
||||
constexpr index_t num_gather_dims = YsGatherDims::size();
|
||||
|
||||
if constexpr(num_gather_dims == 1)
|
||||
{
|
||||
return idx_ys_start[number<YsGatherDims::at(0)>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
// Recursive lambda to compute index as a compile-time number
|
||||
// Uses row-major linearization: idx[0] + idx[1] * len[0] + idx[2] * len[0] * len[1] +
|
||||
// ...
|
||||
auto recurse = [&](auto self, auto i_constant) {
|
||||
constexpr index_t i = decltype(i_constant)::value;
|
||||
constexpr index_t dim = YsGatherDims::at(i);
|
||||
auto current_val = idx_ys_start[number<dim>{}];
|
||||
|
||||
if constexpr(i + 1 < num_gather_dims)
|
||||
{
|
||||
constexpr index_t len = TileDstr::DstrEncode::detail::ys_lengths_[dim];
|
||||
return current_val + self(self, number<i + 1>{}) * number<len>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return current_val;
|
||||
}
|
||||
};
|
||||
return recurse(recurse, number<0>{});
|
||||
}
|
||||
}
|
||||
|
||||
struct load_store_traits
|
||||
{
|
||||
private:
|
||||
@@ -375,7 +453,7 @@ struct tile_scatter_gather
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
constexpr auto idx_gather = get_gather_index(idx_ys_start);
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from bottom tensor
|
||||
@@ -427,7 +505,7 @@ struct tile_scatter_gather
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
|
||||
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
@@ -485,7 +563,7 @@ struct tile_scatter_gather
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
constexpr auto idx_gather = get_gather_index(idx_ys_start);
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// merge page_offset into bottom_coord
|
||||
@@ -513,7 +591,7 @@ struct tile_scatter_gather
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
|
||||
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
@@ -598,7 +676,7 @@ struct tile_scatter_gather
|
||||
}();
|
||||
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
constexpr auto idx_gather = get_gather_index(idx_ys_start);
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from bottom tensor
|
||||
@@ -624,7 +702,7 @@ struct tile_scatter_gather
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
|
||||
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
@@ -718,7 +796,7 @@ struct tile_scatter_gather
|
||||
}();
|
||||
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
constexpr auto idx_gather = get_gather_index(idx_ys_start);
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
|
||||
@@ -748,7 +826,7 @@ struct tile_scatter_gather
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
|
||||
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
@@ -791,7 +869,7 @@ struct tile_scatter_gather
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<0>{}];
|
||||
constexpr auto idx_gather = get_gather_index(idx_ys_start);
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from distributed tensor
|
||||
@@ -837,7 +915,7 @@ struct tile_scatter_gather
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
|
||||
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
@@ -874,11 +952,11 @@ struct tile_scatter_gather
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<0>{}];
|
||||
constexpr auto idx_gather = get_gather_index(idx_ys_start);
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
|
||||
// idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
|
||||
// get_gather_index(idx_ys_start)+0, idx_ys_start[number<1>{}]+0);
|
||||
|
||||
// read from distributed tensor
|
||||
// vector_type_t vec;
|
||||
@@ -928,7 +1006,7 @@ struct tile_scatter_gather
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
|
||||
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
@@ -1076,6 +1154,53 @@ struct tile_scatter_gather
|
||||
};
|
||||
|
||||
// TODO: use strategy
|
||||
/**
|
||||
* @brief Factory function to create tile_scatter_gather with multi-dimensional gather support.
|
||||
*
|
||||
* This overload accepts a sequence<YsGatherDims...> to specify multiple Y-space dimensions
|
||||
* for page lookup. Use this when the tile distribution decomposes the paged dimension
|
||||
* into multiple Y-space dimensions (e.g., VECTORIZED_LAYOUT V tensor with K decomposition
|
||||
* {K2, K0, K1} where both Y0 and Y2 contribute to page index).
|
||||
*
|
||||
* @tparam HsGatherDim H-space dimension for gather
|
||||
* @tparam NumCoord Number of pre-computed coordinates
|
||||
* @tparam YsGatherDims Parameter pack specifying which Y-dimensions are used for page lookup
|
||||
*
|
||||
* @param tensor_view The underlying tensor view for device memory access
|
||||
* @param window_lengths Static window sizes for each dimension
|
||||
* @param origin Window origin coordinates on the bottom tensor
|
||||
* @param tile_distribution Thread-to-tile mapping distribution
|
||||
* @param page_idx Array of page offsets (in bytes) for scatter/gather
|
||||
*/
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
index_t HsGatherDim,
|
||||
index_t NumCoord,
|
||||
index_t... YsGatherDims>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
const StaticPageIndexArray_& page_idx,
|
||||
number<HsGatherDim>,
|
||||
number<NumCoord>,
|
||||
sequence<YsGatherDims...>)
|
||||
{
|
||||
return tile_scatter_gather<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
remove_cvref_t<StaticPageIndexArray_>,
|
||||
std::nullptr_t,
|
||||
HsGatherDim,
|
||||
NumCoord,
|
||||
sequence<YsGatherDims...>>{
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
|
||||
}
|
||||
|
||||
// Legacy overload (compatible with original API)
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
@@ -1087,7 +1212,7 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
const StaticPageIndexArray_& page_idx, // perbytes
|
||||
const StaticPageIndexArray_& page_idx,
|
||||
number<HsGatherDim> = {},
|
||||
number<NumCoord> = {})
|
||||
{
|
||||
@@ -1097,7 +1222,8 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
remove_cvref_t<StaticPageIndexArray_>,
|
||||
std::nullptr_t,
|
||||
HsGatherDim,
|
||||
NumCoord>{
|
||||
NumCoord,
|
||||
sequence<0>>{
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user