Fix vnew append errro

This commit is contained in:
PoYen, Chen
2024-06-26 17:00:07 +00:00
parent 4e6c28522c
commit 8fb567c286
5 changed files with 165 additions and 30 deletions

View File

@@ -518,8 +518,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return i_perm ? (shape_seqlen_k - seqlen_knew)
: nhead_k * (shape_seqlen_k - seqlen_knew);
return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k;
}();
const ck_tile::index_t stride_vnew = [&]() {
if(is_v_rowmajor)
@@ -528,16 +527,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
}();
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k =
(i_perm ? (shape_seqlen_k - seqlen_knew) * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? (shape_seqlen_k - seqlen_knew) * hdim_v : hdim_v;
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * (shape_seqlen_k - seqlen_knew)
: (shape_seqlen_k - seqlen_knew);
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
}();
const ck_tile::index_t nhead_stride_vnew = [&]() {
if(is_v_rowmajor)
@@ -546,12 +543,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
}();
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k =
(nhead_k * (shape_seqlen_k - seqlen_knew) * hdim_q);
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
const ck_tile::index_t batch_stride_v =
(nhead_k * hdim_v * (shape_seqlen_k - seqlen_knew));
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
return fmha_fwd_appendkv_args{q_buf.GetDeviceBuffer(),
@@ -729,16 +724,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
<< " GB/s" << std::flush;
k_buf.FromDevice(k_host.data());
for(int row = 0; row < shape_seqlen_k; ++row)
#if defined(ENABLE_HOST_DEBUG_PRINT)
if(!do_validation)
{
printf("[POYENC][HOST] k_host[%3d] = ", row);
for(int col = 0; col < hdim_q; ++col)
#if 0
k_buf.FromDevice(k_host.data());
for(int row = 0; row < shape_seqlen_k; ++row)
{
printf("%11.7f", ck_tile::type_convert<float>(k_host(0, 0, row, col)));
printf("[POYENC][HOST] k_host[%3d] = ", row);
for(int col = 0; col < hdim_q; ++col)
{
printf("%11.7f", ck_tile::type_convert<float>(k_host(0, 0, row, col)));
}
printf("\n");
}
printf("\n");
#endif
#if 1
v_buf.FromDevice(v_host.data());
for(int row = 0; row < shape_seqlen_k; ++row)
{
printf("[POYENC][HOST] v_host[%3d] = ", row);
for(int col = 0; col < hdim_v; ++col)
{
if(vlayout == "r")
{
printf("%11.7f", ck_tile::type_convert<float>(v_host(0, 0, row, col)));
}
else
{
printf("%11.7f", ck_tile::type_convert<float>(v_host(0, 0, col, row)));
}
}
printf("\n");
}
#endif
}
#endif
if(!do_validation)
{

View File

@@ -344,7 +344,7 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
}
}();
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_knew, args.hdim_v);
printf("[POYENC][HOST] grid size: %2d,%2d,%2d\n",
static_cast<int>(grids.x),
static_cast<int>(grids.y),

View File

@@ -270,6 +270,28 @@ struct FmhaFwdAppendKVKernel
const index_t i_sk = __builtin_amdgcn_readfirstlane(i_tile_sk * FmhaPipeline::kTileSizeSk);
// const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
#if defined(ENABLE_KERNEL_DEBUG_PRINT)
#define PRINTF(expr) printf("[POYENC][DEVICE] " #expr ": %2d\n", (expr));
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID)
{
PRINTF(kargs.stride_k);
PRINTF(kargs.nhead_stride_k);
PRINTF(kargs.batch_stride_k);
PRINTF(kargs.stride_knew);
PRINTF(kargs.nhead_stride_knew);
PRINTF(kargs.batch_stride_knew);
PRINTF(kargs.stride_v);
PRINTF(kargs.nhead_stride_v);
PRINTF(kargs.batch_stride_v);
PRINTF(kargs.stride_vnew);
PRINTF(kargs.nhead_stride_vnew);
PRINTF(kargs.batch_stride_vnew);
}
#endif
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_knew =
@@ -477,13 +499,26 @@ struct FmhaFwdAppendKVKernel
auto v_dram_window = make_tile_window(
v_dram,
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeDv>{}),
{kargs.seqlen_k - kargs.seqlen_knew, 0});
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kTileSizeSk>{}),
{0, kargs.seqlen_k - kargs.seqlen_knew});
auto vnew_dram_window = make_tile_window(
vnew_dram,
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeDv>{}),
{i_sk, 0});
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kTileSizeSk>{}),
{0, i_sk});
#if defined(ENABLE_KERNEL_DEBUG_PRINT)
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID)
{
printf("[POYENC][DEVICE] kargs.seqlen_k - kargs.seqlen_knew: %d\n",
kargs.seqlen_k - kargs.seqlen_knew);
printf("[POYENC][DEVICE] i_sk: %d\n", i_sk);
printf("[POYENC][DEVICE] v_dram.get_length(0): %d\n",
v_dram.get_tensor_descriptor().get_length(number<0>{}));
printf("[POYENC][DEVICE] v_dram.get_length(1): %d\n",
v_dram.get_tensor_descriptor().get_length(number<1>{}));
}
#endif
FmhaPipeline{}(q_dram_window,
k_dram_window,

View File

@@ -121,26 +121,45 @@ struct BlockFmhaFwdAppendKVPipeline
Policy::template MakeKnewDramTileDistribution<Problem>());
auto knew_tile = load_tile(knew_dram_window);
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0)
/// TODO: apply RoPE on knew_tile here
store_tile(k_dram_block_window_tmp, knew_tile);
auto vnew_dram_block_window =
make_tile_window(vnew_dram_block_window_tmp.get_bottom_tensor_view(),
vnew_dram_block_window_tmp.get_window_lengths(),
{0, 0});
auto vnew_dram_window =
make_tile_window(vnew_dram_block_window.get_bottom_tensor_view(),
vnew_dram_block_window.get_window_lengths(),
vnew_dram_block_window.get_window_origin(),
Policy::template MakeVnewDramTileDistribution<Problem>());
auto vnew_tile = load_tile(vnew_dram_window);
#if defined(ENABLE_PIPELINE_DEBUG_PRINT)
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID)
{
constexpr auto spans = decltype(knew_tile)::get_distributed_spans();
printf("[POYENC][DEVICE] tid: %d\n", TID);
constexpr auto spans = decltype(vnew_tile)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
knew_tile.get_tile_distribution(), make_tuple(idx0, idx1));
vnew_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
printf("[POYENC][DEVICE] knew_tile(%2d,%2d): %11.7f\n",
printf("[POYENC][DEVICE] vnew_tile(%2d,%2d): %11.7f\n",
row,
col,
type_convert<float>(knew_tile(i_j_idx)));
type_convert<float>(vnew_tile(i_j_idx)));
});
});
}
store_tile(k_dram_block_window_tmp, knew_tile);
#endif
store_tile(v_dram_block_window_tmp, vnew_tile);
}
template <typename QDramBlockWindowTmp,

View File

@@ -78,6 +78,65 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
// TODO: this is for 3d layout
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType);
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVnewDramTileDistribution()
{
using VLayout = remove_cvref_t<typename Problem::VLayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
static_assert(!std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t kNPerBlock = Problem::kTileSizeDv;
constexpr index_t kKPerBlock = Problem::kTileSizeSk;
constexpr index_t K1 = GetAlignmentV<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
static_assert(N0 != 0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kNPerBlock = Problem::kTileSizeDv;
constexpr index_t kKPerBlock = Problem::kTileSizeSk;
constexpr index_t K1 = 16 / sizeof(VDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
};
} // namespace ck_tile