From 5be2aae20ecf0bc48ca061939f4d607790472a2b Mon Sep 17 00:00:00 2001 From: shay-li77 Date: Sat, 2 Aug 2025 00:16:37 +0800 Subject: [PATCH] atomic16 base impl formatting code fix compile error fix conflict use global_atomic_pk_add instr remove redundant modifications formatting code remove seqstart_dq_acc in varlen mode formatting code --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 55 +++-- example/ck_tile/01_fmha/fmha_bwd.cpp | 94 ++++++-- example/ck_tile/01_fmha/fmha_bwd.hpp | 9 +- .../ck_tile/01_fmha/script/smoke_test_bwd.sh | 14 +- .../core/arch/amd_buffer_addressing.hpp | 48 ++++ .../arch/amd_buffer_addressing_builtins.hpp | 48 ++++ include/ck_tile/core/tensor/tensor_view.hpp | 2 +- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 2 +- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 228 +++++++++++++++--- .../pipeline/block_fmha_bwd_convert_dq.hpp | 38 ++- ...k_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 51 ++-- ...a_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 70 ++++-- ...bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 2 + ...wd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 2 + ...block_fmha_bwd_pipeline_default_policy.hpp | 50 ++++ .../block_fmha_bwd_pipeline_problem.hpp | 6 + 16 files changed, 603 insertions(+), 116 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 0391191fb2..f7ca08469b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -83,6 +83,7 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< fmha_bwd_shape_{F_idx}, {F_mode}, {F_deterministic}, + {F_atomic32}, fmha_mask_{F_idx}, fmha_dropout_{F_idx}, {F_trload}, @@ -124,6 +125,7 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dpad}, {F_dvpad}, {F_deterministic}, + {F_atomic32}, {F_trload}, {F_maxq}>; @@ -218,10 +220,10 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && - ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ + ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_dq_reduce_check})) {{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; - using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_atomic32}, {F_trload}, {F_maxq}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_atomic32}>; r = fmha_bwd_>(s, a); return r; }} @@ -285,8 +287,9 @@ class FmhaBwdDQDKDVKernel: F_mask : str # value from MASK_MAP F_mode : str # value from MODE_MAP F_deterministic : str # + F_atomic32 : str # will not be used if deterministic set to 1 mask_impl : str # - F_trload : str # + F_trload : str # @property def template(self) -> str: @@ -328,6 +331,7 @@ class FmhaBwdDQDKDVKernel: F_mask = get_mask_map(self.mask_impl)[self.F_mask], F_mode = MODE_MAP[self.F_mode], F_deterministic = BOOL_MAP[self.F_deterministic], + F_atomic32 = BOOL_MAP[self.F_atomic32], F_trload = BOOL_MAP[self.F_trload], F_maxq = self.F_tile.max_seq_q ) @@ -362,7 +366,8 @@ class FmhaBwdDQDKDVKernel: else: n += '_ndropout' if self.F_deterministic == 't' : n += '_deterministic' - else: n += '_ndeterministic' + elif self.F_atomic32 == 't' : n += '_atomic32' + else: n += '_atomic16' if self.F_trload == 't' : n += '_trload' else: n += '_ntrload' @@ -504,8 +509,10 @@ using fmha_bwd_convert_dq_pipeline_problem_{F_idx} = {F_bm0}, {F_bn0}, {F_hdim}, + {F_wn0}, {F_mode}, {F_deterministic}, + {F_atomic32}, fmha_bwd_convert_dq_trait_{F_idx}>; using fmha_bwd_convert_dq_{F_idx} = @@ -519,7 +526,8 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_mode}, {F_spad}, {F_dpad}, - {F_deterministic}>; + {F_deterministic}, + {F_atomic32}>; #include @@ -563,11 +571,13 @@ class FmhaBwdConvertQGradKernel: F_dtype : str # data type F_bm0 : int # tile size along q seqlen (block size) F_bn0 : int # tile size along k seqlen + F_wn0 : int # warp size along n in gemm0/gemm2/gemm4 F_spad : str # true/false F_dpad : str # F_mode : str # value from MODE_MAP F_occupancy : int # F_deterministic : str # + F_atomic32 : str disabled : bool # sometimes this kernel is not used @property @@ -579,11 +589,13 @@ class FmhaBwdConvertQGradKernel: F_dtype = BWD_DTYPE_MAP[self.F_dtype], F_bm0 = self.F_bm0, F_bn0 = self.F_bn0, + F_wn0 = self.F_wn0, F_spad = BOOL_MAP[self.F_spad], F_dpad = BOOL_MAP[self.F_dpad], F_mode = MODE_MAP[self.F_mode], F_occupancy = self.F_occupancy, - F_deterministic = BOOL_MAP[self.F_deterministic]) + F_deterministic = BOOL_MAP[self.F_deterministic], + F_atomic32 = BOOL_MAP[self.F_atomic32]) @property def name(self) -> str: @@ -594,11 +606,12 @@ class FmhaBwdConvertQGradKernel: if n != '' : n = 'p' + n return n pn = pad_name() - n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}" + n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_wn0{self.F_wn0}_{self.F_mode}_o{self.F_occupancy}" if pn != '' : n += f'_{pn}' else: n += '_npad' if self.F_deterministic == 't' : n += '_deterministic' - else: n += '_ndeterministic' + elif self.F_atomic32 == 't' : n += '_atomic32' + else: n += '_atomic16' return n @property @@ -621,6 +634,7 @@ class FmhaBwdApiTrait: dpad : str dvpad : str deterministic : str + atomic32 : str mask_impl : str tr_load : str @@ -656,6 +670,12 @@ class FmhaBwdApiTrait: if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' else : return f'a.hdim_v % {self.bhdv} == 0' + @property + def dq_reduce_check(self) -> str: + if self.deterministic == 't' : return 't.is_deterministic' + elif self.atomic32 == 't' : return '!t.is_deterministic && t.is_atomic_fp32' + else : return '!t.is_deterministic && !t.is_atomic_fp32' + @property def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel: # TODO: we don't support tuning yet, so pick up one value for pad/occupancy @@ -670,7 +690,8 @@ class FmhaBwdApiTrait: def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile, F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout, - F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl, F_trload=self.tr_load) + F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, F_atomic32=self.atomic32, + mask_impl=self.mask_impl, F_trload=self.tr_load) @property def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel: @@ -680,9 +701,9 @@ class FmhaBwdApiTrait: return 2 return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, - F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad, + F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_wn0=self.tile.F_wn0, F_spad=self.spad1d, F_dpad=self.dpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), - F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0) + F_deterministic=self.deterministic, F_atomic32=self.atomic32, disabled=self.tile.max_seq_q != 0) class FmhaBwdApiPool: def __init__(self, mask_impl): @@ -705,9 +726,9 @@ class FmhaBwdApiPool: inners += FMHA_BWD_API_INNER_DISPATCH.format(F_if=self.if_(i), F_mode=MODE_MAP[trait.mode], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], - F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], + F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_dq_reduce_check=trait.dq_reduce_check, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, + F_deterministic=BOOL_MAP[trait.deterministic], F_atomic32=BOOL_MAP[trait.atomic32], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled]) i += 1 return inners @@ -778,7 +799,7 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]): tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load) - for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)): + for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic, atomic32 in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 5)): assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize" hdim = tile.F_bhdq if (mode == "group") and (spad1d == "f"): @@ -787,11 +808,13 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm continue if ((bias == "no" or bias == "alibi") and dbias == "t"): continue + if ((deterministic == 't' or tr_load == "t") and atomic32 == 'f'): + continue if ("wg32" in dropout): continue if tr_load == "t" and (dpad == "t" or dvpad == "t"): continue # tr_load cannot work with dpad or dvpad - t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load) + t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, atomic32=atomic32, mask_impl=mask_impl, tr_load=tr_load) if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): continue diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 9f1e0f6948..a3c18fa876 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -94,7 +94,8 @@ auto create_args(int argc, char* argv[]) .insert("deterministic", "0", "if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion " - "will not be used"); + "will not be used") + .insert("atomic_fp32", "1", "if set to 0 will use atomic fp16/bf16"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -122,7 +123,19 @@ auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) return ck_tile::make_tuple(rtol, atol); } -template +ck_tile::index_t get_bit_ceil(const ck_tile::index_t dim_value) +{ + unsigned un = static_cast(dim_value); + un |= un >> 1; + un |= un >> 2; + un |= un >> 4; + un |= un >> 8; + un |= un >> 16; + un++; + return static_cast(un); +} + +template bool run(const ck_tile::ArgParser& arg_parser) { std::string data_type = arg_parser.get_str("prec"); @@ -198,6 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser) int stream_repeat = arg_parser.get_int("repeat"); bool kname = arg_parser.get_bool("kname"); bool deterministic = arg_parser.get_bool("deterministic"); + bool atomic_fp32 = arg_parser.get_bool("atomic_fp32"); ck_tile::stream_config stream_config{nullptr, true, @@ -226,6 +240,7 @@ bool run(const ck_tile::ArgParser& arg_parser) using KGradDataType = typename TypeConfig::KGradDataType; using VGradDataType = typename TypeConfig::VGradDataType; using BiasGradDataType = typename TypeConfig::BiasGradDataType; + using QGradAccDataType = std::conditional_t; // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; @@ -277,12 +292,26 @@ bool run(const ck_tile::ArgParser& arg_parser) return std::array{b, s, h, d}; }; + // for dq_acc padding in atomic16 + constexpr ck_tile::index_t seqlen_dq_acc_tile_size = 16; + const ck_tile::index_t hdim_q_pad = get_bit_ceil(hdim_q); + const ck_tile::index_t hdim_q_dq_acc = atomic_fp32 ? hdim_q : hdim_q_pad; + const ck_tile::index_t max_seqlen_q_aligned = + ck_tile::integer_least_multiple(max_seqlen_q, seqlen_dq_acc_tile_size); + // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); const ck_tile::index_t shape_seqlen_q = (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + const ck_tile::index_t shape_seqlen_dq_acc_batch_mode = + atomic_fp32 ? seqlen_q : ck_tile::integer_least_multiple(seqlen_q, seqlen_dq_acc_tile_size); + const ck_tile::index_t shape_seqlen_dq_acc_group_mode = + atomic_fp32 ? seqstart_q_host.back() : max_seqlen_q_aligned * batch; + const ck_tile::index_t shape_seqlen_dq_acc = + (mode == mode_enum::batch ? shape_seqlen_dq_acc_batch_mode + : shape_seqlen_dq_acc_group_mode); const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64; const ck_tile::index_t nsplits = deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1; @@ -323,10 +352,16 @@ bool run(const ck_tile::ArgParser& arg_parser) use_dbias ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); - ck_tile::HostTensor dq_acc_host( - i_perm - ? std::array{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q} - : std::array{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q}); + + bool dq_acc_perm = i_perm || !atomic_fp32; // need to permute for atomic16 + ck_tile::HostTensor dq_acc_host( + dq_acc_perm ? std::array{nsplits, + shape_batch, + nhead, + shape_seqlen_dq_acc, + hdim_q_dq_acc} + : std::array{ + nsplits, shape_batch, shape_seqlen_dq_acc, nhead, hdim_q_dq_acc}); if(init_method == 0) { @@ -438,7 +473,8 @@ bool run(const ck_tile::ArgParser& arg_parser) use_dbias, p_drop > 0.0f, s_randval, - deterministic}; + deterministic, + atomic_fp32}; auto fmha_args = [&]() { assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, @@ -455,6 +491,8 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q); const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k); + const ck_tile::index_t stride_dq_acc = + (dq_acc_perm ? hdim_q_dq_acc : nhead * hdim_q_dq_acc); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); @@ -466,6 +504,8 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q; const ck_tile::index_t nhead_stride_dbias = (i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k); + const ck_tile::index_t nhead_stride_dq_acc = + (dq_acc_perm ? shape_seqlen_dq_acc * hdim_q_dq_acc : hdim_q_dq_acc); // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); @@ -478,9 +518,9 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_dq_acc = (nhead * shape_seqlen_dq_acc * hdim_q_dq_acc); const ck_tile::index_t split_stride_dq_acc = - (shape_batch * nhead * shape_seqlen_q * hdim_q); - + (shape_batch * nhead * shape_seqlen_dq_acc * hdim_q_dq_acc); const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) { if(drop_prefs) { @@ -516,6 +556,7 @@ bool run(const ck_tile::ArgParser& arg_parser) batch, max_seqlen_q, max_seqlen_k, + max_seqlen_q_aligned, hdim_q, hdim_v, nhead, @@ -529,8 +570,8 @@ bool run(const ck_tile::ArgParser& arg_parser) stride_o, stride_randval, stride_do, - stride_q, // stride_dq_acc - stride_q, // stride_dq + stride_dq_acc, // stride_dq_acc + stride_q, // stride_dq stride_dk, stride_dv, stride_dbias, @@ -542,10 +583,10 @@ bool run(const ck_tile::ArgParser& arg_parser) nhead_stride_randval, nhead_stride_do, nhead_stride_lsed, - nhead_stride_q, // nhead_stride_dq_acc - nhead_stride_q, // nhead_stride_dq - nhead_stride_k, // nhead_stride_dk - nhead_stride_v, // nhead_stride_dv + nhead_stride_dq_acc, // nhead_stride_dq_acc + nhead_stride_q, // nhead_stride_dq + nhead_stride_k, // nhead_stride_dk + nhead_stride_v, // nhead_stride_dv nhead_stride_dbias, batch_stride_q, batch_stride_k, @@ -555,8 +596,8 @@ bool run(const ck_tile::ArgParser& arg_parser) batch_stride_randval, batch_stride_do, batch_stride_lsed, - batch_stride_q, // batch_stride_dq_acc - batch_stride_q, // batch_stride_dq + batch_stride_dq_acc, // batch_stride_dq_acc + batch_stride_q, // batch_stride_dq batch_stride_dk, batch_stride_dv, batch_stride_dbias, @@ -985,13 +1026,28 @@ int main(int argc, char* argv[]) return -1; const std::string data_type = arg_parser.get_str("prec"); + const bool atomic_fp32 = arg_parser.get_bool("atomic_fp32"); if(data_type == "fp16") { - return run(arg_parser) ? 0 : -2; + if(atomic_fp32) + { + return run(arg_parser) ? 0 : -2; + } + else + { + return run(arg_parser) ? 0 : -2; + } } else if(data_type == "bf16") { - return run(arg_parser) ? 0 : -2; + if(atomic_fp32) + { + return run(arg_parser) ? 0 : -2; + } + else + { + return run(arg_parser) ? 0 : -2; + } } return -3; diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 8d35b2d12c..0dc700b8c0 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -98,6 +98,7 @@ struct fmha_bwd_args ck_tile::index_t batch; ck_tile::index_t max_seqlen_q; ck_tile::index_t max_seqlen_k; + ck_tile::index_t max_seqlen_q_aligned; ck_tile::index_t hdim_q; ck_tile::index_t hdim_v; ck_tile::index_t nhead_q; @@ -180,6 +181,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.seqstart_q_ptr, args.seqstart_k_ptr, args.seqlen_k_ptr, + args.max_seqlen_q_aligned, args.hdim_q, args.hdim_v, args.nhead_q, @@ -332,6 +334,7 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) args.dq_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, + args.max_seqlen_q_aligned, args.hdim_q, args.stride_dq, args.stride_dq_acc, @@ -371,6 +374,7 @@ template struct fmha_bwd_dq_dk_dv_traits_ @@ -412,7 +416,8 @@ template + bool kIsDeterministic_, + bool kAtomic32_ = true> struct fmha_bwd_convert_dq_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -421,6 +426,7 @@ struct fmha_bwd_convert_dq_traits_ static constexpr bool kPadS = kPadS_; static constexpr bool kPadD = kPadD_; static constexpr bool kIsDeterministic = kIsDeterministic_; + static constexpr bool kAtomic32 = kAtomic32_; }; template @@ -445,6 +451,7 @@ struct fmha_bwd_traits bool has_dropout; bool is_store_randval; bool is_deterministic; + bool is_atomic_fp32; // TODO: padding check is inside this api }; template diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index d123f842a2..8e416cadfb 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -18,13 +18,14 @@ for bias in "n" "a" ; do for dbias in 0 ; do for p_drop in 0.0 0.2 ; do for deterministic in 0 ; do +for atomic_fp32 in 0 1 ; do -$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS done done @@ -34,4 +35,5 @@ done done done done +done set +x diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7a9c017eb2..8c976b9f41 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -790,6 +790,34 @@ struct buffer_atomic_add_if } }; +template +struct buffer_atomic_add_if +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t /*s_offset*/, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 4); + auto save_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = float; + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "global_atomic_pk_add_f16 %0, %1, %2 offset:%3\n" + "s_mov_b64 exec %5" + : + : "v"(v_offset), + "v"(bit_cast(value)), + "s"(res.xy), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + template struct buffer_atomic_add; @@ -813,6 +841,26 @@ struct buffer_atomic_add } }; +template +struct buffer_atomic_add +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t /*s_offset*/, + index_t i_offset /*max 0xFFF*/, + index_t /*flag = 1*/) + { + static_assert(sizeof(T) == 4); + using mbuf_t = float; + asm volatile("global_atomic_pk_add_f16 %0, %1, %2 offset:%3" + : + : "v"(v_offset), "v"(bit_cast(value)), "s"(res.xy), "n"(i_offset) + : "memory"); + } +}; + namespace impl { // below type indicate the data type used for buffer load inline asm // clang-format off diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 4013b51479..49b7386e58 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -658,6 +658,34 @@ struct buffer_atomic_add_if } }; +template +struct buffer_atomic_add_if +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t /*s_offset*/, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 4); + auto save_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = float; + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "global_atomic_pk_add_f16 %0, %1, %2 offset:%3\n" + "s_mov_b64 exec %5" + : + : "v"(v_offset), + "v"(bit_cast(value)), + "s"(res.xy), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + template struct buffer_atomic_add; @@ -681,6 +709,26 @@ struct buffer_atomic_add } }; +template +struct buffer_atomic_add +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t /*s_offset*/, + index_t i_offset /*max 0xFFF*/, + index_t /*flag = 1*/) + { + static_assert(sizeof(T) == 4); + using mbuf_t = float; + asm volatile("global_atomic_pk_add_f16 %0, %1, %2 offset:%3" + : + : "v"(v_offset), "v"(bit_cast(value)), "s"(res.xy), "n"(i_offset) + : "memory"); + } +}; + namespace impl { // below type indicate the data type used for buffer load inline asm // clang-format off diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 6fa8f898e5..fb209ba827 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -455,7 +455,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* __restrict__ p, auto buffer_view = make_buffer_view(p, desc.get_element_space_size()); - return tensor_view{buffer_view, desc}; + return tensor_view{buffer_view, desc}; } template {}, number{}), {i_m0, i_n1}); - EpiloguePipeline{}(o_dram_window, o_acc_tile); + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); } }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index c1f85cb5e6..55570748de 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -71,15 +71,20 @@ struct FmhaBwdDQDKDVKernel static constexpr bool kHasDropout = FmhaDropout::IsDropout; static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval; static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic; + static constexpr bool kIsAtomic32 = FmhaPipeline::kIsAtomic32; static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad; static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ; static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0)); + static_assert(!kUseTrLoad || kIsAtomic32); + static_assert(!kIsDeterministic || kIsAtomic32); #if defined(__gfx950__) static constexpr bool kIsAvialable = true; #else static constexpr bool kIsAvialable = !kUseTrLoad; #endif + using QGradAccDataType = std::conditional_t; + // clang-format off template struct t2s; template <> struct t2s { static constexpr const char * name = "fp16"; }; @@ -116,7 +121,7 @@ struct FmhaBwdDQDKDVKernel ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) + - (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload"); + (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : (kIsAtomic32 ? "_atomic32" : "_atomic16")) + (kUseTrLoad ? "_trload" : "_ntrload"); #undef _SS_ #undef _TS_ // clang-format on @@ -274,6 +279,11 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t split_stride_dq_acc = 0; }; + struct FmhaBwdAtomic16GroupModeKargs + { + ck_tile::index_t max_seqlen_q_aligned = 0; + }; + struct FmhaBwdBatchModeKargs : FmhaBwdCommonKargs, std::conditional_t>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -518,6 +529,7 @@ struct FmhaBwdDQDKDVKernel const void* seqstart_q_ptr, const void* seqstart_k_ptr, const void* seqlen_k_ptr, + ck_tile::index_t max_seqlen_q_aligned, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -589,6 +601,7 @@ struct FmhaBwdDQDKDVKernel {}, // placeholder for mask {}, // placeholder for dropout {}, // placeholder for deterministic + {}, // placeholder for atomic16 reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; @@ -644,6 +657,11 @@ struct FmhaBwdDQDKDVKernel kargs.split_stride_dq_acc = split_stride_dq_acc; } + if constexpr(!kIsAtomic32) + { + kargs.max_seqlen_q_aligned = max_seqlen_q_aligned; + } + return kargs; } @@ -707,13 +725,22 @@ struct FmhaBwdDQDKDVKernel // get starting offset for each batch const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + long_index_t dq_acc_start = 0; + if constexpr(kIsAtomic32) + { + dq_acc_start = kargs.seqstart_q_ptr[i_batch]; + } + else + { + dq_acc_start = kargs.max_seqlen_q_aligned * i_batch; + } batch_offset_q = query_start * kargs.stride_q; batch_offset_k = key_start * kargs.stride_k; batch_offset_v = key_start * kargs.stride_v; batch_offset_do = query_start * kargs.stride_do; batch_offset_lsed = query_start; - batch_offset_dq_acc = query_start * kargs.stride_dq_acc; + batch_offset_dq_acc = dq_acc_start * kargs.stride_dq_acc; batch_offset_dk = key_start * kargs.stride_dk; batch_offset_dv = key_start * kargs.stride_dv; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) @@ -879,7 +906,9 @@ struct FmhaBwdDQDKDVKernel auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() { constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic; - using DType = std::conditional_t; + + using DType = std:: + conditional_t; auto dq_acc_ptr = reinterpret_cast(kargs.dq_acc_ptr) + [&]() { if constexpr(kUseKSplit) @@ -893,17 +922,71 @@ struct FmhaBwdDQDKDVKernel constexpr auto DstInMemOp = conditional_expr( memory_operation_enum::set, memory_operation_enum::atomic_add); - const auto dq_acc_dram_naive = - make_naive_tensor_view( - dq_acc_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_dq_acc, 1), - number{}, - number<1>{}); - const auto dq_acc_dram = pad_tensor_view( - dq_acc_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + + auto dq_acc_dram = [&]() { + if constexpr(kIsAtomic32) + { + + const auto dq_acc_dram_naive = + make_naive_tensor_view( + dq_acc_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_dq_acc, 1), + number{}, + number<1>{}); + return pad_tensor_view( + dq_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + constexpr index_t m_pack = 2; // dword alignment for atomic 16 instr. + constexpr index_t mfma_m1_per_lane = 4; + constexpr index_t m1_pack_num = mfma_m1_per_lane / m_pack; + constexpr index_t mfma_n_lane = FmhaPipeline::kGemm4WarpN; + constexpr index_t mfma_m_lane = get_warp_size() / mfma_n_lane; + constexpr index_t m_align_size = mfma_m1_per_lane * mfma_m_lane; + + static_assert( + FmhaPipeline::kM0 % m_align_size == 0, + "tiling size in the m direction must be divisible by the m align size."); + + index_t M0 = (kargs.seqlen_q + FmhaPipeline::kM0 - 1) / m_align_size; + constexpr auto dq_acc_n = FmhaPipeline::kQKHeaddim; + constexpr index_t N0 = dq_acc_n / mfma_n_lane; + + const auto q_grad_dram_desc_0 = make_naive_tensor_descriptor( + make_tuple(M0, + number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + const auto q_grad_dram_desc = transform_tensor_descriptor( + q_grad_dram_desc_0, + make_tuple( + make_merge_transform(make_tuple(M0, + number{}, + number{}, + number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 3, 2, 5>{}, sequence<1, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view(dq_acc_ptr, + q_grad_dram_desc); + } + }(); return make_tile_window( dq_acc_dram, make_tuple(number{}, number{}), @@ -1430,14 +1513,18 @@ struct FmhaBwdConvertQGradKernel static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0; static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0; static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim; + static constexpr ck_tile::index_t kGemm4WarpN = FmhaBwdConvertQGrad::kGemm4WarpN; using AccDataType = ck_tile::remove_cvref_t; using QGradDataType = ck_tile::remove_cvref_t; + using QGradAccDataType = + ck_tile::remove_cvref_t; static constexpr bool kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode; static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ; static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ; static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic; + static constexpr bool kIsAtomic32 = FmhaBwdConvertQGrad::kIsAtomic32; // clang-format off template struct t2s; @@ -1463,7 +1550,7 @@ struct FmhaBwdConvertQGradKernel + "b" + _TS_(kM0) + "x" + _TS_(kN0) + "_" + (kIsGroupMode ? "group" : "batch") + "_" + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) - + (kIsDeterministic ? "_deterministic" : "_ndeterministic") ; + + (kIsDeterministic ? "_deterministic" : (kIsAtomic32 ? "_atomic32" : "_atomic16")) ; #undef _SS_ #undef _TS_ // clang-format on @@ -1498,6 +1585,11 @@ struct FmhaBwdConvertQGradKernel ck_tile::index_t split_stride_dq_acc = 0; }; + struct FmhaBwdConvertQGradAtomic16GroupModeKargs + { + ck_tile::index_t max_seqlen_q_aligned = 0; + }; + struct FmhaBwdConvertQGradBatchModeKargs : FmhaBwdConvertQGradCommonKargs, std::conditional_t> + FmhaBwdConvertQGradEmptyKargs<0>>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -1564,6 +1659,7 @@ struct FmhaBwdConvertQGradKernel void* dq_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + ck_tile::index_t max_seqlen_q_aligned, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, @@ -1580,7 +1676,8 @@ struct FmhaBwdConvertQGradKernel stride_dq_acc, nhead_stride_dq, nhead_stride_dq_acc}, - {}, + {}, // placeholder for deterministic + {}, // placeholder for atomic16 reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr)}; @@ -1589,6 +1686,11 @@ struct FmhaBwdConvertQGradKernel kargs.split_stride_dq_acc = split_stride_dq_acc; } + if constexpr(!kIsAtomic32) + { + kargs.max_seqlen_q_aligned = max_seqlen_q_aligned; + } + return kargs; } @@ -1624,8 +1726,17 @@ struct FmhaBwdConvertQGradKernel { // get starting offset for each batch const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - batch_offset_dq = query_start * kargs.stride_dq; - batch_offset_dq_acc = query_start * kargs.stride_dq_acc; + long_index_t dq_acc_start = 0; + if constexpr(kIsAtomic32) + { + dq_acc_start = kargs.seqstart_q_ptr[i_batch]; + } + else + { + dq_acc_start = kargs.max_seqlen_q_aligned * i_batch; + } + batch_offset_dq = query_start * kargs.stride_dq; + batch_offset_dq_acc = dq_acc_start * kargs.stride_dq_acc; // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; @@ -1676,20 +1787,75 @@ struct FmhaBwdConvertQGradKernel } else { - const AccDataType* dq_acc_ptr = - reinterpret_cast(kargs.dq_acc_ptr) + + const QGradAccDataType* dq_acc_ptr = + reinterpret_cast(kargs.dq_acc_ptr) + static_cast(i_nhead_) * (kargs.nhead_stride_dq_acc) + batch_offset_dq_acc; + if constexpr(kIsAtomic32) + { - auto dq_acc_dram_naive = make_naive_tensor_view( - dq_acc_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_dq_acc, 1), - number{}, - number<1>{}); - return pad_tensor_view(dq_acc_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + auto dq_acc_dram_naive = make_naive_tensor_view( + dq_acc_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_dq_acc, 1), + number{}, + number<1>{}); + return pad_tensor_view(dq_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + constexpr index_t m_pack = 2; // dword alignment for atomic 16 instr. + constexpr index_t mfma_m1_per_lane = 4; + constexpr index_t m1_pack_num = mfma_m1_per_lane / m_pack; + constexpr index_t mfma_n_lane = kGemm4WarpN; + constexpr index_t mfma_m_lane = get_warp_size() / mfma_n_lane; + constexpr index_t m_align_size = mfma_m1_per_lane * mfma_m_lane; + + static_assert( + kM0 % m_align_size == 0, + "tiling size in the m direction must be divisible by the m align size."); + + index_t M0 = (kargs.seqlen_q + m_align_size - 1) / m_align_size; + constexpr auto dq_acc_n = kQKHeaddim; + constexpr index_t N0 = dq_acc_n / mfma_n_lane; + + const auto q_grad_dram_desc_0 = make_naive_tensor_descriptor( + make_tuple(M0, + number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + const auto q_grad_dram_desc = transform_tensor_descriptor( + q_grad_dram_desc_0, + make_tuple( + make_merge_transform(make_tuple(M0, + number{}, + number{}, + number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 3, 2, 5>{}, sequence<1, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + auto dq_acc_dram_view = make_tensor_view( + dq_acc_ptr, q_grad_dram_desc); + return pad_tensor_view( + dq_acc_dram_view, + make_tuple(number{}, number{}), + sequence{}); // we have already padded the dram buffer + // in headdim direction + } } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp index 3da1104169..fad323f344 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp @@ -20,11 +20,15 @@ struct BlockFmhaBwdConvertQGrad static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kQKHeaddim = Problem::kQKHeaddim; + static constexpr index_t kGemm4WarpN = Problem::kGemm4WarpN; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + static constexpr bool kIsAtomic32 = Problem::kIsAtomic32; + + using QGradAccDataType = std::conditional_t; static constexpr index_t kAlignmentQGradAcc = kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGradAcc(); @@ -40,7 +44,7 @@ struct BlockFmhaBwdConvertQGrad QGradDramBlockWindowTmp& dq_dram_block_window_tmp) const { static_assert( - std::is_same_v> && std::is_same_v>, @@ -48,16 +52,32 @@ struct BlockFmhaBwdConvertQGrad static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!"); - auto dq_acc_dram_window = - make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(), - dq_acc_dram_block_window_tmp.get_window_lengths(), - dq_acc_dram_block_window_tmp.get_window_origin(), - Policy::template MakePostQGradDramTileDistribution()); + if constexpr(kIsAtomic32) + { + auto dq_acc_dram_window = + make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(), + dq_acc_dram_block_window_tmp.get_window_lengths(), + dq_acc_dram_block_window_tmp.get_window_origin(), + Policy::template MakePostQGradDramTileDistribution()); - auto dq_acc = load_tile(dq_acc_dram_window); - const auto dq = cast_tile(dq_acc); + auto dq_acc = load_tile(dq_acc_dram_window); + const auto dq = cast_tile(dq_acc); - store_tile(dq_dram_block_window_tmp, dq); + store_tile(dq_dram_block_window_tmp, dq); + } + else + { + auto dq_acc_dram_window = make_tile_window( + dq_acc_dram_block_window_tmp.get_bottom_tensor_view(), + dq_acc_dram_block_window_tmp.get_window_lengths(), + dq_acc_dram_block_window_tmp.get_window_origin(), + Policy::template MakePostQGradAccAtomic16DramTileDistribution()); + auto shuffled_dq = make_static_distributed_tensor( + Policy::template MakePostQGradAtomic16DramTileDistribution()); + auto dq_acc = load_tile(dq_acc_dram_window); + shuffle_tile(shuffled_dq, dq_acc); + store_tile(dq_dram_block_window_tmp, shuffled_dq); + } } // Reduce + Convert diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 5e63fb714a..377002ebaa 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -38,15 +38,16 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK2 = BlockFmhaShape::kK2; - static constexpr index_t kK3 = BlockFmhaShape::kK3; - static constexpr index_t kK4 = BlockFmhaShape::kK4; - static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; - static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + static constexpr index_t kGemm4WarpN = BlockFmhaShape::Gemm0WarpTile::at(ck_tile::number<1>{}); static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; @@ -54,6 +55,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + static constexpr bool kIsAtomic32 = Problem::kIsAtomic32; static constexpr bool kUseTrLoad = Problem::kUseTrLoad; static_assert(!kUseTrLoad, "This pipeline does not use trload!"); @@ -468,14 +470,26 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR {0, 0}, Policy::template MakeShuffledBiasTileDistribution()); - // ----------------------------Loop write out------------------------------// - auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), - dq_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}); - using SPBlockTileType = decltype(gemm_0.MakeCBlockTile()); using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile()); using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + // ----------------------------Loop write out------------------------------// + auto dq_dram_window = [&]() { + if constexpr(kIsAtomic32) + { + return make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + } + else + { + return make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + decltype(cast_tile( + QGradBlockTileType{}))::get_tile_distribution()); + } + }(); index_t i_total_loops = 0; index_t seqlen_q_step = seqlen_q_start; @@ -750,7 +764,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR } else { - update_tile(dq_dram_window, dq_acc); + if constexpr(kIsAtomic32) + { + update_tile(dq_dram_window, dq_acc); + } + else + { + update_tile(dq_dram_window, cast_tile(dq_acc)); + } } move_tile_window(dq_dram_window, {kM0, 0}); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index b883aad155..8e78f5cfa8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -38,15 +38,16 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK2 = BlockFmhaShape::kK2; - static constexpr index_t kK3 = BlockFmhaShape::kK3; - static constexpr index_t kK4 = BlockFmhaShape::kK4; - static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; - static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + static constexpr index_t kGemm4WarpN = BlockFmhaShape::Gemm0WarpTile::at(ck_tile::number<1>{}); static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; @@ -54,6 +55,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + static constexpr bool kIsAtomic32 = Problem::kIsAtomic32; static constexpr bool kUseTrLoad = Problem::kUseTrLoad; static_assert(!kUseTrLoad, "This pipeline does not use trload!"); @@ -467,14 +469,26 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP {0, 0}, Policy::template MakeShuffledBiasTileDistribution()); - // ----------------------------Loop write out------------------------------// - auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), - dq_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}); - using SPBlockTileType = decltype(gemm_0.MakeCBlockTile()); using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile()); using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + // ----------------------------Loop write out------------------------------// + auto dq_dram_window = [&]() { + if constexpr(kIsAtomic32) + { + return make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + } + else + { + return make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + decltype(cast_tile( + QGradBlockTileType{}))::get_tile_distribution()); + } + }(); index_t i_total_loops = 0; index_t seqlen_q_step = seqlen_q_start; @@ -792,8 +806,20 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP } else { - update_tile(dq_dram_window, dq_acc); + if constexpr(kIsAtomic32) + { + update_tile(dq_dram_window, dq_acc); + } + else + { + buffer_store_fence(); + update_tile_raw(dq_dram_window, + cast_tile(dq_acc), + number<-1>{}, + bool_constant{}); + } } + move_tile_window(dq_dram_window, {kM0, 0}); i_total_loops += 1; @@ -1027,14 +1053,24 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); } - if constexpr(kIsDeterministic) { store_tile(dq_dram_window, dq_acc); } else { - update_tile(dq_dram_window, dq_acc); + if constexpr(kIsAtomic32) + { + update_tile(dq_dram_window, dq_acc); + } + else + { + buffer_store_fence(); + update_tile_raw(dq_dram_window, + cast_tile(dq_acc), + number<-1>{}, + bool_constant{}); + } } return make_tuple(dk_acc, dv_acc); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 9bd78b4077..7bdc17a480 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -54,8 +54,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + static constexpr bool kIsAtomic32 = Problem::kIsAtomic32; static constexpr bool kUseTrLoad = Problem::kUseTrLoad; static_assert(kUseTrLoad, "This pipeline uses trload!"); + static_assert(kIsAtomic32, "This pipeline does not use atomic16!"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 5adb64564d..2aa7123462 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -56,8 +56,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + static constexpr bool kIsAtomic32 = Problem::kIsAtomic32; static constexpr bool kUseTrLoad = Problem::kUseTrLoad; static_assert(kUseTrLoad, "This pipeline uses trload!"); + static_assert(kIsAtomic32, "This pipeline does not use atomic16!"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 68ead7c765..61b1440551 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -741,6 +741,56 @@ struct BlockFmhaBwdPipelineDefaultPolicy return dstr; } + template + CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccAtomic16DramTileDistribution() + { + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kNPerBlock = Problem::kQKHeaddim; + + constexpr index_t mPack = 2; // for b16 + constexpr index_t M1 = mPack; + constexpr index_t M0 = kMPerBlock / M1; + + constexpr index_t N0 = kBlockSize / get_warp_size(); + constexpr index_t N1 = get_warp_size() / M0; + constexpr index_t N2 = kNPerBlock / (N0 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 1>>, + sequence<2, 1>, + sequence<2, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAtomic16DramTileDistribution() + { + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kNPerBlock = Problem::kQKHeaddim; + + constexpr index_t mPack = 2; // for b16 + constexpr index_t M1 = mPack; + constexpr index_t M0 = kMPerBlock / M1; + + constexpr index_t N0 = kBlockSize / get_warp_size(); + constexpr index_t N1 = get_warp_size() / M0; + constexpr index_t N2 = kNPerBlock / (N0 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 1>>, + sequence<1, 2>, + sequence<1, 2>>{}); + } + // these are for lds template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp index 99718a187f..e150848f16 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -25,6 +25,7 @@ template struct BlockFmhaBwdConvertQGradPipelineProblem { @@ -115,8 +119,10 @@ struct BlockFmhaBwdConvertQGradPipelineProblem static constexpr index_t kM0 = kM0_; static constexpr index_t kN0 = kN0_; static constexpr index_t kQKHeaddim = kQKHeaddim_; + static constexpr index_t kGemm4WarpN = kGemm4WarpN_; static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsDeterministic = kIsDeterministic_; + static constexpr bool kIsAtomic32 = kIsAtomic32_; // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;