Fix vnew append errro

This commit is contained in:
PoYen, Chen
2024-06-26 17:00:07 +00:00
parent 4e6c28522c
commit 8fb567c286
5 changed files with 165 additions and 30 deletions

View File

@@ -518,8 +518,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return i_perm ? (shape_seqlen_k - seqlen_knew)
: nhead_k * (shape_seqlen_k - seqlen_knew);
return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k;
}();
const ck_tile::index_t stride_vnew = [&]() {
if(is_v_rowmajor)
@@ -528,16 +527,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
}();
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k =
(i_perm ? (shape_seqlen_k - seqlen_knew) * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? (shape_seqlen_k - seqlen_knew) * hdim_v : hdim_v;
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * (shape_seqlen_k - seqlen_knew)
: (shape_seqlen_k - seqlen_knew);
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
}();
const ck_tile::index_t nhead_stride_vnew = [&]() {
if(is_v_rowmajor)
@@ -546,12 +543,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
}();
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k =
(nhead_k * (shape_seqlen_k - seqlen_knew) * hdim_q);
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
const ck_tile::index_t batch_stride_v =
(nhead_k * hdim_v * (shape_seqlen_k - seqlen_knew));
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
return fmha_fwd_appendkv_args{q_buf.GetDeviceBuffer(),
@@ -729,16 +724,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
<< " GB/s" << std::flush;
k_buf.FromDevice(k_host.data());
for(int row = 0; row < shape_seqlen_k; ++row)
#if defined(ENABLE_HOST_DEBUG_PRINT)
if(!do_validation)
{
printf("[POYENC][HOST] k_host[%3d] = ", row);
for(int col = 0; col < hdim_q; ++col)
#if 0
k_buf.FromDevice(k_host.data());
for(int row = 0; row < shape_seqlen_k; ++row)
{
printf("%11.7f", ck_tile::type_convert<float>(k_host(0, 0, row, col)));
printf("[POYENC][HOST] k_host[%3d] = ", row);
for(int col = 0; col < hdim_q; ++col)
{
printf("%11.7f", ck_tile::type_convert<float>(k_host(0, 0, row, col)));
}
printf("\n");
}
printf("\n");
#endif
#if 1
v_buf.FromDevice(v_host.data());
for(int row = 0; row < shape_seqlen_k; ++row)
{
printf("[POYENC][HOST] v_host[%3d] = ", row);
for(int col = 0; col < hdim_v; ++col)
{
if(vlayout == "r")
{
printf("%11.7f", ck_tile::type_convert<float>(v_host(0, 0, row, col)));
}
else
{
printf("%11.7f", ck_tile::type_convert<float>(v_host(0, 0, col, row)));
}
}
printf("\n");
}
#endif
}
#endif
if(!do_validation)
{

View File

@@ -344,7 +344,7 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
}
}();
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_knew, args.hdim_v);
printf("[POYENC][HOST] grid size: %2d,%2d,%2d\n",
static_cast<int>(grids.x),
static_cast<int>(grids.y),