From a3296e00b8e4d5542e550b99a09a57b815b6ed74 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Fri, 26 Sep 2025 16:42:59 +0800 Subject: [PATCH] [CK_TILE] FMHA BWD Pad HDim to a Multiple of 8 (#2918) [ROCm/composable_kernel commit: 32773fe5cb176efd2fcbb361f183164fc6525d8a] --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 61 ++++++++++--------- example/ck_tile/01_fmha/fmha_bwd.hpp | 4 +- .../ck_tile/01_fmha/script/smoke_test_bwd.sh | 2 +- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 30 ++++----- ...k_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 16 ++--- ...a_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 16 ++--- ...ck_fmha_bwd_dq_dk_dv_pipeline_selector.hpp | 5 +- ...bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 16 ++--- ...wd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 16 ++--- .../block_fmha_bwd_pipeline_problem.hpp | 12 ++-- .../ops/fmha/pipeline/tile_fmha_traits.hpp | 17 ++++++ test/ck_tile/fmha/test_fmha_bwd.inc | 3 + 12 files changed, 110 insertions(+), 88 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 36482e94c1..bd6a9044e9 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -50,16 +50,10 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; -using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits; using fmha_mask_{F_idx} = {F_mask}; using fmha_dropout_{F_idx} = {F_dropout}; @@ -94,19 +88,19 @@ using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, false, - {F_dpad}>>; + ({F_dpad} > 0)>>; using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, false, - {F_dvpad}>>; + ({F_dvpad} > 0)>>; using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType, false, - {F_dpad}>>; + ({F_dpad} > 0)>>; using fmha_bwd_dq_dk_dv_kernel_{F_idx} = ck_tile::FmhaBwdDQDKDVKernel; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>; using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_convert_dq_bn0}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>; r = fmha_bwd_>(s, a); return r; }} @@ -278,8 +272,8 @@ class FmhaBwdDQDKDVKernel: F_hdim : int # hdim F_dtype : str # data type F_tile : FmhaBwdDQDKDVTileSize - F_dpad : str # - F_dvpad : str # + F_dpad : Literal[0, 8 ,1] + F_dvpad : Literal[0, 8 ,1] F_bias : str # F_dbias : str # F_dropout : str # @@ -320,8 +314,8 @@ class FmhaBwdDQDKDVKernel: F_wm1 = self.F_tile.F_wm1, F_wn1 = self.F_tile.F_wn1, F_wk1 = self.F_tile.F_wk1, - F_dpad = BOOL_MAP[self.F_dpad], - F_dvpad = BOOL_MAP[self.F_dvpad], + F_dpad = self.F_dpad, + F_dvpad = self.F_dvpad, F_bias = BIAS_MAP[self.F_bias], F_dbias = BOOL_MAP[self.F_dbias], F_dropout = DROPOUT_MAP[self.F_dropout], @@ -337,8 +331,8 @@ class FmhaBwdDQDKDVKernel: def name(self) -> str: def pad_name() -> str: n = '' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' + if self.F_dpad : n += f'd{self.F_dpad}' + if self.F_dvpad : n += f'dv{self.F_dvpad}' if n != '' : n = 'p' + n return n pn = pad_name() @@ -622,8 +616,8 @@ class FmhaBwdApiTrait: dbias : str dropout : str spad1d : str # spad for 1d kernels (dot/convert) - dpad : str - dvpad : str + dpad : Literal[0, 1, 8] + dvpad : Literal[0, 1, 8] deterministic : str mask_impl : str tr_load : str @@ -652,13 +646,13 @@ class FmhaBwdApiTrait: @property def dcheck(self) -> str: - if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' - else : return f'a.hdim_q % {self.bhdq} == 0' + if self.dpad == 0: return f'a.hdim_q % {self.bhdq} == 0' + else: return f'a.hdim_q % {self.dpad} == 0' @property def dvcheck(self) -> str: - if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' - else : return f'a.hdim_v % {self.bhdv} == 0' + if self.dvpad == 0: return f'a.hdim_v % {self.bhdv} == 0' + else: return f'a.hdim_v % {self.dvpad} == 0' @property def extra_cond(self) -> str: @@ -678,8 +672,9 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 + F_dvpad = 't' if self.dvpad else 'f' return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d, - F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) + F_dvpad=F_dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) @property def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: @@ -694,8 +689,9 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 + F_dpad = 't' if self.dpad else 'f' return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, - F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=self.dpad, + F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=F_dpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0) @@ -721,7 +717,7 @@ class FmhaBwdApiPool: F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], - F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=trait.dpad, F_dvpad=trait.dvpad, F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond, F_convert_dq_bn0=trait.convert_dq_bn0) @@ -794,7 +790,10 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]): tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load) - for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)): + dpad_options = itertools.product(*([[0, 8, 1]] * 2)) + tf = ["t", "f"] + for tile, mode, mask, bias, dbias, dropout, spad1d, (dpad, dvpad), deterministic in itertools.product( + tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), tf, DROPOUT_MAP.keys(), tf, dpad_options, tf): assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize" hdim = tile.F_bhdq if (mode == "group") and (spad1d == "f"): @@ -805,8 +804,12 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm continue if ("wg32" in dropout): continue - if tr_load == "t" and (dpad == "t" or dvpad == "t"): + if tr_load == "t": continue # tr_load cannot work with dpad or dvpad + else: # tr_load == "f" + # do not generate instance with only 1 of dpad/dvpad being 8 + if dpad != dvpad and dpad == 8: + continue if optdim_list != [-1]: if hdim not in optdim_list: continue diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 378ff9c9f8..6cd1cd94fa 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -392,8 +392,8 @@ template ; using BiasGradDataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; - static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; - static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; - static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; - using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr index_t kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr index_t kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; + using FmhaMask = ck_tile::remove_cvref_t; using FmhaDropout = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kHasDropout = FmhaDropout::IsDropout; @@ -100,8 +100,8 @@ struct FmhaBwdDQDKDVKernel #define _TS_ std::to_string auto pn = [&] () { std::string n; - if (kPadHeadDimQ) n += "d"; - if (kPadHeadDimV) n += "dv"; + if (kPadHeadDimQ) n += "d" + _TS_(kPadHeadDimQ); + if (kPadHeadDimV) n += "dv"+ _TS_(kPadHeadDimV); return n.empty() ? n : std::string("p") + n; }(); return _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + @@ -815,7 +815,7 @@ struct FmhaBwdDQDKDVKernel const auto q_dram = pad_tensor_view( q_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); const auto k_dram_naive = make_naive_tensor_view( k_ptr, @@ -826,7 +826,7 @@ struct FmhaBwdDQDKDVKernel const auto k_dram = pad_tensor_view( k_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( @@ -838,7 +838,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( v_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); }(); // lse and d should be fine to read unpaded data as they are not on the reduction dimension @@ -857,7 +857,7 @@ struct FmhaBwdDQDKDVKernel const auto do_dram = pad_tensor_view( do_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); auto q_dram_window = make_tile_window( q_dram, @@ -905,7 +905,7 @@ struct FmhaBwdDQDKDVKernel const auto dq_acc_dram = pad_tensor_view( dq_acc_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); return make_tile_window( dq_acc_dram, make_tuple(number{}, number{}), @@ -1089,7 +1089,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dk_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); }(); auto dv_dram = [&]() { @@ -1103,7 +1103,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dv_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); }(); auto dk_dram_window = make_tile_window( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 5e63fb714a..ea024a0257 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; @@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ(); static constexpr index_t kAlignmentK = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK(); static constexpr index_t kAlignmentV = - kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentOGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad(); static constexpr index_t kAlignmentQGrad = 1; static constexpr index_t kAlignmentKGrad = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad(); static constexpr index_t kAlignmentVGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad(); static constexpr index_t kAlignmentBias = 1; static constexpr const char* name = "kr_ktr_vr"; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index c402eaeac4..6393f227a2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; @@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ(); static constexpr index_t kAlignmentK = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK(); static constexpr index_t kAlignmentV = - kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentOGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad(); static constexpr index_t kAlignmentQGrad = 1; static constexpr index_t kAlignmentKGrad = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad(); static constexpr index_t kAlignmentVGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad(); static constexpr index_t kAlignmentBias = 1; static constexpr const char* name = "kr_ktr_vr_iglp"; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp index c3e84df934..abe024ced1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp @@ -14,7 +14,8 @@ namespace ck_tile { template class BlockFmhaBwdDQDKDVPipelineSelector { - static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV; + static constexpr bool has_dpad1 = + Problem::Traits::kPadHeadDimQ == 1 || Problem::Traits::kPadHeadDimV == 1; static constexpr bool is_decode = Problem::BlockFmhaShape::kMaxSeqLenQ > 0; public: @@ -24,7 +25,7 @@ class BlockFmhaBwdDQDKDVPipelineSelector std::conditional_t, BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR>, - std::conditional_t, BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP>>; using type = std::conditional_t, // diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 41cb4fc306..5cdb4fe1d7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; @@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ(); static constexpr index_t kAlignmentK = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK(); static constexpr index_t kAlignmentV = - kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentOGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad(); static constexpr index_t kAlignmentQGrad = 1; static constexpr index_t kAlignmentKGrad = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad(); static constexpr index_t kAlignmentVGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad(); static constexpr index_t kAlignmentBias = 1; static constexpr const char* name = "trload_kr_ktr_vr"; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 6d90429407..3d5bfcc76a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -51,8 +51,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; @@ -62,18 +62,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ(); static constexpr index_t kAlignmentK = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK(); static constexpr index_t kAlignmentV = - kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentOGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad(); static constexpr index_t kAlignmentQGrad = 1; static constexpr index_t kAlignmentKGrad = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad(); static constexpr index_t kAlignmentVGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad(); static constexpr index_t kAlignmentBias = 1; static constexpr const char* name = "trload_kr_ktr_vr"; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp index 99718a187f..38aff07093 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -57,13 +57,11 @@ struct BlockFmhaBwdPipelineProblem static constexpr bool kUseTrLoad = kUseTrLoad_; // attributes from traits - static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr auto BiasEnum = Traits::BiasEnum; - static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; - static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; - static_assert(!Traits::kPadSeqLenQ, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ"); - static_assert(!Traits::kPadSeqLenK, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ"); + static constexpr index_t kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr index_t kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr auto BiasEnum = Traits::BiasEnum; + static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; template +struct TileFmhaBwdTraits +{ + static constexpr index_t kPadHeadDimQ = kPadHeadDimQ_; + static constexpr index_t kPadHeadDimV = kPadHeadDimV_; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; + + static_assert(kPadHeadDimQ == 0 || kPadHeadDimQ == 8 || kPadHeadDimQ == 1); + static_assert(kPadHeadDimV == 0 || kPadHeadDimV == 8 || kPadHeadDimV == 1); +}; + template