diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 860ee20d3e..805803fed4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -36,13 +36,12 @@ FMHA_FWD_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; -using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, fmha_warp_tile_{F_idx}, - fmha_block_warps_{F_idx}, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, fmha_warp_tile_{F_idx}, {F_vlayout}>; @@ -291,9 +290,12 @@ class FmhaFwdTileSize: F_bn1 : int # tile size along v head_dim F_bk1 : int # tile size along kv gemm unroll F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm : int # number of warps along q seqlen (block warps) - F_rn : int # number of warps along k seqlen(not used) - F_rk : int # number of warps along gemm-k(not used) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) F_wm : int # warp size along m (warp size) F_wn : int # warp size along n F_wk : int # warp size along k @@ -301,8 +303,8 @@ class FmhaFwdTileSize: @property def name(self) -> str: return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ - f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}" + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") @dataclass class FmhaFwdKernel: @@ -334,9 +336,12 @@ class FmhaFwdKernel: F_bn1 = self.F_tile.F_bn1, F_bk1 = self.F_tile.F_bk1, F_bk0blen = self.F_tile.F_bk0blen, - F_rm = self.F_tile.F_rm, - F_rn = self.F_tile.F_rn, - F_rk = self.F_tile.F_rk, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, F_wm = self.F_tile.F_wm, F_wn = self.F_tile.F_wn, F_wk = self.F_tile.F_wk, @@ -394,16 +399,16 @@ class FmhaFwdKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1) } else: return None diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 57360ea995..46c26b22cd 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -42,13 +42,12 @@ namespace {{ template struct kernel_runner {{ using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; -using fmha_block_warps = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_shape = ck_tile::TileFmhaShape, fmha_warp_tile, - fmha_block_warps, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, fmha_warp_tile, {F_vlayout}>; @@ -162,10 +161,12 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< fmha_pipeline_problem>; +/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving +/// store_tile_raw() data corruption issue using fmha_epilogue = ck_tile::Default2DEpilogue::OaccDataType, typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; + false, false>>; using fmha_kernel = ck_tile::FmhaFwdSplitKVCombineKernel, @@ -458,9 +459,12 @@ class FmhaFwdSplitKVKernel: F_bn1 = self.F_tile.F_bn1, F_bk1 = self.F_tile.F_bk1, F_bk0blen = self.F_tile.F_bk0blen, - F_rm = self.F_tile.F_rm, - F_rn = self.F_tile.F_rn, - F_rk = self.F_tile.F_rk, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, F_wm = self.F_tile.F_wm, F_wn = self.F_tile.F_wn, F_wk = self.F_tile.F_wk, @@ -553,16 +557,16 @@ class FmhaFwdSplitKVCombineKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 16, 16, 16, -1), - '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 16, 16, 16, -1), - '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 16, 16, 16, -1), - '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 16, 16, 16, -1), + '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1), + '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), + '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), + '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1) } else: return None diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 6d519a7ea8..14291715fb 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -557,33 +557,16 @@ bool run(const ck_tile::ArgParser& arg_parser) } #endif - struct - { - auto operator()(bool permute, - ck_tile::index_t b /*batch*/, - ck_tile::index_t h /*nhead*/, - ck_tile::index_t s /*seqlen*/, - ck_tile::index_t d /*hdim*/) - { - if(permute) - return std::array{b, h, s, d}; - else - return std::array{b, s, h, d}; - } - - auto operator()(bool permute, - ck_tile::index_t ns /*num_splits*/, - ck_tile::index_t b /*batch*/, - ck_tile::index_t h /*nhead*/, - ck_tile::index_t s /*seqlen*/, - ck_tile::index_t d /*hdim*/) - { - if(permute) - return std::array{ns, b, h, s, d}; - else - return std::array{ns, b, s, h, d}; - } - } get_lengths; + static const auto get_lengths = [](bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; bool is_v_rowmajor = vlayout == std::string("r"); @@ -635,12 +618,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor lse_acc_host( 1 < num_splits || use_kvcache - ? std::array{num_splits, shape_batch, nhead, shape_seqlen_q} + ? std::array{shape_batch, nhead, num_splits, shape_seqlen_q} : std::array{1, 1, 1, 1}); ck_tile::HostTensor o_acc_host( - 1 < num_splits || use_kvcache - ? get_lengths(o_perm, num_splits, shape_batch, nhead, shape_seqlen_q, hdim_v) - : std::array{1, 1, 1, 1, 1}); + 1 < num_splits || use_kvcache ? std::array{shape_batch, + nhead, + num_splits, + shape_seqlen_q, + hdim_v} + : std::array{1, 1, 1, 1, 1}); // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] @@ -880,7 +866,7 @@ bool run(const ck_tile::ArgParser& arg_parser) }(); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); - const ck_tile::index_t stride_o_acc = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_o_acc = (hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); @@ -906,8 +892,8 @@ bool run(const ck_tile::ArgParser& arg_parser) (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q; - const ck_tile::index_t nhead_stride_o_acc = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); @@ -922,13 +908,13 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v); - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); + const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); // setup split_stride_* arguments (only used in split-kv kernel) - const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q); - const ck_tile::index_t split_stride_o_acc = (shape_batch * nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); + const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); args.q_ptr = q_buf.GetDeviceBuffer(); args.k_ptr = k_buf.GetDeviceBuffer(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index adabda165c..8c1f6c8058 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -69,7 +69,8 @@ struct FmhaFwdKernel // sync with generate.py // clang-format off using bfs = typename FmhaPipeline::BlockFmhaShape; - using gbr = typename bfs::Gemm0BlockWarps; + using g0br = typename bfs::Gemm0BlockWarps; + using g1br = typename bfs::Gemm1BlockWarps; using gwt = typename bfs::Gemm0WarpTile; #define _SS_ std::string #define _TS_ std::to_string @@ -85,7 +86,8 @@ struct FmhaFwdKernel "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_" "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + - "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 34f75990c6..ea30025b53 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -65,7 +65,8 @@ struct FmhaFwdSplitKVKernel // sync with generate.py // clang-format off using bfs = typename FmhaPipeline::BlockFmhaShape; - using gbr = typename bfs::Gemm0BlockWarps; + using g0br = typename bfs::Gemm0BlockWarps; + using g1br = typename bfs::Gemm1BlockWarps; using gwt = typename bfs::Gemm0WarpTile; #define _SS_ std::string #define _TS_ std::to_string @@ -81,7 +82,8 @@ struct FmhaFwdSplitKVKernel "_" + (kIsGroupMode ? "group" : "batch") + "_" "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + - "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + @@ -894,7 +896,7 @@ struct FmhaFwdSplitKVKernel o_acc_ptr, make_tuple(kargs.seqlen_q, kargs.hdim_v), make_tuple(kargs.stride_o_acc, 1), - number<1>{}, + number{}, number<1>{}); return pad_tensor_view( diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp index 2d06ba1762..675a31019e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp @@ -26,8 +26,8 @@ struct FmhaFwdSplitKVTilePartitioner { // TODO: this may need tuning return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) * - ck_tile::integer_divide_ceil(hdim_v, kN1), - nhead * num_splits, + ck_tile::integer_divide_ceil(hdim_v, kN1) * num_splits, + nhead, batch_size); } @@ -42,8 +42,9 @@ struct FmhaFwdSplitKVTilePartitioner return ck_tile::make_tuple(quotient, modulus); }; - const auto [i_tile_m, i_tile_n] = f(blockIdx.x, num_tile_n1); - const auto [i_nhead, i_split] = f(blockIdx.y, num_splits); + const auto [mn, i_split] = f(blockIdx.x, num_splits); + const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1); + const index_t i_nhead = blockIdx.y; const index_t i_batch = blockIdx.z; return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 75af7be82f..6e7416ce8e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -64,6 +64,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); }(); + static constexpr index_t kAlignmentOacc = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc(); + static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); @@ -252,11 +255,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS k_dram_block_window_lengths, {adjusted_seqlen_k_start, 0}); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - auto bias_dram_window = make_tile_window( - bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N - Policy::template MakeBiasDramTileDistribution()); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( v_dram_block_window_lengths, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp index 338319ab32..b7f1f042ed 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp @@ -9,11 +9,20 @@ namespace ck_tile { // This pipeline is qkv all located in LDS -using BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy = - BlockFmhaPipelineQXKSVSCustomPolicy; +struct BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy + : BlockFmhaPipelineQXKSVSCustomPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() + { + using OaccDataType = remove_cvref_t; + + return static_cast(16 / sizeof(OaccDataType)); + } +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 1846664e7d..d9da2f088c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -39,8 +39,11 @@ struct BlockFmhaPipelineProblem using FmhaMask = remove_cvref_t; using Traits = remove_cvref_t; - static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); - static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps; + static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps; + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + + static constexpr bool kIsGroupMode = kIsGroupMode_; // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; @@ -84,8 +87,11 @@ struct BlockFmhaFwdSplitKVPipelineProblem using FmhaMask = remove_cvref_t; using Traits = remove_cvref_t; - static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); - static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps; + static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps; + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + + static constexpr bool kIsGroupMode = kIsGroupMode_; // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 281ddc07b7..6837ffdee7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -242,11 +242,11 @@ struct BlockFmhaPipelineQRKSVS {seqlen_k_start, 0}); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - auto bias_dram_window = make_tile_window( - bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N - Policy::template MakeBiasDramTileDistribution()); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); auto randval_dram_window = dropout.template MakeRandvalDramWindow( randval_dram_block_window_tmp, seqlen_k_start); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 19f569c45b..c4872def1d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -314,11 +314,11 @@ struct BlockFmhaPipelineQRKSVSAsync }(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - auto bias_dram_window = make_tile_window( - bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N - Policy::template MakeBiasDramTileDistribution()); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); auto randval_dram_window = dropout.template MakeRandvalDramWindow( randval_dram_block_window_tmp, seqlen_k_start); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index bc9ca93d09..d08a8d4899 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -9,9 +9,10 @@ namespace ck_tile { +/// NOTICE: we no-longer use this pipeline. // This pipeline is qkv all located in LDS template -struct BlockFmhaPipelineQSKSVS +struct [[deprecated]] BlockFmhaPipelineQSKSVS { using Problem = remove_cvref_t; using Policy = remove_cvref_t; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index a66d2be78b..807ad65485 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -15,6 +15,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" // TODO: remove this #define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0 @@ -64,13 +65,28 @@ struct BlockFmhaPipelineQXCustomPolicy constexpr index_t M1 = MWarp; constexpr index_t M0 = kMPerBlock / (M2 * M1); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<1, 2, 2>, - sequence<0, 0, 2>>{}); + if constexpr(1 < Problem::kNumGemm0Warps) + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + } + else + { + static_assert(MWarp == 1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + } } template @@ -80,7 +96,7 @@ struct BlockFmhaPipelineQXCustomPolicy BlockGemmProblem, @@ -129,12 +145,16 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, decltype(warp_gemm)>; - return BlockGemmARegBSmemCRegV2{}; + if constexpr(1 < Problem::kNumGemm0Warps) + return BlockGemmARegBSmemCRegV2{}; + else + return BlockGemmARegBSmemCRegOneWarpV1{}; } }; +/// NOTICE: we no-longer use this policy. template <> -struct BlockFmhaPipelineQXCustomPolicy +struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy { static constexpr bool QLoadOnce = false; @@ -364,12 +384,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(16 / sizeof(VDataType))); + constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); - // TODO: not correct! - if constexpr(total_pixels > 4) - return 4; - else - return 2; + constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) + ? kMaxVecLoad + : (total_pixels / kMinVecLoad); + + return kVecLoad; } else { @@ -383,10 +406,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy())>; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; - using CWarpDstr = typename WG::CWarpDstr; - constexpr auto vec = - CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); - return vec; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; } template @@ -395,10 +416,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy())>; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; - using CWarpDstr = typename WG::CWarpDstr; - constexpr auto vec = - CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); - return vec; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; } template @@ -449,44 +468,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy - CK_TILE_HOST_DEVICE static constexpr auto MakeQRegBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; - - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr auto q_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - q_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - - constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); - - return q_block_dstr; - } - // TODO: this is used for non async copy desc. unify in the future template CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPack = GetSmemKPackK(); constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( @@ -886,36 +873,10 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy + template CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution() { - constexpr index_t MPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t NPerBlock = Problem::BlockFmhaShape::kN0; - - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); - constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - - // Construct C-Block-HostTensor - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); - - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - - return c_block_dstr; + return BlockGemm::MakeCBlockTile().get_tile_distribution(); } template @@ -972,7 +933,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index 64a61e94d1..f2bb2200fe 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -21,10 +21,15 @@ struct TileFmhaShape using Gemm1BlockWarps = remove_cvref_t; using Gemm1WarpTile = remove_cvref_t; - static constexpr index_t NumWarps = + static constexpr index_t NumGemm0Warps = reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{}); + static constexpr index_t NumGemm1Warps = + reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}); + static_assert(NumGemm1Warps % NumGemm0Warps == 0); - static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{})); + static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); + + static_assert(std::is_same_v); static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index e708255703..4ca773479e 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index 9a5c2aae5c..728a04d830 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -157,7 +157,7 @@ struct BlockGemmARegBRegCRegV1 }); } - CK_TILE_DEVICE constexpr auto MakeCBlockTile() const + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { constexpr index_t MPerBlock = BlockGemmShape::kM; constexpr index_t NPerBlock = BlockGemmShape::kN; diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp new file mode 100644 index 0000000000..ff23f63556 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemCRegOneWarpV1 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static_assert(kBlockSize == get_warp_size(), "Check failed!"); + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + // constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + // constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + // constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + // KPerBlock == BlockGemmShape::kK, + // "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + static_assert(MWarp == 1 && NWarp == 1, "Check failed!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = 0; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + static_assert(MWarp == 1 && NWarp == 1, "Check failed!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + static_assert(decltype(c_block_dstr_encode)::NDimP == 1, "Check failed!"); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp index beab457b90..98e5538c0a 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp @@ -181,7 +181,7 @@ struct BlockGemmARegBSmemCRegV1 }); } - CK_TILE_DEVICE constexpr auto MakeCBlockTile() const + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { constexpr index_t MPerBlock = BlockGemmShape::kM; constexpr index_t NPerBlock = BlockGemmShape::kN; diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp index 4a82702c1f..173ef0a02e 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp @@ -182,7 +182,7 @@ struct BlockGemmARegBSmemCRegV2 }); } - CK_TILE_DEVICE constexpr auto MakeCBlockTile() const + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { constexpr index_t MPerBlock = BlockGemmShape::kM; constexpr index_t NPerBlock = BlockGemmShape::kN; diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp index 3d142df4d4..d28aa9e787 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp @@ -180,7 +180,7 @@ struct BlockGemmASmemBRegCRegV1 }); } - CK_TILE_DEVICE constexpr auto MakeCBlockTile() const + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { constexpr index_t MPerBlock = BlockGemmShape::kM; constexpr index_t NPerBlock = BlockGemmShape::kN; diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp index ac45221709..dc0b411352 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp @@ -167,7 +167,7 @@ struct BlockGemmASmemBSmemCRegV1 }); } - CK_TILE_DEVICE constexpr auto MakeCBlockTile() const + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { constexpr index_t MPerBlock = BlockGemmShape::kM; constexpr index_t NPerBlock = BlockGemmShape::kN; diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 63c364331d..a01265ad5d 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -22,7 +22,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, constexpr index_t idim_p_lane = NDimP - 1; - const auto ps_idx = make_array(get_block_id(), get_lane_id()); + const auto ps_idx = detail::get_partition_index(acc_tensor.get_tile_distribution()); const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();