diff --git a/CHANGELOG.md b/CHANGELOG.md index fe1e7ef345..438320d907 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added benchmarking support for tile engine GEMM Multi D. * Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands. * Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM. -* Added tensor-wise quantization for CK_TILE GEMM +* Added support for f32 to FMHA (fwd/bwd). +* Added tensor-wise quantization for CK_TILE GEMM. ### Optimized diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 802c9e51d7..81d34484a5 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation FWD_DTYPE_MAP = { + "fp32" : "FmhaFwdFp32", "fp16" : "FmhaFwdFp16", "bf16" : "FmhaFwdBf16", "fp8" : "FmhaFwdFp8", @@ -12,6 +13,7 @@ FWD_DTYPE_MAP = { } BWD_DTYPE_MAP = { + "fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16" } diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 0d8f366d8a..e2f69fa49a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -601,6 +601,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl cond &= pipeline.F_squant == 'f' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index bd6a9044e9..7319ef7ea1 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -370,7 +370,14 @@ class FmhaBwdDQDKDVKernel: # TODO: design a more practical way to do it # this is current supported tile size. def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: - if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': + if dtype == 'fp32' and tr_load == 'f': + return [ + # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + ] + elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': return [ FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), @@ -865,6 +872,30 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm cond = dtype in ['fp16', 'bf16'] if not cond: continue + + # fp32 only, all variations + if receipt == 800: + cond = dtype == 'fp32' + cond &= dpad == dvpad + if not cond: + continue + # fp32 only, minimal set of parameters + elif receipt == 801: + cond = dtype == 'fp32' + cond &= hdim in [64, 128] + cond &= dpad == dvpad + cond &= mode == 'batch' + cond &= bias == 'no' + cond &= dropout == 'no' + cond &= mask == 's_no' + cond &= deterministic == "f" + if not cond: + continue + else: + # Don't build fp32 by default + if dtype == 'fp32': + continue + gen_dot_do_o[t.dot_do_o_kernel] = True gen_dq_dk_dv[t.dq_dk_dv_kernel] = True if not t.convert_dq_kernel.disabled: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index da0c9ca931..f898d5f7b2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -25,6 +25,7 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = { 32 : 32, + 48 : 48, 64 : 64, 96 : 128, 128: 128, @@ -164,7 +165,7 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; - const bool has_load_tr = ck_tile::is_load_tr_supported(); + [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); {F_dispatch} return r; @@ -249,9 +250,8 @@ class FmhaFwdApiTrait: else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False - @property - def seqtune(self) -> str: - if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true + def seqtune(self, max_bm0 : int) -> str: + if self.bm0 == max_bm0: return 'true/*fall back to largest tile*/' else: return f'a.seqlen_q <= {self.bm0}' @@ -386,6 +386,7 @@ class FmhaFwdApiPool: per_hdim_case=str() for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] + max_bm0 = max((t.bm0 for t in traits), default=0) inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' @@ -393,7 +394,7 @@ class FmhaFwdApiPool: F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, @@ -534,7 +535,20 @@ class KernelComponentFactory: # this is current supported tile size per hdim @staticmethod def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': + if dtype == 'fp32': + return { + # bm0, bn0, bk0, bn1, bk1, + ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + } + elif dtype == 'fp16' or dtype == 'bf16': return { (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), @@ -572,7 +586,13 @@ class KernelComponentFactory: # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? pipelines = [] - if dtype in ['fp16', 'bf16']: + if dtype in ['fp32']: + squant = 'f' + for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + elif dtype in ['fp16', 'bf16']: squant = 'f' for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256 and hdim_v == 256: @@ -626,6 +646,8 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + for tile, next_tile in zip(tiles, tiles[1:]): + assert next_tile.F_bm0 >= tile.F_bm0, 'Tiles must be ordered by increasing bm0' for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): if mode == "group": if pipeline.F_spad != 't' or pipeline.F_skpad != 't': @@ -635,12 +657,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue - if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): - continue + if dtype != 'fp32': + if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + continue + if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + continue # logits_soft_cap is only allowed if no bias if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): continue @@ -710,6 +733,31 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl if not cond: continue + # fp32 only, all variations + if receipt == 800: + cond = dtype == 'fp32' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + if not cond: + continue + # fp32 only, minimal set of parameters + elif receipt == 801: + cond = dtype == 'fp32' + cond &= hdim in [48, 128] + cond &= mode == 'batch' + cond &= pipeline.F_bias == 'no' + cond &= pipeline.F_lse == 'f' + cond &= pipeline.F_dropout == 'f' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + cond &= pipeline.F_mask == 's_no' + if not cond: + continue + else: + # Don't build fp32 by default + if dtype == 'fp32': + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 0ebeaddf9c..38491b56c4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -184,6 +184,9 @@ class FmhaFwdAppendKVApiPool: per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) @dataclass @@ -341,6 +344,13 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op cond &= pipeline.F_vlayout == 'row' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index cee1505486..281357ef1e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -768,6 +768,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, opt cond &= pipeline.F_squant == 'f' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) @@ -834,6 +841,13 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim cond = dtype in ['fp16', 'bf16'] if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + gen.append(k) return gen diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index df6b422981..3624b7b387 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -560,6 +560,12 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl if not cond: continue + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/example_fmha_bwd.cpp b/example/ck_tile/01_fmha/example_fmha_bwd.cpp index e0e1fba668..73b3c1e619 100644 --- a/example/ck_tile/01_fmha/example_fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_bwd.cpp @@ -43,7 +43,7 @@ auto create_args(int argc, char* argv[]) "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" "a(libi) or 2, alibi with 1*h. a:1, b*h") .insert("dbias", "0", "output bias gradient or not") - .insert("prec", "fp16", "data type. fp16 or bf16") + .insert("prec", "fp16", "data type. fp32/fp16/bf16") .insert("mask", "0", "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" @@ -159,7 +159,11 @@ int main(int argc, char* argv[]) return -1; const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") + if(data_type == "fp32") + { + return run(arg_parser) == bwd_result::success ? 0 : -2; + } + else if(data_type == "fp16") { return run(arg_parser) == bwd_result::success ? 0 : -2; } diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index 79fda6d564..c27a5ce1ae 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -67,7 +67,7 @@ auto create_args(int argc, char* argv[]) "n or 0, no bias\n" "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" "a(libi) or 2, alibi with 1*h. a:1, b*h") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("prec", "fp16", "data type. fp32/fp16/bf16/fp8/bf8") .insert("mask", "0", "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" @@ -227,7 +227,11 @@ int main(int argc, char* argv[]) return -1; const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") + if(data_type == "fp32") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "fp16") { return run(arg_parser) == fwd_result::success ? 0 : -2; } diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index d861b351d4..b6f2c8ca30 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -35,6 +35,14 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) +{ + double rtol = 1e-4; + double atol = 1e-4; + return ck_tile::make_tuple(rtol, atol); +} + template <> auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) { @@ -77,7 +85,9 @@ bwd_result fmha_bwd_run(mode_enum mode, std::optional json = std::nullopt) { const std::string data_type = []() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) + return "fp32"; + else if constexpr(std::is_same_v) return "fp16"; else if constexpr(std::is_same_v) return "bf16"; @@ -776,7 +786,7 @@ bwd_result fmha_bwd_run(mode_enum mode, // non-deterministic kernels use atomic add to write dq // Some block may be skipped with causal mask and dq are not set to zeros // In these cases thus we need to zero out it first - if(!deterministic || mask.type == mask_enum::no_mask) + if(!deterministic || mask.type != mask_enum::no_mask) dq_acc_buf.SetZero(); ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index f5dd42a6bd..761def6d6a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,6 +17,10 @@ #include #include +struct FmhaFwdFp32 +{ +}; + struct FmhaFwdFp16 { }; @@ -48,6 +52,22 @@ struct FmhaFwdFp8Fp32 template struct FmhaFwdTypeConfig; +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = float; + using KDataType = float; + using VDataType = float; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = float; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = float; +}; + template <> struct FmhaFwdTypeConfig { diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index e58e040f19..0703af71e3 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -41,6 +41,14 @@ auto get_elimit(std::string /*init_method*/) return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + template <> auto get_elimit(std::string /*init_method*/) { @@ -180,7 +188,9 @@ fwd_result fmha_fwd_run(mode_enum mode, std::optional json = std::nullopt) { const std::string data_type = []() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) + return "fp32"; + else if constexpr(std::is_same_v) return "fp16"; else if constexpr(std::is_same_v) return "bf16"; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7bc5ca5df8..de3427c33d 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -470,7 +470,7 @@ struct buffer_store<16> index_t /*flag*/ = 1) { static_assert(sizeof(T) == 16); - using mbuf_t = fp32x4_t; + using mbuf_t = uint32x4_t; #if HAS_RAW_BUFFER_BUILTINS index_t s_offset = i_offset; __builtin_amdgcn_raw_buffer_store_b128( @@ -496,7 +496,7 @@ struct buffer_store<8> index_t /*flag*/ = 1) { static_assert(sizeof(T) == 8); - using mbuf_t = fp32x2_t; + using mbuf_t = uint32x2_t; #if HAS_RAW_BUFFER_BUILTINS index_t s_offset = i_offset; __builtin_amdgcn_raw_buffer_store_b64( @@ -522,7 +522,7 @@ struct buffer_store<4> index_t /*flag*/ = 1) { static_assert(sizeof(T) == 4); - using mbuf_t = float; + using mbuf_t = uint32_t; #if HAS_RAW_BUFFER_BUILTINS index_t s_offset = i_offset; __builtin_amdgcn_raw_buffer_store_b32( @@ -548,7 +548,7 @@ struct buffer_store<2> index_t /*flag*/ = 1) { static_assert(sizeof(T) == 2); - using mbuf_t = short; + using mbuf_t = uint16_t; #if HAS_RAW_BUFFER_BUILTINS index_t s_offset = i_offset; __builtin_amdgcn_raw_buffer_store_b16( @@ -573,8 +573,8 @@ struct buffer_store<1> index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { - static_assert(sizeof(T) == 4); - using mbuf_t = float; + static_assert(sizeof(T) == 1); + using mbuf_t = uint8_t; #if HAS_RAW_BUFFER_BUILTINS index_t s_offset = i_offset; __builtin_amdgcn_raw_buffer_store_b8( diff --git a/include/ck_tile/core/utility/philox_rand.hpp b/include/ck_tile/core/utility/philox_rand.hpp index 87abf5cc18..52b1489543 100644 --- a/include/ck_tile/core/utility/philox_rand.hpp +++ b/include/ck_tile/core/utility/philox_rand.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -55,7 +55,8 @@ class philox CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out, const unsigned long long subsequence, - const index_t start_idx) const + const index_t idx0, + const index_t idx1) const { uint4 tmp_ph; tmp_ph = get_philox_4x32(subsequence); @@ -66,13 +67,12 @@ class philox tmp[2] = tmp_ph.z; tmp[3] = tmp_ph.w; uint32_t* out_tmp = reinterpret_cast(&out[0]); - out_tmp[0] = tmp[start_idx]; - out_tmp[1] = tmp[start_idx + 2]; + out_tmp[0] = tmp[idx0]; + out_tmp[1] = tmp[idx1]; } - CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out, - const unsigned long long subsequence, - const index_t start_idx) const + CK_TILE_HOST_DEVICE void + get_random_4x8(uint8_t* out, const unsigned long long subsequence, const index_t idx) const { uint4 tmp_ph; tmp_ph = get_philox_4x32(subsequence); @@ -83,7 +83,7 @@ class philox tmp[2] = tmp_ph.z; tmp[3] = tmp_ph.w; uint32_t* out_tmp = reinterpret_cast(&out[0]); - out_tmp[0] = tmp[start_idx]; + out_tmp[0] = tmp[idx]; } private: diff --git a/include/ck_tile/core/utility/transpose_vectors.hpp b/include/ck_tile/core/utility/transpose_vectors.hpp index 497fd3b948..f0d7dae706 100644 --- a/include/ck_tile/core/utility/transpose_vectors.hpp +++ b/include/ck_tile/core/utility/transpose_vectors.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -34,7 +34,13 @@ struct transpose_vectors constexpr auto I3 = number<3>{}; constexpr auto I4 = number<4>{}; - if constexpr(sizeof(S) == 2) + if constexpr(sizeof(S) == 4) + { + static_for<0, NY, 1>{}([&](auto iy) { + static_for<0, NX, 1>{}([&](auto ix) { vy_tuple(iy)(ix) = vx_tuple[ix][iy]; }); + }); + } + else if constexpr(sizeof(S) == 2) { static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!"); diff --git a/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp b/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp index 2a02adaee3..ec6c6009b7 100644 --- a/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp +++ b/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp @@ -33,18 +33,22 @@ reference_batched_dropout_randval(HostTensor& randval_b_m // With SFactor = 2 it becomes: // C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8) // C j: (lane % 32) + // See ck_tile/ops/fmha/block/block_dropout.hpp for more details. - constexpr index_t max_warp_size = 64; - constexpr index_t warp_gemm_mn = 32; + // The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values + constexpr index_t philox_per_tile = 64; + constexpr index_t warp_gemm_mn = 32; const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn); const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn); auto f = [&](index_t i_h, index_t row, index_t col) { uint2 rowcol = make_uint2(row, col); - for(index_t lane = 0; lane < max_warp_size; lane++) + for(index_t lane = 0; lane < philox_per_tile; lane++) { - philox ph(drop_seed, drop_offset + (batch * nhead + i_h) * max_warp_size + lane); + const uint64_t ph_head_offset = drop_offset + (batch * nhead + i_h) * philox_per_tile; + const index_t ph_offset = lane; + philox ph(drop_seed, ph_head_offset + ph_offset); uint8_t random_uint8_t[16]; ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp index e036402e16..8abdd54cd9 100644 --- a/include/ck_tile/ops/fmha/block/block_dropout.hpp +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -1,17 +1,44 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" namespace ck_tile { +// BlockDropoutBwd and BlockDropout (fwd) support two warp gemm tile sizes: 32x32 (MFMA only) and +// 16x16 (MFMA and WMMA). Even if fwd and bwd use different tile sizes, generated random +// numbers will be the same, they are also the same for MFMA (on CDNA), WMMA (on RDNA), or host +// (for verification, see ck_tile/host/reference/reference_batched_dropout_randval.hpp). +// +// The (row, col) coordinate of the current 32x32 tile in the P matrix determines a subsequence of +// random numbers (ph_subsequence). +// The (batch, head, 0..63) coordinate determines an offset in the subsequence (ph_head_offset and +// ph_offset). +// This means that subsequences are non-overlapping, reproducible and independent of mask or window. +// +// There are 3 modes (all produce the same results): +// * For 32x32 MFMA tile each of 64 lanes generates 4 * 32 bits or 16 bytes, so one warp generates +// the entire 32x32 tile (64 * 16 = 32 * 32). +// * For 16x16 MFMA tile one warp generates 1/4 of the 32x32 tile ((16 * 16) / (64 * 16) = 1/4), 4 +// warps generate the same 64 * 16 random bytes and each uses its own quarter. If kMPerBlock > +// MWarp * WG::kM one warp can generate two 16x16 tiles (MIterPerWarp = 2) so fewer instructions +// are needed for generating a 32x32 tile. +// * For 16x16 WMMA tile one warp generates 1/2 of the 32x32 tile ((16 * 16) / (32 * 16) = 1/2), 2 +// warps generate the same 64 * 16 random bytes and each uses its own half. If kMPerBlock > MWarp * +// WG::kM one warp can generate two 16x16 tiles. + +namespace detail { +// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values +constexpr index_t philox_per_tile = 64; +} // namespace detail + struct NullBlockDropout { template - __host__ __device__ static constexpr auto + CK_TILE_HOST_DEVICE static constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, index_t seqlen_qk_start) { @@ -32,7 +59,9 @@ struct BlockDropout float rp_undrop_, uint8_t p_undrop_in_uint8_t_, bool is_store_randval_) - : ph(seed, offset + (i_batch * nheads + i_head) * get_warp_size() + get_lane_id()), + : ph_seed(amd_wave_read_first_lane(seed)), + ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) * + detail::philox_per_tile)), rp_undrop(rp_undrop_), p_undrop_in_uint8_t(p_undrop_in_uint8_t_), is_store_randval(is_store_randval_) @@ -46,11 +75,15 @@ struct BlockDropout { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - constexpr index_t kMPerStep = MWarp * WG::kM; - constexpr index_t kNPerStep = NWarp * WG::kN; + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); auto randval_dram_window = [&]() { @@ -78,12 +111,17 @@ struct BlockDropout { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t kMPerStep = MWarp * WG::kM; - constexpr index_t kNPerStep = WG::kN; - constexpr index_t kN1 = 8; - constexpr index_t kN0 = kNPerStep / kN1; + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + constexpr index_t kN1 = 8; + constexpr index_t kN0 = kNPerStep / kN1; constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor( ck_tile::make_tuple(number{}, number{}, number{}), @@ -107,33 +145,35 @@ struct BlockDropout { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t MIterPerWarp = 1; + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; constexpr index_t NIterPerWarp = 1; + // The tile distribution is different from the one in MakeRandValLdsShuffleTileDistribution, + // because it can combine 2 (MIterPerWarp) 16x16 subtiles for generating them at once constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< sequence<>, - tuple, sequence>, + tuple, sequence>, tuple>, - tuple>, + tuple>, sequence<1, 2>, - sequence<0, 0>>{}; + sequence<1, 0>>{}; // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. - constexpr auto randval_block_inner_part_dstr_encoding = []() { - if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; - } - else - { - return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; - } - }(); + constexpr auto randval_block_inner_part_dstr_encoding = + typename WarpGemmDispatcher::CWarpDstrEncoding{}; constexpr auto randval_block_part_dstr_encode = detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, @@ -147,11 +187,13 @@ struct BlockDropout { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t MIterPerWarp = 1; + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; constexpr index_t NIterPerWarp = 1; constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< @@ -181,14 +223,16 @@ struct BlockDropout { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t kNPerBlock = BlockGemmShape::kN; - constexpr index_t kMPerStep = MWarp * WG::kM; - constexpr index_t kNPerStep = NWarp * WG::kN; + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; // randval tile in LDS auto randval_lds = make_tensor_view( @@ -200,42 +244,100 @@ struct BlockDropout // register distribute auto randval_dist_generated = make_static_distributed_tensor(MakeRandValTileDistribution()); - static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); - auto randval_lds_read_window = + const auto randval_lds_read_window = make_tile_window(randval_lds_window.get_bottom_tensor_view(), randval_lds_window.get_window_lengths(), randval_lds_window.get_window_origin(), MakeRandValLdsShuffleTileDistribution()); - const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); + const index_t start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + auto generate_randval = [&](auto i_m0, auto i_n0) { + // Generate random numbers + uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize]; + const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp; + const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp); + if constexpr(IsWG32) + { + // Generate the whole 32x32 tile at once (each tile consists of random numbers taken + // from a separate subsequence of Philox) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0, wg_n0)); + const index_t ph_offset = get_lane_id(); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + else + { + // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether + // MIterPerWarp is equal to 1 or 2) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0 / 2, wg_n0 / 2)); + const index_t subtile_m0 = wg_m0 % 2; + if constexpr(get_warp_size() == 32) + { + const index_t ph_offset = (get_lane_id() & 15) + + (((get_lane_id() >> 4) & 1) << 5) + + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + } + else + { + const index_t subtile_n0 = (get_lane_id() >> 4) & 1; + const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 4); + ph.get_random_4x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0); + } + } + } + + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + // Transpose randval using LDS + store_tile(randval_lds_window, randval_dist_generated); + block_sync_lds(); + const auto randval = load_tile(randval_lds_read_window); + block_sync_lds(); + return randval; + }; + if(is_store_randval) { static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { - int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); - int block_col_start = (start_n0_idx / WG::kN) + i_n0; - uint2 rowcol = make_uint2(block_row_start, block_col_start); - - // generate random number - uint8_t random_uint8_t[16]; - ph.get_random_16x8(random_uint8_t, - reinterpret_cast(rowcol)); - - constexpr auto randval_dist_generated_spans = - decltype(randval_dist_generated)::get_distributed_spans(); - int i_random_idx = 0; - sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); - randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; - }); - }); - // save to LDS - store_tile(randval_lds_window, randval_dist_generated); - block_sync_lds(); - // read from LDS to register - auto randval = load_tile(randval_lds_read_window); + const auto randval = generate_randval(i_m0, i_n0); // save to Global const auto randval_store = cast_tile(randval); store_tile(randval_dram_window, randval_store); @@ -244,37 +346,21 @@ struct BlockDropout move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); }); move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); - }; + } static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { - int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); - int block_col_start = (start_n0_idx / WG::kN) + i_n0; - uint2 rowcol = make_uint2(block_row_start, block_col_start); - - // generate random number - uint8_t random_uint8_t[16]; - ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); - - constexpr auto randval_dist_generated_spans = - decltype(randval_dist_generated)::get_distributed_spans(); - int i_random_idx = 0; - sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); - randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; - }); - }); - // save to LDS - store_tile(randval_lds_window, randval_dist_generated); - block_sync_lds(); - // read from LDS to register - auto randval = load_tile(randval_lds_read_window); + const auto randval = generate_randval(i_m0, i_n0); + // Drop values of P based on the generated probabilities constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { - constexpr auto p_idx0 = tile_distributed_index{}; + constexpr auto p_idx0 = + tile_distributed_index()>{}; constexpr auto p_idx1 = - tile_distributed_index{}; + tile_distributed_index(), + idx1.impl_.template at<2>()>{}; constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t @@ -286,12 +372,15 @@ struct BlockDropout }); } - ck_tile::philox ph; + const unsigned long long ph_seed; + const unsigned long long ph_head_offset; const float rp_undrop; const uint8_t p_undrop_in_uint8_t; const bool is_store_randval; }; +// TODO: IsWG32_ is not needed as template parameter and can be removed. IsDropout_ == false can be +// replaced with NullBlockDropout. This requires changes in xformers and other libs. template struct BlockDropoutBwd; @@ -301,8 +390,8 @@ struct BlockDropoutBwd static constexpr bool IsDropout = false; static constexpr bool IsStoreRandval = IsStoreRandval_; - template - __host__ __device__ static constexpr auto + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, index_t seqlen_qk_start) { @@ -316,10 +405,7 @@ struct BlockDropoutBwd template struct BlockDropoutBwd { - static constexpr bool IsDropout = true; - // true: 32*32 warp gemm - // false: 16*16 warp gemm - static constexpr bool IsWG32 = IsWG32_; + static constexpr bool IsDropout = true; static constexpr bool IsStoreRandval = IsStoreRandval_; CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, @@ -329,38 +415,30 @@ struct BlockDropoutBwd unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_) - : ph(seed, - offset + (i_batch * nheads + i_head) * get_warp_size() + - (IsWG32 ? get_lane_id() : ((get_lane_id() & 47) + ((get_warp_id() & 1) << 4)))), + : ph_seed(amd_wave_read_first_lane(seed)), + ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) * + detail::philox_per_tile)), rp_undrop(rp_undrop_), p_undrop_in_uint8_t(p_undrop_in_uint8_t_) { } - template + template CK_TILE_HOST_DEVICE static constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, index_t seqlen_qk_start) { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using BlockGemmShape = remove_cvref_t; - using WG = remove_cvref_t())>; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16); - constexpr index_t kMPerStep = [&]() { - if constexpr(MBwdWG16MultiIterCheck) - { - return MWarp * WG::kM * 2; - } - else - { - return MWarp * WG::kM; - } - }(); - constexpr index_t kNPerStep = NWarp * WG::kN; + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); auto randval_dram_window = [&]() { @@ -384,85 +462,39 @@ struct BlockDropoutBwd } template - CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor() - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t kMPerStep = MWarp * WG::kM; - constexpr index_t kNPerStep = WG::kN; - constexpr index_t kN1 = 8; - constexpr index_t kN0 = kNPerStep / kN1; - - constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor( - ck_tile::make_tuple(number{}, number{}, number{}), - ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto randval_lds_block_desc = transform_tensor_descriptor( - randval_lds_block_desc_0, - ck_tile::make_tuple( - make_pass_through_transform(number{}), - make_merge_transform(ck_tile::make_tuple(number{}, number{}))), - ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}), - ck_tile::make_tuple(sequence<0>{}, sequence<1>{})); - - return randval_lds_block_desc; - } - - template CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16); - - constexpr index_t MIterPerWarp = [&]() { - if constexpr(MBwdWG16MultiIterCheck) - { - return 2; - } - else - { - return 1; - } - }(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; constexpr index_t NIterPerWarp = 1; constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< sequence<>, - tuple, sequence>, + tuple, sequence>, tuple>, - tuple>, + tuple>, sequence<1, 2>, - sequence<0, 0>>{}; + sequence<1, 0>>{}; - // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. - // except headdim256. - constexpr auto randval_block_inner_part_dstr_encoding = []() { - if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - if constexpr(IsWG32) - return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; - else - return typename WarpGemmMfmaF16F16F32M16N16K16::CWarpDstrEncoding{}; - } - else - { - if constexpr(IsWG32) - return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; - else - return typename WarpGemmMfmaBf16Bf16F32M16N16K16::CWarpDstrEncoding{}; - } - }(); + constexpr auto randval_block_inner_part_dstr_encoding = + typename WarpGemmDispatcher::CWarpDstrEncoding{}; + static_assert( + std::is_same_v, + typename WG::CWarpDstrEncoding>); constexpr auto randval_block_part_dstr_encode = detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, @@ -471,129 +503,6 @@ struct BlockDropoutBwd return make_static_tile_distribution(randval_block_part_dstr_encode); } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution() - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t MIterPerWarp = 1; - constexpr index_t NIterPerWarp = 1; - - constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto randval_block_part_dstr_encode = - detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, - typename WG::CWarpDstrEncoding{}); - - return make_static_tile_distribution(randval_block_part_dstr_encode); - } - - template - CK_TILE_HOST_DEVICE void Run(void* randval_ptr, - const index_t start_m0_idx, - const index_t start_n0_idx, - PComputeWindow& p_compute, - RandValDramWindow& randval_dram_window) const - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t kNPerBlock = BlockGemmShape::kN; - constexpr index_t kMPerStep = MWarp * WG::kM; - constexpr index_t kNPerStep = NWarp * WG::kN; - - // randval tile in LDS - auto randval_lds = make_tensor_view( - reinterpret_cast(randval_ptr), MakeRandValLdsBlockDescriptor()); - - auto randval_lds_window = make_tile_window( - randval_lds, MakeRandValLdsBlockDescriptor().get_lengths(), {0, 0}); - - // register distribute - auto randval_dist_generated = - make_static_distributed_tensor(MakeRandValTileDistribution()); - static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); - - auto randval_lds_read_window = - make_tile_window(randval_lds_window.get_bottom_tensor_view(), - randval_lds_window.get_window_lengths(), - randval_lds_window.get_window_origin(), - MakeRandValLdsShuffleTileDistribution()); - - static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { - static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { - int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); - int block_col_start = (start_n0_idx / WG::kN) + i_n0; - uint2 rowcol = make_uint2(block_row_start, block_col_start); - - // generate random number - uint8_t random_uint8_t[16]; - ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); - - constexpr auto randval_dist_generated_spans = - decltype(randval_dist_generated)::get_distributed_spans(); - int i_random_idx = 0; - sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); - randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; - }); - }); - // save to LDS - store_tile(randval_lds_window, randval_dist_generated); - block_sync_lds(); - // read from LDS to register - auto randval = load_tile(randval_lds_read_window); - constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); - sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { - constexpr auto p_idx0 = tile_distributed_index{}; - constexpr auto p_idx1 = - tile_distributed_index{}; - constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); - constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); - p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t - ? p_compute[p_idx] * rp_undrop - : PComputeDataType(0); - }); - }); - // save to Global - if constexpr(IsStoreRandval) - { - const auto randval_store = cast_tile(randval); - store_tile(randval_dram_window, randval_store); - move_tile_window(randval_dram_window, {0, kNPerStep}); - } - }); - if constexpr(IsStoreRandval) - { - move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); - } - }); - if constexpr(IsStoreRandval) - { - move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); - } - } - template { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t kNPerBlock = BlockGemmShape::kN; - constexpr bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16); - constexpr bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16); - constexpr index_t kMPerStep = [&]() { - if constexpr(MBwdWG16MultiIterCheck) + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // register distribute + auto randval_dist_generated = + make_static_distributed_tensor(MakeRandValTileDistribution()); + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + auto generate_randval = [&](auto i_m0, auto i_n0) { + // Generate random numbers + uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize]; + const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp; + const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp); + if constexpr(IsWG32) { - return MWarp * WG::kM * 2; + // Generate the whole 32x32 tile at once (each tile consists of random numbers + // taken from a separate subsequence of Philox) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0, wg_n0)); + const index_t ph_offset = get_lane_id(); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); } else { - return MWarp * WG::kM; + // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether + // MIterPerWarp is equal to 1 or 2) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0 / 2, wg_n0 / 2)); + const index_t subtile_m0 = wg_m0 % 2; + if constexpr(get_warp_size() == 32) + { + const index_t ph_offset = (get_lane_id() & 15) + + (((get_lane_id() >> 4) & 1) << 5) + + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + } + else + { + const index_t subtile_n0 = (get_lane_id() >> 4) & 1; + const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 4); + ph.get_random_4x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0); + } + } } - }(); - constexpr index_t kNPerStep = NWarp * WG::kN; - // register distribute - auto randval = make_static_distributed_tensor( - MakeRandValTileDistribution()); - if constexpr(IsWG32) - static_assert(randval.kThreadElementSpaceSize == 16); - else - static_assert(randval.kThreadElementSpaceSize == 4 || - randval.kThreadElementSpaceSize == 8); + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + return randval_dist_generated; + }; static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { - int block_row_start, block_col_start; - if constexpr(IsWG32) - { - block_row_start = (start_m0_idx / WG::kM) + i_m0; - block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id(); - } - else - { - block_row_start = start_m0_idx / 32 + i_m0; - block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2; - } - uint2 rowcol = make_uint2(block_row_start, block_col_start); - - // generate random number - uint8_t* random_uint8_t_; - if constexpr(MBwdWG16SingleIterCheck) - { - uint8_t random_uint8_t[4]; - // m0t0 ~m0t15/m0t32~m0t47: 0 - // m0t16~m0t31/m0t48~m0t63: 1 - // m1t0 ~m1t15/m1t32~m1t47: 2 - // m1t16~m1t31/m1t48~m1t63: 3 - const index_t start_idx = - ((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1); - ph.get_random_4x8( - random_uint8_t, reinterpret_cast(rowcol), start_idx); - random_uint8_t_ = random_uint8_t; - } - else if constexpr(MBwdWG16MultiIterCheck) - { - uint8_t random_uint8_t[8]; - // t0 ~t15/t32~t47: 0 - // t16~t31/t48~t63: 1 - const index_t start_idx = (get_lane_id() >> 4) & 1; - ph.get_random_8x8( - random_uint8_t, reinterpret_cast(rowcol), start_idx); - random_uint8_t_ = random_uint8_t; - } - else - { - uint8_t random_uint8_t[16]; - ph.get_random_16x8(random_uint8_t, - reinterpret_cast(rowcol)); - random_uint8_t_ = random_uint8_t; - } - + const auto randval = generate_randval(i_m0, i_n0); + // Drop values of P based on the generated probabilities, negative sign is used to + // distinguish such values ​​later in bwd pipeline. constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); - int i_random_idx = 0; sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { - constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); - randval(r_idx) = random_uint8_t_[i_random_idx++]; - constexpr auto p_idx0 = tile_distributed_index{}; + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + constexpr auto p_idx0 = + tile_distributed_index(), + idx0.impl_.template at<1>(), + idx0.impl_.template at<2>()>{}; constexpr auto p_idx1 = tile_distributed_index{}; constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t @@ -717,7 +645,8 @@ struct BlockDropoutBwd } } - ck_tile::philox ph; + const unsigned long long ph_seed; + const unsigned long long ph_head_offset; const float rp_undrop; const uint8_t p_undrop_in_uint8_t; }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index b2b00a07e4..980dfb06ae 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -82,6 +82,7 @@ struct FmhaBwdDQDKDVKernel // clang-format off template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; template <> struct t2s { static constexpr const char * name = "fp16"; }; template <> struct t2s { static constexpr const char * name = "bf16"; }; // clang-format on @@ -1187,6 +1188,7 @@ struct FmhaBwdOGradDotOKernel // clang-format off template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; template <> struct t2s { static constexpr const char * name = "fp16"; }; template <> struct t2s { static constexpr const char * name = "bf16"; }; // clang-format on @@ -1443,6 +1445,7 @@ struct FmhaBwdConvertQGradKernel // clang-format off template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; template <> struct t2s { static constexpr const char * name = "fp16"; }; template <> struct t2s { static constexpr const char * name = "bf16"; }; // clang-format on diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp index 9c348495ff..f7ee88f906 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -32,12 +32,27 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy constexpr auto warp_gemm = []() { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - if constexpr(std::is_same_v && - std::is_same_v && + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 16); + + return WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + true>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) @@ -49,6 +64,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 67ab548dab..050eb48384 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -264,12 +264,27 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy constexpr auto warp_gemm = []() { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - if constexpr(std::is_same_v && - std::is_same_v && + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 16); + + return WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + true>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) @@ -281,6 +296,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index dccb41ba44..9dba3c85d5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -73,12 +73,27 @@ struct BlockFmhaPipelineQXCustomPolicy constexpr auto warp_gemm = []() { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - if constexpr(std::is_same_v && - std::is_same_v && + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 16); + + return WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + true>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) @@ -90,6 +105,8 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) @@ -201,7 +218,7 @@ struct BlockFmhaPipelineQXCustomPolicy constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), - number<8>{}, + number{}, number<1>{}); constexpr auto q_lds_block_desc = transform_tensor_descriptor( @@ -228,14 +245,29 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; - constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - constexpr auto warp_gemm = []() { - if constexpr(std::is_same_v && - std::is_same_v && + constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 16); + + return WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + true>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) @@ -247,6 +279,8 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) @@ -258,6 +292,8 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 32); + // TODO: hard coded here. Otherwise, it may incorrect result constexpr index_t swizzle_factor = 4; return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< @@ -507,7 +543,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy{}, number{}, number{}), make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number{}, number<1>{}), - number<8>{}, + number{}, number<1>{}); constexpr auto k_lds_block_desc = transform_tensor_descriptor( @@ -806,15 +842,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, @@ -824,7 +859,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, // N0 K2 N2 sequence<0, 2, 2>>{}); } - else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0) + else if constexpr(get_warp_size() % (K2 * N0) == 0) { constexpr index_t K1 = get_warp_size() / (K2 * N0); constexpr index_t K0 = kBlockSize / get_warp_size(); @@ -863,13 +898,40 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, tuple, sequence>, - tuple, sequence<1, 2>>, + tuple, sequence<1, 2>>, // N1, N2 K0 tuple, sequence<2, 0>>, - sequence<1, 2>, + sequence<1, 2>, // N0 K1 sequence<0, 1>>{}); + if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock) + { + return dstr; + } + else + { + static_assert(kKPerBlock % 16 == 0); + constexpr index_t kKPerIter = kKPerBlock % 32 == 0 ? 32 : 16; + constexpr index_t K0_m = kKPerBlock / kKPerIter; + constexpr index_t K2 = 2; + constexpr index_t K1_m = kKPerIter / K2; + constexpr index_t N2_m = get_warp_size() / K1_m; + constexpr index_t N0_m = kNPerBlock / (N2_m * N1); + constexpr auto dstr_m = make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<1, 2>>, // N1, N2 K1 + tuple, sequence<2, 1>>, + sequence<2, 1, 2>, // K0 N0 K2 + sequence<0, 0, 2>>{}); + static_assert(container_reduce(dstr_m.get_lengths(), + std::multiplies{}, + 1) == kNPerBlock * kKPerBlock); + return dstr_m; + } } } @@ -897,14 +959,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, tuple, sequence>, @@ -913,7 +975,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, // N0 K2 <-> N2 sequence<0, 2, 2>>{}); } - else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0) + else if constexpr(get_warp_size() % (K2 * N0) == 0) { constexpr index_t K1 = get_warp_size() / (K2 * N0); constexpr index_t K0 = kBlockSize / get_warp_size(); diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index 41a744ea91..ca82519e72 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -7,20 +7,22 @@ namespace ck_tile { -static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len) +template +static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length() { - if(len == 96) + if constexpr(Headdim == 48) + return 48; + else if constexpr(Headdim == 96) return 128; - if(len == 160) + else if constexpr(Headdim == 160) return 256; - if(len == 192) + else if constexpr(Headdim == 192) return 192; - - // only length of 96, 160 and power-of-two is supported - if(!(len & (len - 1))) - return len; - - return 0; + else if constexpr(is_power_of_two_integer(Headdim)) + return Headdim; + else + static_assert(Headdim == 0, + "only Headdim of 48, 96, 160, 192 and power-of-two is supported"); }; template (); // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index f83bbc2a18..21f21e1aa0 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -12,6 +12,24 @@ namespace ck_tile { +// fp32 + +using WarpGemmMfmaF32F32F32M16N16K4 = WarpGemmImpl< + WarpGemmAttributeMfma>>; + +template +using WarpGemmMfmaF32F32F32M16N16K16 = WarpGemmImpl, + 4, + AttrNumAccess>>; + +template +using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution = + WarpGemmImpl, + 4, + AttrNumAccess>>; + // fp16 using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl< diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 11a8416fb2..7528760439 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -61,6 +61,135 @@ enum class WGAttrCtlEnum DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \ } +// F32 +template +struct WarpGemmAttributeMfmaImplF32F32F32M16N16K4 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + + using ADataType = float; + using BDataType = float; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 4; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 1; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x4f32", Ctrl) + else + { +#if defined(__gfx9__) + c_vec = __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx9__) + return bit_cast( + __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +template +struct WarpGemmAttributeMfmaImplF32F32F32M32N32K2 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + + using ADataType = float; + using BDataType = float; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 2; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 1; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x2f32", Ctrl) + else + { +#if defined(__gfx9__) + c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx9__) + return bit_cast( + __builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + // V_MFMA_F32_16x16x32_BF16 template struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32 diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 5eedd42b04..924f7c4a54 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -23,6 +23,11 @@ template struct WarpGemmDispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K4; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; }; // fp16 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index df3a03cca8..292bc41a0b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -38,8 +38,10 @@ set(REGRESSION_TESTS test_conv_tensor_rearrange test_gemm_mx test_ck_tile_batched_transpose + test_ck_tile_fmha_bwd_fp32 test_ck_tile_fmha_bwd_bf16 test_ck_tile_fmha_bwd_fp16 + test_ck_tile_fmha_fwd_fp32 test_ck_tile_fmha_fwd_bf16 test_ck_tile_fmha_fwd_fp16 test_ck_tile_fmha_fwd_fp8 diff --git a/test/ck_tile/fmha/CMakeLists.txt b/test/ck_tile/fmha/CMakeLists.txt index b17d682560..8e5cce4c0b 100644 --- a/test/ck_tile/fmha/CMakeLists.txt +++ b/test/ck_tile/fmha/CMakeLists.txt @@ -6,12 +6,18 @@ endif() set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") +add_gtest_executable(test_ck_tile_fmha_bwd_fp32 test_fmha_bwd_fp32.cpp) +target_link_libraries(test_ck_tile_fmha_bwd_fp32 PRIVATE ${FMHA_BWD_INSTANCES}) + add_gtest_executable(test_ck_tile_fmha_bwd_bf16 test_fmha_bwd_bf16.cpp) target_link_libraries(test_ck_tile_fmha_bwd_bf16 PRIVATE ${FMHA_BWD_INSTANCES}) add_gtest_executable(test_ck_tile_fmha_bwd_fp16 test_fmha_bwd_fp16.cpp) target_link_libraries(test_ck_tile_fmha_bwd_fp16 PRIVATE ${FMHA_BWD_INSTANCES}) +add_gtest_executable(test_ck_tile_fmha_fwd_fp32 test_fmha_fwd_fp32.cpp) +target_link_libraries(test_ck_tile_fmha_fwd_fp32 PRIVATE ${FMHA_FWD_INSTANCES}) + add_gtest_executable(test_ck_tile_fmha_fwd_bf16 test_fmha_fwd_bf16.cpp) target_link_libraries(test_ck_tile_fmha_fwd_bf16 PRIVATE ${FMHA_FWD_INSTANCES}) @@ -23,8 +29,10 @@ target_link_libraries(test_ck_tile_fmha_fwd_fp8 PRIVATE ${FMHA_FWD_INSTANCES}) add_custom_target(test_ck_tile_fmha DEPENDS + test_ck_tile_fmha_bwd_fp32 test_ck_tile_fmha_bwd_bf16 test_ck_tile_fmha_bwd_fp16 + test_ck_tile_fmha_fwd_fp32 test_ck_tile_fmha_fwd_bf16 test_ck_tile_fmha_fwd_fp16 test_ck_tile_fmha_fwd_fp8 diff --git a/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp b/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp new file mode 100644 index 0000000000..d409d0dd30 --- /dev/null +++ b/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "example/ck_tile/01_fmha/fmha_bwd.hpp" +#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp" + +#include "gtest/gtest.h" + +using DataTypeConfig = FmhaBwdFp32; + +using ::testing::Values; +using ::testing::ValuesIn; + +const auto HDimValues = Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}); + +const auto ModeValues = Values(mode_enum::batch, mode_enum::group); + +constexpr std::string init_method = "uf"; + +#include "test_fmha_bwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_fwd.inc b/test/ck_tile/fmha/test_fmha_fwd.inc index 9497122594..ccca5cf969 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.inc +++ b/test/ck_tile/fmha/test_fmha_fwd.inc @@ -515,6 +515,8 @@ class PagedKV : public TestWithParam, { }; +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(PagedKV); + INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, PagedKV, Combine(SplitKVHDimValues, @@ -580,6 +582,8 @@ class SplitKV : public TestWithParam, { }; +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(SplitKV); + INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, SplitKV, Combine(SplitKVHDimValues, @@ -662,6 +666,8 @@ INSTANTIATE_TEST_SUITE_P( std::tuple{2, 3, 1, 264, 265, "1"}, std::tuple{4, 4, 2, 71, 64, "1"}))); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKV); + TEST_P(AppendKV, Test) { auto [hdims, diff --git a/test/ck_tile/fmha/test_fmha_fwd_fp32.cpp b/test/ck_tile/fmha/test_fmha_fwd_fp32.cpp new file mode 100644 index 0000000000..00f1eb0629 --- /dev/null +++ b/test/ck_tile/fmha/test_fmha_fwd_fp32.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "example/ck_tile/01_fmha/fmha_fwd.hpp" +#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp" + +#include "gtest/gtest.h" + +#include +#include + +using ::testing::Values; + +using DataTypeConfig = FmhaFwdFp32; + +const auto HDimValues = Values(std::tuple{32, -1}, + std::tuple{48, -1}, + std::tuple{64, -1}, + std::tuple{96, 128}, + std::tuple{128, -1}, + std::tuple{192, -1}, + std::tuple{256, -1}); + +const auto SplitKVHDimValues = Values(); + +const auto AppendKVHDimValues = Values(); + +const auto ModeValues = Values(mode_enum::batch, mode_enum::group); + +const auto IsVRowmajorValues = Values(true); + +const bool squant = false; +const std::string init_method = "uf"; +const bool def_lse = true; +const bool def_is_v_rowmajor = true; + +int adjust_seqlen(int seqlen) { return seqlen; } + +#include "test_fmha_fwd.inc"