mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
fix multi batch and hack page idx core
This commit is contained in:
@@ -448,14 +448,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
|
||||
@@ -401,7 +401,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
cache_seqlen_ks.begin(),
|
||||
[&](auto seqlen_k) { return seqlen_k - seqlen_knew; });
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// clang-format off
|
||||
std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl;
|
||||
std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl;
|
||||
@@ -481,12 +481,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
||||
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
|
||||
// std::vector<int32_t> page_idx_host(seqstart_k_host.back(), 0);
|
||||
ck_tile::HostTensor<int32_t> page_idx_host({seqstart_k_host.back()});
|
||||
// std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
|
||||
iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end(), 0);
|
||||
// for (int i = 0; i < page_idx_host.get_element_space_size(); i++) {
|
||||
// page_idx_host(i) = (i + 19) % page_idx_host.size();
|
||||
// }
|
||||
ck_tile::HostTensor<int32_t> page_idx_host({seqstart_k_host.back() + 65536});
|
||||
std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
|
||||
iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end() - 65536, 0);
|
||||
page_idx_host.savetxt("page_idx_host.txt", "int");
|
||||
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
|
||||
@@ -1023,6 +1020,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
args.page_idx_ptr =
|
||||
(mode == mode_enum::group ? page_idx.GetDeviceBuffer() : nullptr);
|
||||
args.page_num =
|
||||
(mode == mode_enum::group ? shape_seqlen_k : 0);
|
||||
args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
|
||||
@@ -131,6 +131,7 @@ struct fmha_fwd_args
|
||||
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
|
||||
const void* page_idx_ptr;
|
||||
|
||||
ck_tile::index_t page_num;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
@@ -328,6 +329,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.page_idx_ptr,
|
||||
args.page_num,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
|
||||
@@ -338,7 +338,6 @@ struct tile_scatter_gather
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
|
||||
@@ -266,6 +266,7 @@ struct FmhaFwdKernel
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
const int32_t* page_idx;
|
||||
int32_t page_num;
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
|
||||
@@ -598,6 +599,7 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* page_idx_ptr,
|
||||
ck_tile::index_t page_num,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -657,7 +659,8 @@ struct FmhaFwdKernel
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(page_idx_ptr)};
|
||||
reinterpret_cast<const int32_t*>(page_idx_ptr),
|
||||
page_num};
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -724,6 +727,7 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* page_idx_ptr,
|
||||
ck_tile::index_t page_num,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -763,6 +767,7 @@ struct FmhaFwdKernel
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
page_idx_ptr,
|
||||
page_num,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -805,6 +810,7 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* page_idx_ptr,
|
||||
ck_tile::index_t page_num,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -844,6 +850,7 @@ struct FmhaFwdKernel
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
page_idx_ptr,
|
||||
page_num,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -966,7 +973,7 @@ struct FmhaFwdKernel
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
// long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
// long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_randval = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
@@ -980,16 +987,19 @@ struct FmhaFwdKernel
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
// batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
// if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
// {
|
||||
// batch_offset_v = key_start * kargs.stride_v;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// batch_offset_v = key_start;
|
||||
// }
|
||||
|
||||
kargs.page_idx += key_start;
|
||||
// if(threadIdx.x==0){
|
||||
// printf("\nbid %d %d page id %d pagev %d\n", blockIdx.z, i_batch, key_start, kargs.page_idx[0]);
|
||||
// }
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.stride_bias + key_start;
|
||||
@@ -1050,17 +1060,17 @@ struct FmhaFwdKernel
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
batch_offset_q;
|
||||
// const KDataType* k_ptr =
|
||||
// reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
// static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
|
||||
// batch_offset_k;
|
||||
const KDataType* k_ptr =
|
||||
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;// +
|
||||
// batch_offset_k;
|
||||
// const KDataType* k_ptr =
|
||||
// reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
// static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;
|
||||
const VDataType* v_ptr =
|
||||
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
|
||||
batch_offset_v;
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v;// +
|
||||
// batch_offset_v;
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
batch_offset_o;
|
||||
@@ -1091,7 +1101,7 @@ struct FmhaFwdKernel
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_q),
|
||||
make_tuple(kargs.page_num, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
@@ -1107,7 +1117,7 @@ struct FmhaFwdKernel
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_v),
|
||||
make_tuple(kargs.page_num, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
@@ -1115,7 +1125,7 @@ struct FmhaFwdKernel
|
||||
const auto v_dram_transposed =
|
||||
transform_tensor_view(v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen_k)),
|
||||
make_pass_through_transform(kargs.page_num)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
|
||||
// #include "ck_tile/core/tensor/tile_scatter_gather_debug.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -315,6 +314,11 @@ struct BlockFmhaPipelineQRKSVS
|
||||
k_dram_block_window.get_window_origin(),
|
||||
k_dist,
|
||||
k_offsets); // K DRAM tile window for
|
||||
// auto k_dram_window = make_tile_window_debug(
|
||||
// k_dram_block_window.get_bottom_tensor_view(),
|
||||
// k_dram_block_window.get_window_lengths(),
|
||||
// k_dram_block_window.get_window_origin(),
|
||||
// k_dist); // K DRAM tile window for
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
{
|
||||
|
||||
@@ -67,7 +67,6 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
|
||||
|
||||
Reference in New Issue
Block a user