mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Fix tile window navigation bugs
This commit is contained in:
@@ -262,11 +262,12 @@ float fmha_fwd_dispatch(fmha_fwd_traits traits,
|
||||
{
|
||||
return fmha_fwd_splitkv(traits, args, config);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#if 0
|
||||
{
|
||||
return fmha_fwd(traits, args, config);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
@@ -444,8 +445,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
|
||||
#define ENABLE_PAGED_KVCACHE 0
|
||||
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
@@ -507,8 +506,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
const ck_tile::index_t max_num_blocks =
|
||||
(ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
? batch * std::min(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size))
|
||||
(0 < page_block_size
|
||||
? batch * std::max(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size))
|
||||
: 0);
|
||||
|
||||
// legalize num_splits according to other options
|
||||
@@ -545,21 +544,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
|
||||
: seqstart_k_with_padding_host.back()));
|
||||
|
||||
std::cerr << "[POYENC] num_blocks: " << max_num_blocks << std::endl;
|
||||
std::cerr << "[HOST] num_blocks: " << max_num_blocks << std::endl;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_q)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
0 < page_block_size ? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_q)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
|
||||
ck_tile::HostTensor<KDataType> knew_host(
|
||||
0 < seqlen_knew
|
||||
? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<VDataType> v_host(
|
||||
ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
USE_PAGED_VCACHE && 0 < page_block_size
|
||||
? (is_v_rowmajor
|
||||
? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_v)
|
||||
: get_lengths(i_perm, max_num_blocks, nhead_k, hdim_v, page_block_size))
|
||||
@@ -606,9 +604,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
|
||||
ck_tile::HostTensor<int32_t> block_table_host(
|
||||
ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
? std::array<ck_tile::index_t, 2>{batch, max_num_blocks / batch}
|
||||
: std::array<ck_tile::index_t, 2>{1, 1});
|
||||
0 < page_block_size ? std::array<ck_tile::index_t, 2>{batch, max_num_blocks / batch}
|
||||
: std::array<ck_tile::index_t, 2>{1, 1});
|
||||
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
@@ -884,7 +881,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
return USE_PAGED_VCACHE && 0 < page_block_size
|
||||
? (i_perm ? page_block_size : nhead_k * page_block_size)
|
||||
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
}();
|
||||
@@ -894,16 +891,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = (ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
? (i_perm ? page_block_size * hdim_q : hdim_q)
|
||||
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q));
|
||||
const ck_tile::index_t nhead_stride_k =
|
||||
(0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q)
|
||||
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q));
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
return USE_PAGED_VCACHE && 0 < page_block_size
|
||||
? (i_perm ? page_block_size * hdim_v : hdim_v)
|
||||
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
|
||||
else
|
||||
return ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
return USE_PAGED_VCACHE && 0 < page_block_size
|
||||
? (i_perm ? hdim_v * page_block_size : page_block_size)
|
||||
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
|
||||
}();
|
||||
@@ -917,11 +914,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k =
|
||||
(ENABLE_PAGED_KVCACHE && 0 < page_block_size ? (nhead_k * page_block_size * hdim_q)
|
||||
: (nhead_k * shape_seqlen_k * hdim_q));
|
||||
(0 < page_block_size ? (nhead_k * page_block_size * hdim_q)
|
||||
: (nhead_k * shape_seqlen_k * hdim_q));
|
||||
const ck_tile::index_t batch_stride_v =
|
||||
(ENABLE_PAGED_KVCACHE && 0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
|
||||
: (nhead_k * hdim_v * shape_seqlen_k));
|
||||
(USE_PAGED_VCACHE && 0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
|
||||
: (nhead_k * hdim_v * shape_seqlen_k));
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
|
||||
@@ -1076,7 +1073,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
#endif
|
||||
|
||||
if (ENABLE_PAGED_KVCACHE && 0 < page_block_size) {
|
||||
if (0 < page_block_size) {
|
||||
if(i_perm) {
|
||||
k_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]);
|
||||
@@ -1128,7 +1125,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
});
|
||||
}
|
||||
#endif
|
||||
if (ENABLE_PAGED_KVCACHE && 0 < page_block_size) {
|
||||
if (USE_PAGED_VCACHE && 0 < page_block_size) {
|
||||
if (is_v_rowmajor) {
|
||||
if(i_perm) {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
|
||||
@@ -214,6 +214,12 @@ struct tile_window_with_static_distribution
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
@@ -843,6 +849,17 @@ struct tile_window_with_static_lengths
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; }
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ template <typename DataType_, typename TensorViewLengths_, typename TensorViewSt
|
||||
struct SimpleTileWindowNavigator
|
||||
{
|
||||
using DataType = DataType_;
|
||||
using WindowOrigin = multi_index<2>;
|
||||
using TensorViewLengths = TensorViewLengths_;
|
||||
using TensorViewStrides = TensorViewStrides_;
|
||||
|
||||
@@ -24,17 +25,40 @@ struct SimpleTileWindowNavigator
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& window_origin)
|
||||
const WindowOrigin& window_origin)
|
||||
{
|
||||
return ck_tile::make_tile_window(tile_window, window_origin);
|
||||
return make_tuple(
|
||||
/*block_index=*/0, ck_tile::make_tile_window(tile_window, window_origin));
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_DEVICE void
|
||||
move_tile_window(TileWindow& tile_window,
|
||||
CK_TILE_DEVICE static index_t
|
||||
move_tile_window(index_t /*block_index*/,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step)
|
||||
{
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
|
||||
return /*block_index=*/0;
|
||||
}
|
||||
|
||||
/// TODO: remove this method after finish debuging
|
||||
CK_TILE_DEVICE static constexpr int32_t
|
||||
get_block_index(const WindowOrigin& /*global_window_origin*/)
|
||||
{
|
||||
return /*block_index=*/0;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr WindowOrigin
|
||||
to_local_window_origin(const WindowOrigin& global_window_origin)
|
||||
{
|
||||
return global_window_origin;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr WindowOrigin
|
||||
to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin)
|
||||
{
|
||||
return local_window_origin;
|
||||
}
|
||||
|
||||
TensorViewLengths lengths;
|
||||
@@ -57,24 +81,24 @@ struct PagedTileWindowNavigator
|
||||
{
|
||||
using DataType = DataType_;
|
||||
static constexpr index_t VirtualDim = VirtualDim_;
|
||||
static_assert(VirtualDim == 0 || VirtualDim == 1);
|
||||
static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window");
|
||||
using WindowOrigin = multi_index<2>;
|
||||
using TensorViewLengths = TensorViewLengths_;
|
||||
using TensorViewStrides = TensorViewStrides_;
|
||||
|
||||
CK_TILE_DEVICE constexpr PagedTileWindowNavigator(copy_const_t<DataType, void>* blocks_,
|
||||
long_index_t block_stride_,
|
||||
long_index_t head_stride_,
|
||||
long_index_t row_stride_,
|
||||
const int32_t* block_indices_,
|
||||
index_t num_blocks_,
|
||||
index_t page_block_size_,
|
||||
const TensorViewLengths& lengths_,
|
||||
const TensorViewStrides& strides_)
|
||||
: blocks(reinterpret_cast<DataType*>(blocks_)),
|
||||
CK_TILE_DEVICE constexpr PagedTileWindowNavigator(
|
||||
copy_const_t<DataType, void>* physical_blocks_,
|
||||
long_index_t block_stride_,
|
||||
long_index_t fixed_offset_,
|
||||
const int32_t* physical_block_indices_,
|
||||
index_t num_blocks_,
|
||||
index_t page_block_size_,
|
||||
const TensorViewLengths& lengths_,
|
||||
const TensorViewStrides& strides_)
|
||||
: physical_blocks(reinterpret_cast<DataType*>(physical_blocks_)),
|
||||
block_stride(block_stride_),
|
||||
head_stride(head_stride_),
|
||||
row_stride(row_stride_),
|
||||
block_indices(block_indices_),
|
||||
fixed_offset(fixed_offset_),
|
||||
physical_block_indices(physical_block_indices_),
|
||||
num_blocks(num_blocks_),
|
||||
page_block_size(page_block_size_),
|
||||
lengths(lengths_),
|
||||
@@ -85,36 +109,89 @@ struct PagedTileWindowNavigator
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
CK_TILE_DEVICE auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& window_origin) const
|
||||
const WindowOrigin& window_origin) const
|
||||
{
|
||||
/// TODO: convert global window origin to local window origin
|
||||
auto local_window_origin = window_origin;
|
||||
const index_t block_index = get_block_index(window_origin);
|
||||
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
|
||||
|
||||
return ck_tile::make_tile_window(tile_window, local_window_origin);
|
||||
auto new_tile_window = ck_tile::make_tile_window(tile_window, local_window_origin);
|
||||
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
|
||||
|
||||
return make_tuple(block_index, new_tile_window);
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_DEVICE void
|
||||
move_tile_window(TileWindow& tile_window,
|
||||
CK_TILE_DEVICE index_t
|
||||
move_tile_window(index_t block_index,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
|
||||
{
|
||||
/// TODO: reset pointer and adjust local window origin
|
||||
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
|
||||
const WindowOrigin global_window_origin =
|
||||
to_global_window_origin(block_index, tile_window.get_window_origin());
|
||||
const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
|
||||
|
||||
const index_t new_block_index = get_block_index(global_window_origin);
|
||||
tile_window.set_window_origin(local_window_origin);
|
||||
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
|
||||
|
||||
return new_block_index;
|
||||
}
|
||||
|
||||
DataType* get_block_base(index_t block_index)
|
||||
CK_TILE_DEVICE
|
||||
DataType* get_block_ptr(index_t block_index) const
|
||||
{
|
||||
return blocks + block_index * block_stride + head_stride;
|
||||
return physical_blocks + physical_block_indices[block_index] * block_stride + fixed_offset;
|
||||
}
|
||||
|
||||
DataType* base(index_t i_virtual) { return get_block_base(); }
|
||||
CK_TILE_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const
|
||||
{
|
||||
return integer_divide_floor(global_window_origin.at(number<VirtualDim>{}), page_block_size);
|
||||
}
|
||||
|
||||
DataType* blocks;
|
||||
CK_TILE_DEVICE WindowOrigin
|
||||
to_local_window_origin(const WindowOrigin& global_window_origin) const
|
||||
{
|
||||
if constexpr(VirtualDim == 0)
|
||||
{
|
||||
const index_t length = global_window_origin.at(number<0>{});
|
||||
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
|
||||
return make_multi_index(length - page_block_size * num_complete_blocks,
|
||||
global_window_origin.at(number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t length = global_window_origin.at(number<1>{});
|
||||
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
|
||||
return make_multi_index(global_window_origin.at(number<0>{}),
|
||||
length - page_block_size * num_complete_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE WindowOrigin
|
||||
to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const
|
||||
{
|
||||
if constexpr(VirtualDim == 0)
|
||||
{
|
||||
return make_multi_index(block_index * page_block_size +
|
||||
local_window_origin.at(number<0>{}),
|
||||
local_window_origin.at(number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_multi_index(local_window_origin.at(number<0>{}),
|
||||
block_index * page_block_size +
|
||||
local_window_origin.at(number<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
DataType* physical_blocks;
|
||||
long_index_t block_stride;
|
||||
long_index_t head_stride;
|
||||
long_index_t row_stride;
|
||||
long_index_t fixed_offset;
|
||||
|
||||
const int32_t* block_indices;
|
||||
const int32_t* physical_block_indices;
|
||||
index_t num_blocks;
|
||||
index_t page_block_size;
|
||||
|
||||
@@ -126,22 +203,21 @@ template <typename DataType,
|
||||
index_t VirtualDim,
|
||||
typename TensorViewLengths,
|
||||
typename TensorViewStrides>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_window_navigator(copy_const_t<DataType, void>* blocks,
|
||||
long_index_t block_stride,
|
||||
long_index_t head_stride,
|
||||
long_index_t row_stride,
|
||||
const int32_t* block_indices,
|
||||
index_t num_blocks,
|
||||
index_t page_block_size,
|
||||
const TensorViewLengths& lengths,
|
||||
const TensorViewStrides& strides)
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window_navigator(copy_const_t<DataType, void>* physical_blocks,
|
||||
long_index_t block_stride,
|
||||
long_index_t fixed_offset,
|
||||
const int32_t* physical_block_indices,
|
||||
index_t num_blocks,
|
||||
index_t page_block_size,
|
||||
const TensorViewLengths& lengths,
|
||||
const TensorViewStrides& strides)
|
||||
{
|
||||
return PagedTileWindowNavigator<DataType, VirtualDim, TensorViewLengths, TensorViewStrides>(
|
||||
blocks,
|
||||
physical_blocks,
|
||||
block_stride,
|
||||
head_stride,
|
||||
row_stride,
|
||||
block_indices,
|
||||
fixed_offset,
|
||||
physical_block_indices,
|
||||
num_blocks,
|
||||
page_block_size,
|
||||
lengths,
|
||||
|
||||
@@ -533,7 +533,7 @@ struct FmhaFwdSplitKVKernel
|
||||
const long_index_t batch_offset_o_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
|
||||
auto k_tile_navigator = [&, i_batch_ = i_batch]() {
|
||||
auto k_tile_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
const auto* block_indices =
|
||||
@@ -545,17 +545,17 @@ struct FmhaFwdSplitKVKernel
|
||||
return make_tile_window_navigator<const KDataType, 0>(
|
||||
kargs.k_ptr,
|
||||
kargs.batch_stride_k,
|
||||
kargs.nhead_stride_k,
|
||||
kargs.stride_k,
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_k,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
kargs.page_block_size,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_q),
|
||||
make_tuple(kargs.page_block_size, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window_navigator<KDataType>(
|
||||
return make_tile_window_navigator<const KDataType>(
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_q), make_tuple(kargs.stride_k, 1));
|
||||
}
|
||||
}();
|
||||
@@ -611,14 +611,8 @@ struct FmhaFwdSplitKVKernel
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
const auto* block_indices =
|
||||
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
|
||||
i_batch * kargs.batch_stride_block_table;
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
// batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
// batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -641,22 +635,10 @@ struct FmhaFwdSplitKVKernel
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
batch_offset_q;
|
||||
const KDataType* k_ptr = [&, i_nhead_ = i_nhead]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
}
|
||||
else
|
||||
{
|
||||
return reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
}
|
||||
}();
|
||||
const KDataType* k_ptr =
|
||||
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
const VDataType* v_ptr =
|
||||
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
|
||||
|
||||
@@ -240,7 +240,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
}
|
||||
|
||||
auto k_dram_block_window =
|
||||
auto [i_block0, k_dram_block_window] =
|
||||
k_tile_navigator.make_tile_window(k_dram_block_window_tmp, {seqlen_k_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
@@ -278,7 +278,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
{
|
||||
k_tile_navigator.move_tile_window(k_dram_window, {0, kK0});
|
||||
// moving k_dram_window is an in-page-block operation, so there is
|
||||
// no need to invoke k_tile_navigator.move_tile_window() here.
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
clear_tile(s_acc); // initialize C
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
@@ -306,7 +308,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
block_sync_lds();
|
||||
k_tile_navigator.move_tile_window(k_dram_window, {0, kK0});
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
store_tile(
|
||||
k_lds_window,
|
||||
@@ -354,7 +356,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
const auto k_origin = k_tile_navigator.to_global_window_origin(
|
||||
i_block0, k_dram_block_window.get_window_origin());
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
@@ -383,7 +386,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
/// TODO: only check in last iteration without increasing code size
|
||||
if constexpr(kHasUnevenSplits)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
const auto k_origin = k_tile_navigator.to_global_window_origin(
|
||||
i_block0, k_dram_block_window.get_window_origin());
|
||||
set_tile_if(s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) {
|
||||
@@ -395,7 +399,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
const auto k_origin = k_tile_navigator.to_global_window_origin(
|
||||
i_block0, k_dram_block_window.get_window_origin());
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
@@ -555,7 +560,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
});
|
||||
}
|
||||
// move K tile windows
|
||||
k_tile_navigator.move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
i_block0 = k_tile_navigator.move_tile_window(i_block0, k_dram_block_window, {kN0, 0});
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
Reference in New Issue
Block a user