From 4e6c28522cf242479d7b07b20553344f055789f1 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 25 Jun 2024 10:12:13 +0000 Subject: [PATCH] Fix wrong K values after appending --- .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 123 ++++++++---------- example/ck_tile/01_fmha/fmha_fwd.cpp | 13 ++ example/ck_tile/01_fmha/fmha_fwd.hpp | 36 ++--- .../fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 108 ++++++++------- .../fmha_fwd_appendkv_tile_partitioner.hpp | 41 +++--- .../block_fmha_fwd_appendkv_pipeline.hpp | 62 ++++++--- ...a_fwd_appendkv_pipeline_default_policy.hpp | 30 ++++- ...ock_fmha_fwd_appendkv_pipeline_problem.hpp | 26 +++- 8 files changed, 250 insertions(+), 189 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 0fb1807c5b..f0d0fd720a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -13,7 +13,6 @@ from codegen.cmake_config import * from codegen.cpp_symbol_map import * from codegen.ops.fmha_fwd import ( - FmhaFwdTileSize, FmhaFwdApiTrait, DTYPE_BITS, FMHA_FWD_KERNEL_HEADER, @@ -25,17 +24,6 @@ from codegen.ops.fmha_fwd import ( FMHA_FWD_APPENDKV_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; -using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; -using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; - -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape; - using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad}, {F_skpad}, {F_dpad}, @@ -46,7 +34,11 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl typename FmhaFwdTypeConfig::QDataType, typename FmhaFwdTypeConfig::KDataType, typename FmhaFwdTypeConfig::VDataType, - fmha_shape_{F_idx}, + {F_bs}, + {F_bsk}, + {F_bd}, + {F_bdv}, + {F_vlayout}, {F_mode}, fmha_trait_{F_idx}>; @@ -54,10 +46,10 @@ using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline< fmha_pipeline_problem_{F_idx}>; using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdAppendKVKernel, + ck_tile::FmhaFwdAppendKVKernel, fmha_pipeline_{F_idx}>; -using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, +using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -86,7 +78,7 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_fwd_appendkv_(s, a); }} """ @@ -97,12 +89,10 @@ class FmhaFwdAppendKVApiTrait: hdim : str dtype : str # data type mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0blen : int + bs : int # tile size along q seqlen + bsk : int # tile size along k seqlen + bd : int # tile size along qk gemm unroll + bdv : int # tile size along kv gemm unroll vlayout : str spad : str skpad : str @@ -111,30 +101,30 @@ class FmhaFwdAppendKVApiTrait: @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-'+\ f'{self.vlayout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' @property def scheck(self) -> str: if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' - else : return f'a.seqlen_q % {self.bm0} == 0' + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/' + else : return f'a.seqlen_q % {self.bs} == 0' @property def skcheck(self) -> str: if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' - else : return f'a.seqlen_k % {self.bn0} == 0' + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bsk} != 0*/' + else : return f'a.seqlen_k % {self.bsk} == 0' @property def dcheck(self) -> str: - if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {self.bk0blen} == 0' + if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {self.bd} == 0' @property def dvcheck(self) -> str: - if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {self.bk0blen} == 0' + if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {self.bdv} == 0' @dataclass class FmhaFwdAppendKVPipeline: @@ -186,21 +176,32 @@ class FmhaFwdAppendKVApiPool: inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, - F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) +@dataclass +class FmhaFwdAppendKVTileSize: + F_bs : int # tile size along q seqlen + F_bsk : int # tile size along k seqlen + F_bd : int # tile size along qk gemm unroll + F_bdv : int # tile size along kv gemm unroll + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + @dataclass class FmhaFwdAppendKVKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim F_dtype : str # data type F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize + F_tile : FmhaFwdAppendKVTileSize F_pipeline : FmhaFwdAppendKVPipeline mask_impl : str @@ -212,18 +213,10 @@ class FmhaFwdAppendKVKernel: F_idx = self.F_idx, F_hdim = self.F_hdim, F_dtype = DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0blen = self.F_tile.F_bk0blen, - F_rm = self.F_tile.F_rm, - F_rn = self.F_tile.F_rn, - F_rk = self.F_tile.F_rk, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, + F_bs = self.F_tile.F_bs, + F_bsk = self.F_tile.F_bsk, + F_bd = self.F_tile.F_bd, + F_bdv = self.F_tile.F_bdv, F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], @@ -247,12 +240,10 @@ class FmhaFwdAppendKVKernel: hdim=str(self.F_hdim), dtype=self.F_dtype, mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0blen=self.F_tile.F_bk0blen, + bs=self.F_tile.F_bs, + bsk=self.F_tile.F_bsk, + bd=self.F_tile.F_bd, + bdv=self.F_tile.F_bdv, vlayout=self.F_pipeline.F_vlayout, spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, @@ -261,24 +252,24 @@ class FmhaFwdAppendKVKernel: # TODO: design a more practical way to do it # this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: +def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), + '32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1), + '64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), + '128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), + '256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) + '64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), + '128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), + '256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1) } else: return None -def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: +def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: @@ -289,8 +280,8 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f')) - pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f')) + # pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f')) + # pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f')) pipelines.append(FmhaFwdAppendKVPipeline('row', 't', 't', 't', 't')) pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't')) @@ -306,7 +297,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm api_pool = FmhaFwdAppendKVApiPool(mask_impl) for dtype in DTYPE_MAP.keys(): - d = get_fmha_fwd_tile_dict_from_dtype(dtype) + d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype) if d == None: continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): @@ -347,14 +338,14 @@ def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_appendkv_api(api_pool, output_dir) def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: with file_path.open('a') as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") \ No newline at end of file diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 9a2d8e8158..075b1c93a8 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -472,7 +472,9 @@ bool run(const ck_tile::ArgParser& arg_parser) q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); + knew_buf.ToDevice(knew_host.data()); v_buf.ToDevice(v_host.data()); + vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() @@ -727,6 +729,17 @@ bool run(const ck_tile::ArgParser& arg_parser) << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec << " GB/s" << std::flush; + k_buf.FromDevice(k_host.data()); + for(int row = 0; row < shape_seqlen_k; ++row) + { + printf("[POYENC][HOST] k_host[%3d] = ", row); + for(int col = 0; col < hdim_q; ++col) + { + printf("%11.7f", ck_tile::type_convert(k_host(0, 0, row, col))); + } + printf("\n"); + } + if(!do_validation) { std::cout << std::flush << std::endl; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 79672d930f..438a076604 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -345,6 +345,10 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) }(); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + printf("[POYENC][HOST] grid size: %2d,%2d,%2d\n", + static_cast(grids.x), + static_cast(grids.y), + static_cast(grids.z)); return ck_tile::make_tuple(kargs, grids); } @@ -400,33 +404,29 @@ float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); template struct fmha_fwd_appendkv_traits_ { - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr ck_tile::index_t kM0 = kM0_; - static constexpr ck_tile::index_t kN0 = kN0_; - static constexpr ck_tile::index_t kK0 = kK0_; - static constexpr ck_tile::index_t kN1 = kN1_; - static constexpr ck_tile::index_t kK1 = kK1_; - static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; - static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_; + static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_; + static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_; + static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; // static constexpr bool kApplyRotray = kApplyRotray_; static constexpr bool kPadS = kPadS_; - static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadSk = kPadSk_; static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index b9580dd626..6ab5c8c96a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -43,11 +43,8 @@ struct FmhaFwdAppendKVKernel __host__ static std::string GetName() { - // sync with generate.py - // clang-format off - using bfs = typename FmhaPipeline::BlockFmhaShape; - using gbr = typename bfs::Gemm0BlockWarps; - using gwt = typename bfs::Gemm0WarpTile; +// sync with generate.py +// clang-format off #define _SS_ std::string #define _TS_ std::to_string auto pn = [&] () { @@ -58,13 +55,10 @@ struct FmhaFwdAppendKVKernel if (kPadHeadDimV) n += "dv"; return n.empty() ? n : std::string("p") + n; }(); return - _SS_("fmha_fwd_appendkv_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s::name) + + _SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kTileSizeD) + "_" + _SS_(t2s::name) + "_" + (kIsGroupMode ? "group" : "batch") + "_" - "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + - _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + - "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + - (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + + "b" + _TS_(FmhaPipeline::kTileSizeS) + "x" + _TS_(FmhaPipeline::kTileSizeSk) + "x" + _TS_(FmhaPipeline::kTileSizeD) + "x" + + _TS_(FmhaPipeline::kTileSizeDv) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn); #undef _SS_ #undef _TS_ @@ -271,11 +265,10 @@ struct FmhaFwdAppendKVKernel __shared__ char smem_ptr[GetSmemSize()]; // divide problem - const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = - TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v); + const auto [i_tile_sk, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v); - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); - const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + const index_t i_sk = __builtin_amdgcn_readfirstlane(i_tile_sk * FmhaPipeline::kTileSizeSk); + // const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); long_index_t batch_offset_q = 0; long_index_t batch_offset_k = 0; @@ -306,12 +299,14 @@ struct FmhaFwdAppendKVKernel const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; +#if 0 // # of required blocks is different in each groups, terminate unnecessary blocks // earlier if(kargs.seqlen_q <= i_m0) { return; } +#endif if(kargs.seqlen_k_ptr != nullptr) { @@ -334,16 +329,16 @@ struct FmhaFwdAppendKVKernel const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + static_cast(i_nhead) * kargs.nhead_stride_q + batch_offset_q; - const KDataType* k_ptr = - reinterpret_cast(kargs.k_ptr) + + KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + batch_offset_k; const KDataType* knew_ptr = reinterpret_cast(kargs.knew_ptr) + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_knew + batch_offset_knew; - const VDataType* v_ptr = - reinterpret_cast(kargs.v_ptr) + + VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + batch_offset_v; const VDataType* vnew_ptr = @@ -362,7 +357,7 @@ struct FmhaFwdAppendKVKernel return pad_tensor_view( q_dram_naive, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); }(); const auto k_dram = [&]() { @@ -375,7 +370,7 @@ struct FmhaFwdAppendKVKernel return pad_tensor_view( k_dram_naive, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); }(); const auto knew_dram = [&]() { @@ -388,7 +383,7 @@ struct FmhaFwdAppendKVKernel return pad_tensor_view( knew_dram_naive, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); }(); const auto v_dram = [&]() { @@ -408,10 +403,10 @@ struct FmhaFwdAppendKVKernel make_tuple(sequence<1>{}, sequence<0>{}), make_tuple(sequence<0>{}, sequence<1>{})); - return pad_tensor_view( - v_dram_transposed, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(v_dram_transposed, + make_tuple(number{}, + number{}), + sequence{}); } else { @@ -422,10 +417,10 @@ struct FmhaFwdAppendKVKernel number{}, number<1>{}); - return pad_tensor_view( - v_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(v_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); } }(); const auto vnew_dram = [&]() { @@ -445,10 +440,10 @@ struct FmhaFwdAppendKVKernel make_tuple(sequence<1>{}, sequence<0>{}), make_tuple(sequence<0>{}, sequence<1>{})); - return pad_tensor_view( - vnew_dram_transposed, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(vnew_dram_transposed, + make_tuple(number{}, + number{}), + sequence{}); } else { @@ -459,35 +454,36 @@ struct FmhaFwdAppendKVKernel number{}, number<1>{}); - return pad_tensor_view( - vnew_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(vnew_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); } }(); - - auto q_dram_window = - make_tile_window(q_dram, - make_tuple(number{}, number{}), - {i_m0, 0}); + auto q_dram_window = make_tile_window( + q_dram, + make_tuple(number{}, number{}), + {0, 0}); auto k_dram_window = make_tile_window( - k_dram, make_tuple(number{}, number{}), {0, 0}); + k_dram, + make_tuple(number{}, number{}), + {kargs.seqlen_k - kargs.seqlen_knew, 0}); - auto knew_dram_window = - make_tile_window(knew_dram, - make_tuple(number{}, number{}), - {0, 0}); + auto knew_dram_window = make_tile_window( + knew_dram, + make_tuple(number{}, number{}), + {i_sk, 0}); - auto v_dram_window = - make_tile_window(v_dram, - make_tuple(number{}, number{}), - {i_n1, 0}); + auto v_dram_window = make_tile_window( + v_dram, + make_tuple(number{}, number{}), + {kargs.seqlen_k - kargs.seqlen_knew, 0}); - auto vnew_dram_window = - make_tile_window(vnew_dram, - make_tuple(number{}, number{}), - {i_n1, 0}); + auto vnew_dram_window = make_tile_window( + vnew_dram, + make_tuple(number{}, number{}), + {i_sk, 0}); FmhaPipeline{}(q_dram_window, k_dram_window, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp index 641cc47f3b..b4732a04ea 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp @@ -7,35 +7,30 @@ namespace ck_tile { -template +template struct FmhaFwdAppendKVTilePartitioner { - using BlockFmhaShape = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_; + static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_; + static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_; + static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_; - static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; - static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; - static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; - static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; - static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; + static_assert(kTileSizeD == kTileSizeDv); - static constexpr const char* name = "shb"; - - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t seqlen_knew, + ck_tile::index_t /*hdim_v*/) { + assert(ck_tile::integer_divide_ceil(hdim_v, kTileSizeD) == 1); + // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * - ck_tile::integer_divide_ceil(hdim_v_, kN1), - nhead_, - batch_size_); + return dim3(ck_tile::integer_divide_ceil(seqlen_knew, kTileSizeSk), nhead, batch_size); } - CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t /*hdim_v*/) { - // const index_t num_tile_m0 = seqlen_q / kM0; - const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); + // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); const index_t i_block = blockIdx.x; const index_t i_nhead = blockIdx.y; @@ -46,10 +41,10 @@ struct FmhaFwdAppendKVTilePartitioner index_t modulus = dividend - quotient * divisor; return ck_tile::make_tuple(quotient, modulus); }; + (void)f; + // const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + return ck_tile::make_tuple(i_block, i_nhead, i_batch); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 683e3a8659..82e5951f6a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -17,17 +17,14 @@ struct BlockFmhaFwdAppendKVPipeline using KDataType = typename Problem::KDataType; using VDataType = typename Problem::VDataType; - using BlockFmhaShape = typename Problem::BlockFmhaShape; - using VLayout = typename BlockFmhaShape::VLayout; + using VLayout = typename Problem::VLayout; static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kN1 = BlockFmhaShape::kN1; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + static constexpr index_t kTileSizeS = Problem::kTileSizeS; + static constexpr index_t kTileSizeSk = Problem::kTileSizeSk; + static constexpr index_t kTileSizeD = Problem::kTileSizeD; + static constexpr index_t kTileSizeDv = Problem::kTileSizeDv; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; @@ -53,19 +50,19 @@ struct BlockFmhaFwdAppendKVPipeline return Problem::kBlockPerCu; else { - if constexpr(kK0BlockLength <= 32) + if constexpr(kTileSizeD <= 32) { return 2; } - else if constexpr(kK0BlockLength <= 64) + else if constexpr(kTileSizeD <= 64) { return 3; } - else if constexpr(kK0BlockLength <= 128) + else if constexpr(kTileSizeD <= 128) { return 2; } - else if constexpr(kK0BlockLength <= 256) + else if constexpr(kTileSizeD <= 256) { return 1; } @@ -90,11 +87,11 @@ struct BlockFmhaFwdAppendKVPipeline CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KElementFunction& k_element_func, const KnewDramBlockWindowTmp& knew_dram_block_window_tmp, // N0*K0 tile const KnewElementFunction& knew_element_func, - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VElementFunction& v_element_func, const VnewDramBlockWindowTmp& vnew_dram_block_window_tmp, // N1*K1 tile const VnewElementFunction& vnew_element_func, @@ -111,6 +108,39 @@ struct BlockFmhaFwdAppendKVPipeline (void)vnew_dram_block_window_tmp; (void)vnew_element_func; (void)smem_ptr; + + auto knew_dram_block_window = + make_tile_window(knew_dram_block_window_tmp.get_bottom_tensor_view(), + knew_dram_block_window_tmp.get_window_lengths(), + {0, 0}); + + auto knew_dram_window = + make_tile_window(knew_dram_block_window.get_bottom_tensor_view(), + knew_dram_block_window.get_window_lengths(), + knew_dram_block_window.get_window_origin(), + Policy::template MakeKnewDramTileDistribution()); + + auto knew_tile = load_tile(knew_dram_window); + if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0) + { + constexpr auto spans = decltype(knew_tile)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + knew_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = tile_idx.at(number<0>{}); + const auto col = tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + printf("[POYENC][DEVICE] knew_tile(%2d,%2d): %11.7f\n", + row, + col, + type_convert(knew_tile(i_j_idx))); + }); + }); + } + store_tile(k_dram_block_window_tmp, knew_tile); } template CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, - const KDramBlockWindowTmp& k_dram_block_window_tmp, + KDramBlockWindowTmp& k_dram_block_window_tmp, const KnewDramBlockWindowTmp& knew_dram_block_window_tmp, - const VDramBlockWindowTmp& v_dram_block_window_tmp, + VDramBlockWindowTmp& v_dram_block_window_tmp, const VnewDramBlockWindowTmp& vnew_dram_block_window_tmp, void* smem_ptr) const { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp index b5580a1695..542ba03e18 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp @@ -28,13 +28,13 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { - using VLayout = remove_cvref_t; + using VLayout = remove_cvref_t; using VDataType = remove_cvref_t; if constexpr(std::is_same_v) { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kNPerBlock = Problem::kTileSizeSk; + constexpr index_t kKPerBlock = Problem::kTileSizeDv; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; // TODO: not correct! @@ -54,6 +54,30 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy { return 1; } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution() + { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::kTileSizeSk; + constexpr index_t kKPerBlock = Problem::kTileSizeD; + + constexpr index_t K1 = 16 / sizeof(KDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp index 0c107b039b..9e831e5f1e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp @@ -10,20 +10,32 @@ namespace ck_tile { template struct BlockFmhaFwdAppendKVPipelineProblem { - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using BlockFmhaShape = remove_cvref_t; - using Traits = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using Traits = remove_cvref_t; - static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + static constexpr index_t kBlockSize = 256; static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr index_t kTileSizeS = kTileSizeS_; + static constexpr index_t kTileSizeSk = kTileSizeSk_; + static constexpr index_t kTileSizeD = kTileSizeD_; + static constexpr index_t kTileSizeDv = kTileSizeDv_; + + using VLayout = std::conditional_t; + // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;