mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Fix vnew append errro
This commit is contained in:
@@ -518,8 +518,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return i_perm ? (shape_seqlen_k - seqlen_knew)
|
||||
: nhead_k * (shape_seqlen_k - seqlen_knew);
|
||||
return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k;
|
||||
}();
|
||||
const ck_tile::index_t stride_vnew = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
@@ -528,16 +527,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
|
||||
}();
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k =
|
||||
(i_perm ? (shape_seqlen_k - seqlen_knew) * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? (shape_seqlen_k - seqlen_knew) * hdim_v : hdim_v;
|
||||
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
|
||||
else
|
||||
return i_perm ? hdim_v * (shape_seqlen_k - seqlen_knew)
|
||||
: (shape_seqlen_k - seqlen_knew);
|
||||
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_vnew = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
@@ -546,12 +543,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
|
||||
}();
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k =
|
||||
(nhead_k * (shape_seqlen_k - seqlen_knew) * hdim_q);
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
|
||||
const ck_tile::index_t batch_stride_v =
|
||||
(nhead_k * hdim_v * (shape_seqlen_k - seqlen_knew));
|
||||
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
|
||||
|
||||
return fmha_fwd_appendkv_args{q_buf.GetDeviceBuffer(),
|
||||
@@ -729,16 +724,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
|
||||
<< " GB/s" << std::flush;
|
||||
|
||||
k_buf.FromDevice(k_host.data());
|
||||
for(int row = 0; row < shape_seqlen_k; ++row)
|
||||
#if defined(ENABLE_HOST_DEBUG_PRINT)
|
||||
if(!do_validation)
|
||||
{
|
||||
printf("[POYENC][HOST] k_host[%3d] = ", row);
|
||||
for(int col = 0; col < hdim_q; ++col)
|
||||
#if 0
|
||||
k_buf.FromDevice(k_host.data());
|
||||
for(int row = 0; row < shape_seqlen_k; ++row)
|
||||
{
|
||||
printf("%11.7f", ck_tile::type_convert<float>(k_host(0, 0, row, col)));
|
||||
printf("[POYENC][HOST] k_host[%3d] = ", row);
|
||||
for(int col = 0; col < hdim_q; ++col)
|
||||
{
|
||||
printf("%11.7f", ck_tile::type_convert<float>(k_host(0, 0, row, col)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
v_buf.FromDevice(v_host.data());
|
||||
for(int row = 0; row < shape_seqlen_k; ++row)
|
||||
{
|
||||
printf("[POYENC][HOST] v_host[%3d] = ", row);
|
||||
for(int col = 0; col < hdim_v; ++col)
|
||||
{
|
||||
if(vlayout == "r")
|
||||
{
|
||||
printf("%11.7f", ck_tile::type_convert<float>(v_host(0, 0, row, col)));
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%11.7f", ck_tile::type_convert<float>(v_host(0, 0, col, row)));
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!do_validation)
|
||||
{
|
||||
|
||||
@@ -344,7 +344,7 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_knew, args.hdim_v);
|
||||
printf("[POYENC][HOST] grid size: %2d,%2d,%2d\n",
|
||||
static_cast<int>(grids.x),
|
||||
static_cast<int>(grids.y),
|
||||
|
||||
Reference in New Issue
Block a user