Fix wrong grid size

This commit is contained in:
PoYen, Chen
2024-07-23 14:20:52 +00:00
parent 52b47810bb
commit ca4b208b60
2 changed files with 3 additions and 3 deletions

View File

@@ -536,7 +536,7 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
}
}();
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_knew, args.hdim_v);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.seqlen_knew);
HOST_DEBUG_STMTS
{
printf("[HOST] grid size: %2d,%2d,%2d\n",

View File

@@ -19,11 +19,11 @@ struct FmhaFwdAppendKVTilePartitioner
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_q,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_knew)
{
// TODO: this may need tuning
return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kM0),
return dim3(std::max(ck_tile::integer_divide_ceil(max_seqlen_q, kM0),
ck_tile::integer_divide_ceil(seqlen_knew, kN0)),
nhead,
batch_size);