diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 77b63a0c83..47cf6b3ad4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import copy @@ -8,21 +8,13 @@ import fnmatch import itertools from pathlib import Path from typing import List, Optional, Tuple, Dict, Literal +from collections import defaultdict from codegen.cmake_config import * from codegen.cpp_symbol_map import * +from codegen.utils import update_file -BWD_DQDKDV_PIPELINE_MAP = { - "kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP", - "kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR", -} - -BWD_DQDKDV_PIPELINE_ENUM_MAP = { - "kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP", - "kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR", -} - FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py @@ -56,8 +48,8 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; -using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, +using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits; -using fmha_bwd_pipeline_{F_idx} = {F_pipeline}; +using fmha_bwd_pipeline_{F_idx} = ck_tile::BlockFmhaBwdDQDKDVPipeline; using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, - {F_skpad}, + false, {F_dpad}>>; using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, - {F_skpad}, + false, {F_dvpad}>>; using fmha_bwd_dq_dk_dv_kernel_{F_idx} = @@ -115,13 +107,10 @@ using fmha_bwd_dq_dk_dv_kernel_{F_idx} = using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, - {F_pipeline_enum}, fmha_mask_{F_idx}, fmha_dropout_{F_idx}, {F_bias}, {F_dbias}, - {F_spad}, - {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>; @@ -195,15 +184,18 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < """ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; - using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dpad}, {F_deterministic}>; + ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>; r = fmha_bwd_(s, a); return r; }} """ +# M0 size for 1d kernels (dot/convert) +M0_1D = 64 + # GEMM0: Q@K=S^T # GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) # GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) @@ -249,8 +241,6 @@ class FmhaBwdDQDKDVKernel: F_hdim : int # hdim F_dtype : str # data type F_tile : FmhaBwdDQDKDVTileSize - F_spad : str # true/false - F_skpad : str # F_dpad : str # F_dvpad : str # F_bias : str # @@ -259,7 +249,6 @@ class FmhaBwdDQDKDVKernel: F_mask : str # value from MASK_MAP F_mode : str # value from MODE_MAP F_deterministic : str # - F_pipeline : str # mask_impl : str # @property @@ -293,8 +282,6 @@ class FmhaBwdDQDKDVKernel: F_wm1 = self.F_tile.F_wm1, F_wn1 = self.F_tile.F_wn1, F_wk1 = self.F_tile.F_wk1, - F_spad = BOOL_MAP[self.F_spad], - F_skpad = BOOL_MAP[self.F_skpad], F_dpad = BOOL_MAP[self.F_dpad], F_dvpad = BOOL_MAP[self.F_dvpad], F_bias = BIAS_MAP[self.F_bias], @@ -304,21 +291,18 @@ class FmhaBwdDQDKDVKernel: F_mask = get_mask_map(self.mask_impl)[self.F_mask], F_mode = MODE_MAP[self.F_mode], F_deterministic = BOOL_MAP[self.F_deterministic], - F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], - F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) + ) @property def name(self) -> str: def pad_name() -> str: n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' if self.F_dpad == 't' : n += 'd' if self.F_dvpad == 't' : n += 'dv' if n != '' : n = 'p' + n return n pn = pad_name() - n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}' + n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name if pn != '' : n += f'_{pn}' else: n += '_npad' @@ -347,20 +331,15 @@ class FmhaBwdDQDKDVKernel: return self.name + ".cpp" # TODO: design a more practical way to do it -# this is current supported tile size & pipeline. +# this is current supported tile size. def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"], - '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"], - '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"], - # '160' : [FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - # "kr_ktr_vr_iglp", "kr_ktr_vr"], - '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"] + '32' : FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + '64' : FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + '128' : FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + # '160' : FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + '256' : FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), } else: return None @@ -375,7 +354,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot typename FmhaBwdTypeConfig::ODataType, typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, - /* BlockSize = */ 64, + /* BlockSize = M0 = */ 64, {F_hdim}, {F_mode}, fmha_bwd_dot_do_o_trait_{F_idx}>; @@ -580,7 +559,6 @@ class FmhaBwdConvertQGradKernel: @dataclass(frozen=True) class FmhaBwdApiTrait: idx : int # this is not a tunable, but a counter to differentiate symbol - pipeline : str # sync with fmha_bwd_traits<>, to generate fallback calls hdim : int dtype : str # data type @@ -590,9 +568,7 @@ class FmhaBwdApiTrait: bias : str dbias : str dropout : str - spad : str - spad1 : str # spad for dot/convert kernel - skpad : str + spad1d : str # spad for 1d kernels (dot/convert) dpad : str dvpad : str deterministic : str @@ -611,24 +587,14 @@ class FmhaBwdApiTrait: def bhdv(self) -> int: return self.tile.F_bhdv - def scheck(self, spad1 : str) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.spad == 't' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} != 0' - elif self.spad == 'f' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0' - else: # self.skpad == 'f' and skpad1 == 'f' - return 'a.seqlen_q % 64 == 0' - @property - def skcheck(self) -> str: + def scheck(self) -> str: if self.mode == 'group': return 'true' # always support - elif self.skpad == 't': - return f'a.seqlen_k % {self.bn0} != 0' - else: - return f'a.seqlen_k % {self.bn0} == 0' + elif self.spad1d == 't': + return f'a.seqlen_q % {M0_1D} != 0' + else: # self.spad1d == 'f' + return f'a.seqlen_q % {M0_1D} == 0' @property def dcheck(self) -> str: @@ -647,14 +613,14 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 - return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1, + return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d, F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) @property def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile, - F_spad=self.spad, F_skpad=self.skpad, F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, - F_dbias=self.dbias, F_dropout=self.dropout, F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, F_pipeline=self.pipeline, mask_impl=self.mask_impl) + F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout, + F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl) @property def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel: @@ -664,48 +630,46 @@ class FmhaBwdApiTrait: return 2 return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, - F_bm0=64, F_bn0=self.tile.F_bn0, F_spad=self.spad, F_dpad=self.dpad, + F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), F_deterministic=self.deterministic) class FmhaBwdApiPool: def __init__(self, mask_impl): - self.dq_dk_dv_pool = dict() + self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(list)) self.mask_impl = mask_impl def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None: # TODO: do we need to check duplication? - if trait.dtype not in self.dq_dk_dv_pool.keys(): - self.dq_dk_dv_pool[trait.dtype] = dict() - if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): - self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() - self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + @staticmethod + def if_(i: int) -> str: + return 'if' if i == 0 else 'else if' + + def _api_innders(self, traits: List[FmhaBwdApiTrait]) -> str: + inners = "" + i = 0 + for trait in traits: + inners += FMHA_BWD_API_INNER_DISPATCH.format(F_if=self.if_(i), F_mode=MODE_MAP[trait.mode], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], + F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], + F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_deterministic=BOOL_MAP[trait.deterministic]) + i += 1 + return inners + @property def api(self) -> str: per_dtypes=str() - for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): + for i, dtype in enumerate(self.dq_dk_dv_pool): per_hdim_case=str() - for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): + for j, hdim in enumerate(self.dq_dk_dv_pool[dtype]): traits=self.dq_dk_dv_pool[dtype][hdim] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - for spad1 in ["t", "f"]: - if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")): - continue - inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], - F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], - F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], - F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_deterministic=BOOL_MAP[trait.deterministic]) - - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + inners = self._api_innders(traits) + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(j), F_hdim=hdim, F_inner_dispatch=inners) + per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(i), F_dtype=dtype, F_hdim_case=per_hdim_case) if not per_dtypes: # empty string we add some ignore to suppress warning in api per_dtypes += ' (void)t ; (void)s ; (void)a;' @@ -730,21 +694,16 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) if d is None: continue - for hdim_str, mode, mask, bias, dbias, dropout, spad, spad1, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 6)): - tile = d[hdim_str][0] - ppl = d[hdim_str][1] + for hdim_str, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)): + tile = d[hdim_str] hdim = int(hdim_str) - if (mode == "group") and (spad == "f" or skpad == "f"): - continue - if (spad1 == "f") and (spad == "t" or mode == "group"): + if (mode == "group") and (spad1d == "f"): continue if ((bias == "no" or bias == "alibi") and dbias == "t"): continue if ("wg32" in dropout): continue - if (dpad == "t" or dvpad == "t"): - ppl = d[hdim_str][2] - t = FmhaBwdApiTrait(idx=0, pipeline=ppl, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad=spad, spad1=spad1, skpad=skpad, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl) + t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl) if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): continue @@ -808,13 +767,13 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list) - (output_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + update_file(output_dir / FMHA_BWD_API_FILENAME, api_pool.api) for k in kernels_dot_do_o: - (output_dir / k.filename).write_text(k.template) + update_file(output_dir / k.filename, k.template) for k in kernels_convert_dq: - (output_dir / k.filename).write_text(k.template) + update_file(output_dir / k.filename, k.template) for k in kernels_dq_dk_dv: - (output_dir / k.filename).write_text(k.template) + update_file(output_dir / k.filename, k.template) def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None: diff --git a/example/ck_tile/01_fmha/codegen/utils.py b/example/ck_tile/01_fmha/codegen/utils.py new file mode 100644 index 0000000000..e3bbb18c42 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/utils.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import os.path as path + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 9179dbd9be..c999cf750e 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -357,31 +357,25 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) template struct fmha_bwd_dq_dk_dv_traits_ { - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; - using FmhaMask = ck_tile::remove_cvref_t; - using FmhaDropout = ck_tile::remove_cvref_t; - static constexpr auto BiasEnum = BiasEnum_; - static constexpr bool kHasBiasGrad = kHasBiasGrad_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadSK = kPadSK_; - static constexpr bool kPadD = kPadD_; - static constexpr bool kPadDv = kPadDv_; - static constexpr bool kIsDeterministic = kIsDeterministic_; + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + using FmhaMask = ck_tile::remove_cvref_t; + using FmhaDropout = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsDeterministic = kIsDeterministic_; }; template diff --git a/include/ck_tile/core/tensor/null_tile_window.hpp b/include/ck_tile/core/tensor/null_tile_window.hpp index de99be1965..f7eca73afb 100644 --- a/include/ck_tile/core/tensor/null_tile_window.hpp +++ b/include/ck_tile/core/tensor/null_tile_window.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -53,10 +53,13 @@ struct is_null_tile_window> : public std::true_type }; } // namespace impl +template +constexpr bool is_null_tile_window_v = impl::is_null_tile_window>::value; + template CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&) { - return impl::is_null_tile_window>::value; + return is_null_tile_window_v>; } template diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 30bea193b7..313de5f29a 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -24,8 +24,8 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index ce3bf8fe8d..8b184b18f3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -52,8 +52,6 @@ struct FmhaBwdDQDKDVKernel using BiasGradDataType = ck_tile::remove_cvref_t; static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; - static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; @@ -85,8 +83,6 @@ struct FmhaBwdDQDKDVKernel #define _TS_ std::to_string auto pn = [&] () { std::string n; - if (kPadSeqLenQ) n += "s"; - if (kPadSeqLenK) n += "sk"; if (kPadHeadDimQ) n += "d"; if (kPadHeadDimV) n += "dv"; return n.empty() ? n : std::string("p") + n; }(); @@ -100,7 +96,7 @@ struct FmhaBwdDQDKDVKernel "r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" + - ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "_npad" : "_" + pn) + + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) + (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ); @@ -1221,7 +1217,7 @@ struct FmhaBwdDQDKDVKernel const auto q_dram = pad_tensor_view( q_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); const auto k_dram_naive = make_naive_tensor_view( k_ptr, @@ -1232,7 +1228,7 @@ struct FmhaBwdDQDKDVKernel const auto k_dram = pad_tensor_view( k_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( @@ -1244,22 +1240,15 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( v_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); - const auto lse_dram = [&]() { - const auto lse_dram_naive = make_naive_tensor_view_packed( - lse_ptr, make_tuple(kargs.seqlen_q), number<1>{}); - return pad_tensor_view( - lse_dram_naive, make_tuple(number{}), sequence{}); - }(); + // lse and d should be fine to read unpaded data as they are not on the reduction dimension + const auto lse_dram = make_naive_tensor_view_packed( + lse_ptr, make_tuple(kargs.seqlen_q), number{}); - const auto d_dram = [&]() { - const auto d_dram_naive = make_naive_tensor_view_packed( - d_ptr, make_tuple(kargs.seqlen_q), number<1>{}); - return pad_tensor_view( - d_dram_naive, make_tuple(number{}), sequence{}); - }(); + const auto d_dram = make_naive_tensor_view_packed( + d_ptr, make_tuple(kargs.seqlen_q), number{}); const auto do_dram_naive = make_naive_tensor_view( do_ptr, @@ -1270,7 +1259,7 @@ struct FmhaBwdDQDKDVKernel const auto do_dram = pad_tensor_view( do_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); auto q_dram_window = make_tile_window( q_dram, @@ -1313,7 +1302,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dq_acc_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); return make_tile_window( @@ -1341,7 +1330,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dq_acc_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); return make_tile_window( @@ -1376,9 +1365,8 @@ struct FmhaBwdDQDKDVKernel number{}, number<1>{}); - return pad_tensor_view(bias_dram_naive, - bias_dram_window_lengths, - sequence{}); + return pad_tensor_view( + bias_dram_naive, bias_dram_window_lengths, sequence{}); }(); return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0}); @@ -1406,9 +1394,8 @@ struct FmhaBwdDQDKDVKernel number{}, number<1>{}); - return pad_tensor_view(dbias_dram_naive, - bias_dram_window_lengths, - sequence{}); + return pad_tensor_view( + dbias_dram_naive, bias_dram_window_lengths, sequence{}); }(); return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0}); @@ -1495,9 +1482,8 @@ struct FmhaBwdDQDKDVKernel number<1>{}, number<1>{}); - return pad_tensor_view(randval_dram_naive, - randval_dram_window_lengths, - sequence{}); + return pad_tensor_view( + randval_dram_naive, randval_dram_window_lengths, sequence{}); }(); return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0}); @@ -1550,7 +1536,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dk_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); auto dv_dram = [&]() { @@ -1564,7 +1550,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dv_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); auto dk_dram_window = make_tile_window( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 8a13c0b060..1f11569533 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -49,8 +49,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; @@ -72,8 +70,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); static constexpr index_t kAlignmentVGrad = kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); - static constexpr index_t kAlignmentBias = - kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + static constexpr index_t kAlignmentBias = 1; static constexpr const char* name = "kr_ktr_vr"; @@ -554,7 +551,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR }); } - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { bool need_perpixel_check = mask.IsEdgeTile( seqlen_q_step, k_origin.at(number<0>{}), number{}, number{}); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index c88b058d32..967fe2362d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -49,8 +49,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; @@ -72,8 +70,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); static constexpr index_t kAlignmentVGrad = kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); - static constexpr index_t kAlignmentBias = - kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + static constexpr index_t kAlignmentBias = 1; static constexpr const char* name = "kr_ktr_vr_iglp"; @@ -590,7 +587,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP }); } - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { bool need_perpixel_check = mask.IsEdgeTile( seqlen_q_step, k_origin.at(number<0>{}), number{}, number{}); @@ -849,7 +845,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP }); } - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { bool need_perpixel_check = mask.IsEdgeTile( seqlen_q_step, k_origin.at(number<0>{}), number{}, number{}); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp new file mode 100644 index 0000000000..80c311de86 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp" + +namespace ck_tile { + +template +class BlockFmhaBwdDQDKDVPipelineSelector +{ + static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV; + + public: + using type = std::conditional_t, + BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP>; +}; + +template +class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector::type +{ + public: + static constexpr const char* name = "auto"; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp deleted file mode 100644 index 27f58ef2f8..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp +++ /dev/null @@ -1,15 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -namespace ck_tile { - -// This class is used for codegen pattern matching -enum class BlockFmhaBwdPipelineEnum -{ - KRKTRVR_IGLP = 0, - KRKTRVR, -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp index c4c4a745a7..f6c79c7db6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -55,13 +55,13 @@ struct BlockFmhaBwdPipelineProblem static constexpr bool kIsDeterministic = kIsDeterministic_; // attributes from traits - static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static_assert(!Traits::kPadSeqLenQ, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ"); + static_assert(!Traits::kPadSeqLenK, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ"); }; template