diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index c61db0c4cb..fdb869337d 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -536,7 +536,7 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) } }(); - dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_knew, args.hdim_v); + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.seqlen_knew); HOST_DEBUG_STMTS { printf("[HOST] grid size: %2d,%2d,%2d\n", diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp index 97c9b960c2..1190520995 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp @@ -19,11 +19,11 @@ struct FmhaFwdAppendKVTilePartitioner CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, - ck_tile::index_t seqlen_q, + ck_tile::index_t max_seqlen_q, ck_tile::index_t seqlen_knew) { // TODO: this may need tuning - return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kM0), + return dim3(std::max(ck_tile::integer_divide_ceil(max_seqlen_q, kM0), ck_tile::integer_divide_ceil(seqlen_knew, kN0)), nhead, batch_size);