fix multi batch and hack page idx core

This commit is contained in:
coderfeli
2025-04-15 01:21:23 +00:00
parent 234528e06c
commit ff281e135d
7 changed files with 51 additions and 38 deletions

View File

@@ -448,14 +448,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
else:
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels

View File

@@ -401,7 +401,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
cache_seqlen_ks.begin(),
[&](auto seqlen_k) { return seqlen_k - seqlen_knew; });
#if 0
#if 1
// clang-format off
std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl;
std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl;
@@ -481,12 +481,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
// std::vector<int32_t> page_idx_host(seqstart_k_host.back(), 0);
ck_tile::HostTensor<int32_t> page_idx_host({seqstart_k_host.back()});
// std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end(), 0);
// for (int i = 0; i < page_idx_host.get_element_space_size(); i++) {
// page_idx_host(i) = (i + 19) % page_idx_host.size();
// }
ck_tile::HostTensor<int32_t> page_idx_host({seqstart_k_host.back() + 65536});
std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end() - 65536, 0);
page_idx_host.savetxt("page_idx_host.txt", "int");
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
@@ -1023,6 +1020,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.page_idx_ptr =
(mode == mode_enum::group ? page_idx.GetDeviceBuffer() : nullptr);
args.page_num =
(mode == mode_enum::group ? shape_seqlen_k : 0);
args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);

View File

@@ -131,6 +131,7 @@ struct fmha_fwd_args
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
const void* page_idx_ptr;
ck_tile::index_t page_num;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
@@ -328,6 +329,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.page_idx_ptr,
args.page_num,
args.hdim_q,
args.hdim_v,
args.nhead_q,

View File

@@ -338,7 +338,6 @@ struct tile_scatter_gather
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(

View File

@@ -266,6 +266,7 @@ struct FmhaFwdKernel
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
const int32_t* page_idx;
int32_t page_num;
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
@@ -598,6 +599,7 @@ struct FmhaFwdKernel
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
const void* page_idx_ptr,
ck_tile::index_t page_num,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -657,7 +659,8 @@ struct FmhaFwdKernel
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
reinterpret_cast<const int32_t*>(page_idx_ptr)};
reinterpret_cast<const int32_t*>(page_idx_ptr),
page_num};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
@@ -724,6 +727,7 @@ struct FmhaFwdKernel
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
const void* page_idx_ptr,
ck_tile::index_t page_num,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -763,6 +767,7 @@ struct FmhaFwdKernel
seqstart_k_ptr,
seqlen_k_ptr,
page_idx_ptr,
page_num,
hdim_q,
hdim_v,
num_head_q,
@@ -805,6 +810,7 @@ struct FmhaFwdKernel
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
const void* page_idx_ptr,
ck_tile::index_t page_num,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -844,6 +850,7 @@ struct FmhaFwdKernel
seqstart_k_ptr,
seqlen_k_ptr,
page_idx_ptr,
page_num,
hdim_q,
hdim_v,
num_head_q,
@@ -966,7 +973,7 @@ struct FmhaFwdKernel
long_index_t batch_offset_q = 0;
// long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
// long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0;
long_index_t batch_offset_lse = 0;
@@ -980,16 +987,19 @@ struct FmhaFwdKernel
batch_offset_q = query_start * kargs.stride_q;
// batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
}
else
{
batch_offset_v = key_start;
}
// if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
// {
// batch_offset_v = key_start * kargs.stride_v;
// }
// else
// {
// batch_offset_v = key_start;
// }
kargs.page_idx += key_start;
// if(threadIdx.x==0){
// printf("\nbid %d %d page id %d pagev %d\n", blockIdx.z, i_batch, key_start, kargs.page_idx[0]);
// }
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias + key_start;
@@ -1050,17 +1060,17 @@ struct FmhaFwdKernel
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
// const KDataType* k_ptr =
// reinterpret_cast<const KDataType*>(kargs.k_ptr) +
// static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
// batch_offset_k;
const KDataType* k_ptr =
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;// +
// batch_offset_k;
// const KDataType* k_ptr =
// reinterpret_cast<const KDataType*>(kargs.k_ptr) +
// static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;
const VDataType* v_ptr =
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v;
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v;// +
// batch_offset_v;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
@@ -1091,7 +1101,7 @@ struct FmhaFwdKernel
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_q),
make_tuple(kargs.page_num, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
@@ -1107,7 +1117,7 @@ struct FmhaFwdKernel
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_v),
make_tuple(kargs.page_num, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
@@ -1115,7 +1125,7 @@ struct FmhaFwdKernel
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)),
make_pass_through_transform(kargs.page_num)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));

View File

@@ -10,7 +10,6 @@
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
// #include "ck_tile/core/tensor/tile_scatter_gather_debug.hpp"
namespace ck_tile {
@@ -315,6 +314,11 @@ struct BlockFmhaPipelineQRKSVS
k_dram_block_window.get_window_origin(),
k_dist,
k_offsets); // K DRAM tile window for
// auto k_dram_window = make_tile_window_debug(
// k_dram_block_window.get_bottom_tensor_view(),
// k_dram_block_window.get_window_lengths(),
// k_dram_block_window.get_window_origin(),
// k_dist); // K DRAM tile window for
auto k_block_tile = load_tile(k_dram_window);
{

View File

@@ -67,7 +67,6 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();