mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK_TILE] Add PagedAttention kernels (#1387)
* Use dictionary to config all the functions * Add init codegen logic for fmha fwd appendkv * Call HIP_CHECK_ERROR() macro to get real source info * Setup meaningfull arguments * Sync kernel name with the codegen * Add knew/vnew tensors to the kernel argument * Fix wrong K values after appending * Fix vnew append errro * Extract common logics * Fix Vnew tile dstr for row major case * Conditionally add fwd_splitkv API in fmha_fwd example * Conditionally add call to fmha_fwd_splitkv() * Remove "EXAMPLE_" prefix of cmake variables * Regsiter API handlers automatically * Early return if 0 < s_k_new is not supported * Show message if we are ignoring option * Unify CMakeLists.txt coding style * Set num_splits=1 if split-kv is not supported * Add length/stride getters for HostTensor * Add RoPE example utilities * Add reference_rotary_position_embedding() (not implemented) * Finish reference_rotary_position_embedding() impl * Fix typo of HostTensor<>::get_length() * Fix compilation errors * Fix wrong answer when interleaved=false * Fix wrong answer when interleaved=true * Append K/V in the host verification code * Simplify K appending logics * Simplify v_host_ref definition * Reduce input/output dimensions * Rename function: add "batched" prefix * Apply RoPE on host side * Rename RoPE utility function * Fix wrong tensor size * Avoid invoking deprecated method 'find_module' * Pass RoPE kernel args * Create Rotary Cos/Sin tile windows in kernel * Add compute data type alias for RoPE * Randomly generate seqlen_knew if needed * Fix seqlen_knew enabling check logic * Add minimum seqlen_k to generate compliance kvcache * Fix compilation error in debug mode * Fix wrong boundaries * Fix wrong seqlen_k for kvcache * Rename variables used in distributio encoding * Fix rotary cos/sin tensor/tile size * Add constraint to the rotary_dim option * Remove unused inner namespace * Add dram distribution for rotary_cos/rotary_sin (interleaved) * Only apply interleaved RoPE on Knew for now * Fix wrong thread starting offset * Instantiate multiple kernels for RoPE approaches * Clean-up pipeline * Fix error in RoPE host reference * Handle RoPE half-rotated logics * Support 8x rotary_dim under half-rotated RoPE * Add comment * Apply elementwise function to the loaded tiles * Unify parameter/variable naming style * Remove constness from q_ptr * Add code blocks for q_tile * Apply RoPE to q_tile * Remove debug print code in kernel * Fix wrong knew/vnew appending positions * Use better naming for tile indices * Add make_tile_window() for adding distribution only * Skip code if # of block is more than needed * Move thread locating logics into policy * Remove always true static_assert() * Rename header * Rename RotaryEmbeddingEnum * Extract rotary embedding logic out * Re-order parameters * Align naming of some tile size constants * Rename more tile size constants * Fix wrong grid size * Fix wrong shape of knew_host/vnew_host * Fix wrong index into knew_host/vnew_host * Fix wrong rotary_cos/rotary_sin memory size for Q * Extract Q/Knew vector size to helper methods * Use different rotary_cos/rotary_sin distr for Q/Knew * Update host/device specifiers * Fix wrong data type for Q rotary_cos/rotary_sin * Remove RoPEComputeDataType type alias * Shift rotary_cos/rotary_sin by cache_seqlen_k * Add comment for why I just 't' for all padding flags * Align commit message to the real comment * Fix wrong pipeline * Rename utility function * Disable host verification if API not exist * Fix wrong rope key for fp8 pipeline * Allow only apply RoPE on Q (without append KV) * Add append-kv smoke tests * Remove debug statements * Remove more debug statements * Re-arrange the 'set +x' command * Remove no-longer used method in pipeline * Add missing init code * Refine pipeline padding settings * Enlarge rotary_dim limit (8 -> 16) * Enlarge KPerThread for rotary_interleaved=false * Update rotary_dim range in smoke_test_fwd.sh * Add template argument 'kIsPagedKV' for splitkv kernels * Launch splitkv kernel if given page_block_size * Fix wrong kernel name * Fix seqlen_k_min for pre-fill case (1 -> 0) * Add copy_const<> type trait * Add another make_tile_window() * Introduce 'TileWindowNavigator' types * Simplify TileWindowNavigator interfaces * Fix tile window navigation bugs * Disable calling fmha_fwd() * Remove ununnecessary data members * Simplify more make_tile_window() overloads * Move V tile through TileWindowNavigator * Fix uneven split checking logic * Move code after decide seqlen_q/seqlen_k * Make sure we always start reading complete tile * Use 128 as minimus page_block_size * Fix wrong origin for bias * Add batch_stride_k/batch_stride_v in group mode * Unify origin * Add missing kernel arguments for group mode * Add paged-kv codegen logic for appendkv kernels * Add block_table kernel args for appendkv kernel * Add tile navigators to the appendkv kernel * Fix wrong tensor descriptor lengths * Pass re-created tile window to pipeline * Fix wrong strides for appendkv kernel * Allow transit tile_window to another page-block * Handle cross-page-block write * Donot perform write again if already in last page-block * Always add fmha_fwd() api * Add missing group mode argument * Remove debug macro usages * Rename option s_k_new to s_knew * Separate splitkv/non-splitkv args/traits * Remove fmha_fwd_dispatch() * Fix compilation errors * Remove dropout code in splitkv kernel * Allow problem types without define kHasDropout attr * Use generic lambda to init traits objects * Separate more non-splitkv & splitkv traits/args * Display more info for specific kernels * Show more detailed warning message * Rename 'max_num_blocks' to 'max_num_page_blocks' * Remove no-longer used pipeline files * Wrap code by #if directives * Move functors to the begining of validation code * Use generic lambda to init all the api traits/args * Fix wrong seqlen for kvcache * Add missing comment * Rename TileWindowNavigator to PageBlockNavigator * Only expose necessary methods (not attributes) * Re-order pipeline paremeters * Refine smoke_test_fwd.sh * Fix wrong arugment count * Make tile window directly via PageBlockNavigator * Remove unused template paremeter * Remove group mode from appendkv kernel * Fix skcheck logic * Fix wrong syntax in skcheck expr * Use meaningful options in smoke test * Remove options * Fix formatting * Fix more format * Re-organize bash functions * Pass cache_batch_idx to kernels * Support cache_batch_idx in example * Fix compilation error * Add more appendkv test * Add more case for appendkv * Fix unexisted attribute * Remove 0 < seqlen_knew constraint * Clarify the case in warning message * Remove macro checking * Force batch mode when invoking appendkv & splitkv apis * Fix mode overriding logics * Fix wrong parameter name * Randomize seqlen_k if use kvcache * Use randomized seqlen_k for kvcache * Avoid using too small rotary_cos & rotary_sin * Rename parameter * Add seqlen_q & seqlen_k rules * Add comment * Add more comments * Fix compilation errors * Fix typo in comment * Remove type argument * Avoid seqlen_k=0 for kvcache * Revert "Avoid seqlen_k=0 for kvcache" This reverts commit21c4df89e4. * Fix wrong uneven split checking logics * Only randomize kvcache seqlen_k if 1 < batch * Return earlier if split is empty * Revert "Only randomize kvcache seqlen_k if 1 < batch" This reverts commitb9a4ab0d7e. * Re-order seqlen_k_start adjustment logics * Fix compilation errors * Re-format script * Find executable from folder automatically * Fix kvcache seqlen_k generating logic * Make comment more clear * Fix wrong knew/vew appending logic on host * Add s_barrier to sync threads * Revert "Add s_barrier to sync threads" This reverts commitd3f550f30c. * Support only using 1 row of rotary_cos/rotary_sin * Rotate Q in different way * Unify tensor view creation logics * Fix wrong argument * Add mask to switch how we use the rotary_cos/sin * Move attr from traits to problem * Move has_mask to fmha_fwd_appendkv_args * Support use uint32_t as SAD operand in Alibi<> * Use sad_u32() in splitkv kernels * Store tensor views in PageBlockNavigator * Use stored tensor view to update tile windows * Enlarge tensor view size * Remove debug code * Fix wrong tensor view size * Wrap tensor view into PageBlockNavigator * Add DataType member to PageBlockNavigator * Remove unnecessary member functions * Refind macro use * Fix typo * Add blank line between directives and actual code * Re-format files * Remove type in comment --------- Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
@@ -43,9 +43,12 @@ enum struct AlibiMode
|
||||
FROM_BOTTOM_RIGHT = 2,
|
||||
};
|
||||
|
||||
template <typename DataType, bool RowMajor = true>
|
||||
template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
|
||||
struct Alibi
|
||||
{
|
||||
static_assert(1 <= LogMaxSadOprndSize && LogMaxSadOprndSize <= 32,
|
||||
"for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t");
|
||||
|
||||
// RowMajor here means if pixel within the same thread are along the row, or col
|
||||
// this may impact the performance of update(), while the result are the same.
|
||||
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
|
||||
@@ -79,6 +82,19 @@ struct Alibi
|
||||
mode = mode_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) { return sad_u32(x, y, acc); }
|
||||
|
||||
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
|
||||
{
|
||||
if constexpr(LogMaxSadOprndSize <= 16)
|
||||
{
|
||||
return sad_u16(
|
||||
static_cast<uint16_t>(x), static_cast<uint16_t>(y), static_cast<uint16_t>(acc));
|
||||
}
|
||||
|
||||
return sad_u32(x, y, acc);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx)
|
||||
{
|
||||
if constexpr(RowMajor)
|
||||
@@ -128,7 +144,7 @@ struct EmptyPositionEncoding
|
||||
// can convert from the FA style left/right to our generic coordinate
|
||||
// if left_size < 0 && right_size = 0, it is normal causal mask
|
||||
// local is left_size >=0 or right_size >=0
|
||||
template <typename DataType, bool RowMajor = true>
|
||||
template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
|
||||
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
|
||||
index_t window_left_size,
|
||||
index_t window_right_size,
|
||||
@@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
|
||||
AlibiMode alibi_mode =
|
||||
is_causal ? AlibiMode::VERTICAL
|
||||
: static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/;
|
||||
return Alibi<DataType, RowMajor>{slope, y_total, x_total, alibi_mode};
|
||||
return Alibi<DataType, RowMajor, LogMaxSadOprndSize>{slope, y_total, x_total, alibi_mode};
|
||||
}
|
||||
|
||||
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
|
||||
|
||||
108
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
Normal file
108
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
Normal file
@@ -0,0 +1,108 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class RotaryEmbeddingEnum
|
||||
{
|
||||
NONE = 0,
|
||||
INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc
|
||||
HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
|
||||
};
|
||||
|
||||
template <RotaryEmbeddingEnum>
|
||||
struct RotaryEmbeddingEnumToStr;
|
||||
|
||||
template <>
|
||||
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::NONE>
|
||||
{
|
||||
static constexpr const char* name = "";
|
||||
};
|
||||
template <>
|
||||
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::INTERLEAVED>
|
||||
{
|
||||
static constexpr const char* name = "inter";
|
||||
};
|
||||
template <>
|
||||
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::HALF_ROTATED>
|
||||
{
|
||||
static constexpr const char* name = "half";
|
||||
};
|
||||
|
||||
template <RotaryEmbeddingEnum RotaryEnum, typename ComputeDataType = float>
|
||||
struct BlockRotaryEmbedding
|
||||
{
|
||||
template <typename DistributedTensor,
|
||||
typename OtherDramBlockWindow,
|
||||
typename RotaryCosDramBlockWindow,
|
||||
typename RotarySinDramBlockWindow>
|
||||
CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile,
|
||||
OtherDramBlockWindow other_window,
|
||||
RotaryCosDramBlockWindow rotary_cos_window,
|
||||
RotarySinDramBlockWindow rotary_sin_window,
|
||||
index_t rotary_dim,
|
||||
index_t thread_end)
|
||||
{
|
||||
using DataType = typename remove_cvref_t<DistributedTensor>::DataType;
|
||||
|
||||
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<ComputeDataType>(tile.thread_buf_[idx + 1]);
|
||||
|
||||
const auto cos =
|
||||
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx / 2]);
|
||||
const auto sin =
|
||||
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx / 2]);
|
||||
|
||||
tile.thread_buf_[idx] = type_convert<DataType>(left * cos - right * sin);
|
||||
tile.thread_buf_[idx + 1] = type_convert<DataType>(right * cos + left * sin);
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
|
||||
{
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
const bool is_left = (thread_end <= (rotary_dim / 2));
|
||||
|
||||
move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
auto other_tile = load_tile(other_window);
|
||||
|
||||
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
|
||||
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
|
||||
const auto curr = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
|
||||
const auto other = type_convert<ComputeDataType>(other_tile.thread_buf_[idx]);
|
||||
|
||||
const auto cos =
|
||||
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx]);
|
||||
const auto sin =
|
||||
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx]);
|
||||
|
||||
tile.thread_buf_[idx] =
|
||||
type_convert<DataType>(curr * cos + other * (is_left ? -sin : sin));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
279
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
Normal file
279
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
Normal file
@@ -0,0 +1,279 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// assume that we have only 1 page-block/tensor view
|
||||
template <typename TensorView>
|
||||
struct TrivialPageBlockNavigator
|
||||
{
|
||||
using DataType = typename TensorView::DataType;
|
||||
using WindowOrigin = multi_index<2>;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr TrivialPageBlockNavigator(const TensorView& tensor_view_)
|
||||
: tensor_view(tensor_view_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_window(const WindowLengths& window_lengths,
|
||||
const WindowOrigin& window_origin) const
|
||||
{
|
||||
return make_tuple(/*block_index=*/0,
|
||||
ck_tile::make_tile_window(tensor_view, window_lengths, window_origin));
|
||||
}
|
||||
|
||||
template <typename WindowLengths, typename TileDistribution>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_tile_window(const WindowLengths& window_lengths,
|
||||
const WindowOrigin& window_origin,
|
||||
const TileDistribution& tile_distribution) const
|
||||
{
|
||||
return make_tuple(
|
||||
/*block_index=*/0,
|
||||
ck_tile::make_tile_window(
|
||||
tensor_view, window_lengths, window_origin, tile_distribution));
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE static index_t
|
||||
move_tile_window(index_t /*block_index*/,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step)
|
||||
{
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
|
||||
return /*block_index=*/0;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
|
||||
to_local_window_origin(const WindowOrigin& global_window_origin)
|
||||
{
|
||||
return global_window_origin;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
|
||||
to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin)
|
||||
{
|
||||
return local_window_origin;
|
||||
}
|
||||
|
||||
private:
|
||||
TensorView tensor_view;
|
||||
};
|
||||
|
||||
// default page-block navigator, assume that tensor view size is same as page-block size or smaller
|
||||
// if tile window on last page-block
|
||||
template <typename DataType_, index_t VirtualDim, typename TensorView>
|
||||
struct PageBlockNavigator
|
||||
{
|
||||
using DataType = DataType_;
|
||||
static_assert(std::is_same_v<DataType, typename TensorView::DataType>);
|
||||
static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window");
|
||||
using WindowOrigin = multi_index<2>;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr PageBlockNavigator(copy_const_t<DataType, void>* physical_blocks_,
|
||||
long_index_t block_stride_,
|
||||
long_index_t fixed_offset_,
|
||||
const int32_t* physical_block_indices_,
|
||||
index_t num_blocks_,
|
||||
index_t page_block_size_,
|
||||
const TensorView& complete_view_,
|
||||
const TensorView& last_view_)
|
||||
: physical_blocks(reinterpret_cast<DataType*>(physical_blocks_)),
|
||||
block_stride(block_stride_),
|
||||
fixed_offset(fixed_offset_),
|
||||
physical_block_indices(physical_block_indices_),
|
||||
num_blocks(num_blocks_),
|
||||
page_block_size(page_block_size_),
|
||||
complete_view(complete_view_),
|
||||
last_view(last_view_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
|
||||
const WindowOrigin& window_origin) const
|
||||
{
|
||||
const index_t block_index = get_block_index(window_origin);
|
||||
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
|
||||
|
||||
auto new_tile_window =
|
||||
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
|
||||
window_lengths,
|
||||
local_window_origin);
|
||||
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
|
||||
|
||||
return make_tuple(block_index, new_tile_window);
|
||||
}
|
||||
|
||||
template <typename WindowLengths, typename TileDistribution>
|
||||
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
|
||||
const WindowOrigin& window_origin,
|
||||
const TileDistribution& tile_distribution) const
|
||||
{
|
||||
const index_t block_index = get_block_index(window_origin);
|
||||
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
|
||||
|
||||
auto new_tile_window =
|
||||
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
|
||||
window_lengths,
|
||||
local_window_origin,
|
||||
tile_distribution);
|
||||
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
|
||||
|
||||
return make_tuple(block_index, new_tile_window);
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
move_tile_window(index_t block_index,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
|
||||
{
|
||||
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
|
||||
const WindowOrigin global_window_origin =
|
||||
to_global_window_origin(block_index, tile_window.get_window_origin());
|
||||
const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
|
||||
|
||||
const index_t new_block_index = get_block_index(global_window_origin);
|
||||
/// TODO: only update necessary attributes
|
||||
tile_window.bottom_tensor_view_.desc_ =
|
||||
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
|
||||
tile_window.set_window_origin(local_window_origin);
|
||||
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
|
||||
|
||||
return new_block_index;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
|
||||
{
|
||||
return block_index == num_blocks - 1;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index,
|
||||
const TileWindow& tile_window) const
|
||||
{
|
||||
const index_t origin = tile_window.get_window_origin().at(number<VirtualDim>{});
|
||||
const index_t length = tile_window.get_window_lengths().at(number<VirtualDim>{});
|
||||
return (block_index < num_blocks - 1) && (page_block_size < origin + length);
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const
|
||||
{
|
||||
const multi_index<2> step = [&]() {
|
||||
const index_t origin_diff = (block_index - new_block_index) * page_block_size;
|
||||
if constexpr(VirtualDim == 0)
|
||||
{
|
||||
return make_multi_index(origin_diff, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_multi_index(0, origin_diff);
|
||||
}
|
||||
}();
|
||||
|
||||
/// TODO: only update necessary attributes
|
||||
tile_window.bottom_tensor_view_.desc_ =
|
||||
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
|
||||
tile_window.set_window_origin(tile_window.get_window_origin() + step);
|
||||
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE WindowOrigin
|
||||
to_local_window_origin(const WindowOrigin& global_window_origin) const
|
||||
{
|
||||
if constexpr(VirtualDim == 0)
|
||||
{
|
||||
const index_t length = global_window_origin.at(number<0>{});
|
||||
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
|
||||
return make_multi_index(length - page_block_size * num_complete_blocks,
|
||||
global_window_origin.at(number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t length = global_window_origin.at(number<1>{});
|
||||
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
|
||||
return make_multi_index(global_window_origin.at(number<0>{}),
|
||||
length - page_block_size * num_complete_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE WindowOrigin
|
||||
to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const
|
||||
{
|
||||
if constexpr(VirtualDim == 0)
|
||||
{
|
||||
return make_multi_index(block_index * page_block_size +
|
||||
local_window_origin.at(number<0>{}),
|
||||
local_window_origin.at(number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_multi_index(local_window_origin.at(number<0>{}),
|
||||
block_index * page_block_size +
|
||||
local_window_origin.at(number<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
CK_TILE_HOST_DEVICE
|
||||
DataType* get_block_ptr(index_t block_index) const
|
||||
{
|
||||
return physical_blocks + physical_block_indices[block_index] * block_stride + fixed_offset;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const
|
||||
{
|
||||
return integer_divide_floor(global_window_origin.at(number<VirtualDim>{}), page_block_size);
|
||||
}
|
||||
|
||||
DataType* physical_blocks;
|
||||
long_index_t block_stride;
|
||||
long_index_t fixed_offset;
|
||||
|
||||
const int32_t* physical_block_indices;
|
||||
index_t num_blocks;
|
||||
index_t page_block_size;
|
||||
|
||||
TensorView complete_view;
|
||||
TensorView last_view;
|
||||
};
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView& tensor_view)
|
||||
{
|
||||
return TrivialPageBlockNavigator<TensorView>(tensor_view);
|
||||
}
|
||||
|
||||
template <typename DataType, index_t VirtualDim, typename TensorView>
|
||||
CK_TILE_HOST_DEVICE auto make_page_block_navigator(copy_const_t<DataType, void>* physical_blocks,
|
||||
long_index_t block_stride,
|
||||
long_index_t fixed_offset,
|
||||
const int32_t* physical_block_indices,
|
||||
index_t num_blocks,
|
||||
index_t page_block_size,
|
||||
const TensorView& complete_view,
|
||||
const TensorView& last_view)
|
||||
{
|
||||
return PageBlockNavigator<DataType, VirtualDim, TensorView>(physical_blocks,
|
||||
block_stride,
|
||||
fixed_offset,
|
||||
physical_block_indices,
|
||||
num_blocks,
|
||||
page_block_size,
|
||||
complete_view,
|
||||
last_view);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user