[CK_TILE] More fmha splitkv optimizations (#1588)

* Use pre-defined constants for readability

* Use vector write for o_acc tensor

* Remove no-longer used policy method

* Deprecate no-longer used policy/pipeline

* Specify gemm0/gemm1 block warps separately in codegen

* Fix wrong ps_idx creation logic

* Add single-warp block gemm

* Supoprt single-warp gemm0

* Make MakeCBlockTile() as static method

* Use MakeCBlockTile() to get underlying tile distribution

* Use kNumGemm1Warps to compute # threads for gemm1

* Put normal case in the if clause

* Refine fmha splitkv block mapping

* Refine & fix the lse_acc/o_acc layout

* Fix wrong LDS size for K tile

* Use kK0=64 for hdim=128,256 fmha splitkv kernels

* Use kK1=64 for hdim=32,64,128 fmha splitkv kernels

* Undo kK0/kK1 changes

* Use more reasonable GetAlignmentV() computation

* Using store_tile() in fmha splitkv kernel epilogue
This commit is contained in:
Po Yen Chen
2024-10-26 18:35:45 +08:00
committed by GitHub
parent 37f7afed1e
commit 54f0e6f4bb
22 changed files with 422 additions and 199 deletions

View File

@@ -69,7 +69,8 @@ struct FmhaFwdKernel
// sync with generate.py
// clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps;
using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
@@ -85,7 +86,8 @@ struct FmhaFwdKernel
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +

View File

@@ -65,7 +65,8 @@ struct FmhaFwdSplitKVKernel
// sync with generate.py
// clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps;
using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
@@ -81,7 +82,8 @@ struct FmhaFwdSplitKVKernel
"_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
@@ -894,7 +896,7 @@ struct FmhaFwdSplitKVKernel
o_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o_acc, 1),
number<1>{},
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
return pad_tensor_view(

View File

@@ -26,8 +26,8 @@ struct FmhaFwdSplitKVTilePartitioner
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead * num_splits,
ck_tile::integer_divide_ceil(hdim_v, kN1) * num_splits,
nhead,
batch_size);
}
@@ -42,8 +42,9 @@ struct FmhaFwdSplitKVTilePartitioner
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(blockIdx.x, num_tile_n1);
const auto [i_nhead, i_split] = f(blockIdx.y, num_splits);
const auto [mn, i_split] = f(blockIdx.x, num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);